#!/usr/bin/env python# -*- coding: utf-8 -*-# Author: Benjamin Vial# This file is part of nannos# License: GPLv3# See the documentation at nannos.gitlab.ioimportpyvista# isort: skippyvista.set_jupyter_backend("trame")pyvista.set_plot_theme("document")pyvista.global_theme.background="white"# pyvista.global_theme.window_size = [600, 400]pyvista.global_theme.axes.show=True# pyvista.global_theme.smooth_shading = True# pyvista.global_theme.antialiasing = True# pyvista.global_theme.axes.box = Trueimportosimportwarningsfrompackaging.versionimportparseasparse_versionfrom.__about__import__author__,__description__,__version__from.__about__importdataas_datafrom.logimport*warnings.filterwarnings(action="ignore",category=RuntimeWarning,message=".*deadlock.",)available_backends=["numpy","scipy","autograd"]defprint_info():print(f"nannos v{__version__}")print("=============")print(__description__)print(f"Author: {__author__}")print(f"Licence: {_data['License']}")defhas_torch():try:importtorchreturnTrueexceptModuleNotFoundError:returnFalsedefhas_jax():try:importjaxreturnTrueexceptModuleNotFoundError:returnFalsedef_has_cuda():try:importtorchreturntorch.cuda.is_available()exceptModuleNotFoundError:returnFalseHAS_TORCH=has_torch()HAS_CUDA=_has_cuda()HAS_JAX=has_jax()ifHAS_TORCH:available_backends.append("torch")ifHAS_JAX:available_backends.append("jax")defuse_gpu(boolean):globalDEVICEglobal_CPU_DEVICEglobal_GPU_DEVICEglobal_FORCE_GPU_FORCE_GPU=1ifboolean:ifBACKENDnotin["torch","jax"]:logger.debug(f"Cannot use GPU with {BACKEND} backend.")_delvar("_GPU_DEVICE")_CPU_DEVICE=TrueelifBACKEND=="torch"andnotHAS_TORCH:logger.warning("pytorch not found. Cannot use GPU.")_delvar("_GPU_DEVICE")_CPU_DEVICE=TrueelifBACKEND=="torch"andnotHAS_CUDA:logger.warning("cuda not found. Cannot use GPU.")_delvar("_GPU_DEVICE")_CPU_DEVICE=TrueelifBACKEND=="jax"andnotHAS_JAX:logger.warning("jax not found. Cannot use GPU.")_delvar("_GPU_DEVICE")_CPU_DEVICE=Trueelse:DEVICE="cuda"logger.debug("Using GPU.")_delvar("_CPU_DEVICE")_GPU_DEVICE=Trueelse:_CPU_DEVICE=True_delvar("_GPU_DEVICE")DEVICE="cpu"logger.debug("Using CPU.")_reload_package()defjit(fun,**kwargs):# if get_backend() == "jax":# import jax# return jax.jit(fun)returnfundef_delvar(VAR):ifVARinglobals():delglobals()[VAR]def_del_vars(VARS):forVARinVARS:_delvar(VAR)
[docs]defset_backend(backend):"""Set the numerical backend. Parameters ---------- backend : str Either ``numpy``, ``scipy``, ``autograd``, ``torch`` or ``jax``. """global_NUMPYglobal_SCIPYglobal_AUTOGRADglobal_JAXglobal_TORCHglobal_FORCE_BACKEND_FORCE_BACKEND=1ifbackend==get_backend():return## _backend_env_var = os.environ.get("NANNOS_BACKEND")# if _backend_env_var is not None:# if _backend_env_var in available_backends:# if backend != _backend_env_var:# # _delvar("_FORCE_BACKEND")# pass# else:# backend = _backend_env_varifbackend=="autograd":logger.debug("Setting autograd backend")_AUTOGRAD=True_del_vars(["_JAX","_TORCH","_SCIPY"])elifbackend=="scipy":logger.debug("Setting scipy backend")_SCIPY=True_del_vars(["_JAX","_TORCH","_AUTOGRAD"])elifbackend=="jax":logger.debug("Setting jax backend")_JAX=True_del_vars(["_SCIPY","_TORCH","_AUTOGRAD"])elifbackend=="torch":_TORCH=Truelogger.debug("Setting torch backend")_del_vars(["_SCIPY","_JAX","_AUTOGRAD"])elifbackend=="numpy":_NUMPY=Truelogger.debug("Setting numpy backend")_del_vars(["_SCIPY","_JAX","_AUTOGRAD","_TORCH"])else:raiseValueError(f"Unknown backend '{backend}'. Please choose between 'numpy', 'scipy', 'jax', 'torch' and 'autograd'.")_reload_package()
def_reload_package():importimportlibimportsysimportnannosimportlib.reload(nannos)its=[sforsinsys.modules.items()ifs[0].startswith("nannos")]fork,vinits:importlib.reload(v)defget_backend():if"_SCIPY"inglobals():return"scipy"elif"_AUTOGRAD"inglobals():return"autograd"elif"_JAX"inglobals():return"jax"elif"_TORCH"inglobals():return"torch"else:return"numpy"def_grad(f):raiseNotImplementedError(f"grad is not implemented for {BACKEND} backend.")if"_SCIPY"inglobals():importnumpygrad=_gradbackend=numpyelif"_AUTOGRAD"inglobals():fromautogradimportgrad,numpybackend=numpyelif"_JAX"inglobals():ifHAS_JAX:importjaxifparse_version(jax.__version__)>=parse_version("0.4.16"):jax.numpy.trapz=jax.scipy.integrate.trapezoidifDEVICE=="cpu":os.environ["CUDA_VISIBLE_DEVICES"]=""jax.config.update("jax_platform_name","cpu")else:# os.environ.pop("CUDA_VISIBLE_DEVICES", None)os.environ["CUDA_VISIBLE_DEVICES"]="0"jax.config.update("jax_platform_name","gpu")jax.config.update("jax_enable_x64",True)# TODO: jax eig not implemented on GPU# see https://github.com/google/jax/issues/1259# TODO: support jax properly (is it faster than autograd? use jit?)# jax does not support eig# for autodif wrt eigenvectors yet.# see: https://github.com/google/jax/issues/2748# from jax import grad, numpyfromjaximportnumpygrad=_gradbackend=numpyelse:logger.warning("jax not found. Falling back to default numpy backend.")set_backend("numpy")elif"_TORCH"inglobals():ifHAS_TORCH:importnumpyimporttorch# torch.set_default_tensor_type(torch.cuda.FloatTensor)backend=torchdef_array(a,**kwargs):ifisinstance(a,backend.Tensor):returna.to(torch.device(DEVICE))else:returnbackend.tensor(a,**kwargs).to(torch.device(DEVICE))backend.array=_arraydefgrad(f):defdf(x,*args,**kwargs):x=backend.array(x,dtype=bk.float64)_x=x.clone().detach().requires_grad_(True)out=backend.autograd.grad(f(_x,*args,**kwargs),_x,allow_unused=True)[0]returnoutreturndfelse:logger.warning("pytorch not found. Falling back to default numpy backend.")set_backend("numpy")else:importnumpygrad=_gradbackend=numpydefget_device():return"cuda"if"_GPU_DEVICE"inglobals()else"cpu"defuse_32_bit(yes):global_USE_32ifyes:_USE_32=Trueelse:_delvar("_USE_32")_reload_package()defget_types():return((backend.float32,backend.complex64)if"_USE_32"inglobals()else(backend.float64,backend.complex128))BACKEND=get_backend()DEVICE=get_device()FLOAT,COMPLEX=get_types()_backend_env_var=os.environ.get("NANNOS_BACKEND")_gpu_env_var=os.environ.get("NANNOS_GPU")if(_backend_env_varinavailable_backendsand_backend_env_varisnotNoneandBACKEND!=_backend_env_varand"_FORCE_BACKEND"notinglobals()):logger.debug(f"Found environment variable NANNOS_BACKEND={_backend_env_var}")set_backend(_backend_env_var)if_gpu_env_varisnotNoneand"_FORCE_GPU"notinglobals():logger.debug(f"Found environment variable NANNOS_GPU={_gpu_env_var}")use_gpu(True)from.importoptimizefrom.constantsimport*from.excitationimport*from.latticeimport*from.parallelimport*from.sampleimport*from.simulationimport*from.utilsimport*array=backend.array