Source code for delira._backends

import os
import json
from delira._version import get_versions as _get_versions

# to register new possible backends, they have to be added to this list.
# each backend should consist of a tuple of length 2 with the first entry
# being the package import name and the second being the backend abbreviation.
# E.g. TensorFlow's package is named 'tensorflow' but if the package is found,
# it will be considered as 'tf' later on
__POSSIBLE_BACKENDS = (("torch", "torch"),
                       ("tensorflow", "tf"),
                       ("chainer", "chainer"),
                       ("sklearn", "sklearn"))
__BACKENDS = ()


def _determine_backends():
    """
    Internal Helper Function to determine the currently valid backends by
    trying to import them. The valid backends are not returned, but appended
    to the global ``__BACKENDS`` variable

    """

    _config_file = __file__.replace("_backends.py", ".delira")
    # look for config file to determine backend
    # if file exists: load config into environment variables

    if not os.path.isfile(_config_file):
        _backends = {}
        # try to import all possible backends to determine valid backends

        import importlib
        for curr_backend in __POSSIBLE_BACKENDS:
            try:
                assert len(curr_backend) == 2
                assert all([isinstance(_tmp, str) for _tmp in curr_backend]), \
                    "All entries in current backend must be strings"

                # check if backend can be imported
                bcknd = importlib.util.find_spec(curr_backend[0])

                if bcknd is not None:
                    _backends[curr_backend[1]] = True
                else:
                    _backends[curr_backend[1]] = False
                del bcknd

            except ValueError:
                _backends[curr_backend[1]] = False

        with open(_config_file, "w") as f:
            json.dump({"version": _get_versions()['version'],
                       "backend": _backends},
                      f, sort_keys=True, indent=4)

        del _backends

    # set values from config file to variable and empty Backend-List before
    global __BACKENDS
    __BACKENDS = []
    with open(_config_file) as f:
        _config_dict = json.load(f)
    for key, val in _config_dict.pop("backend").items():
        if val:
            __BACKENDS.append(key.upper())
    del _config_dict

    del _config_file

    # make __BACKENDS non mutable
    __BACKENDS = tuple(__BACKENDS)


[docs]def get_backends(): """ Return List of currently available backends Returns ------- list list of strings containing the currently installed backends """ global __BACKENDS if not __BACKENDS: _determine_backends() return __BACKENDS
def seed_all(seed): """ Helper Function to seed all available backends Parameters ---------- seed : int the new random seed """ import sys import numpy as np np.random.seed(seed) import random random.seed = seed if "torch" in sys.modules and "TORCH" in get_backends(): import torch torch.random.manual_seed(seed) elif "tensorflow" in sys.modules and "TF" in get_backends(): import tensorflow as tf tf.random.set_random_seed(seed) elif "chainer" in sys.modules and "CHAINER" in get_backends(): try: import cupy cupy.random.seed(seed) except ImportError: pass