Source code for delira.utils.decorators

import numpy as np
from functools import wraps
import warnings


[docs]def dtype_func(class_object): """ Decorator to Check whether the first argument of the decorated function is of a certain type Parameters ---------- class_object : Any type the first function argument should have Returns ------- Wrapped Function Raises ------ AssertionError First argument of decorated function is not of given type """ def instance_checker(func): @wraps(func) def func_wrapper(checked_object, *args, **kwargs): assertion_str = "Argument 1 is not of type %s but of type %s" % \ (class_object.__name__, checked_object.__class__.__name__) assert isinstance(checked_object, class_object), assertion_str return func(checked_object, *args, **kwargs) return func_wrapper return instance_checker
[docs]def classtype_func(class_object): """ Decorator to Check whether the first argument of the decorated function is a subclass of a certain type Parameters ---------- class_object : Any type the first function argument should be subclassed from Returns ------- Wrapped Function Raises ------ AssertionError First argument of decorated function is not a subclass of given type """ def subclass_checker(func): @wraps(func) def func_wrapper(checked_object, *args, **kwargs): assertion_str = "Argument 1 is not subclass of %s but of type %s" \ % (class_object.__name__, checked_object.__name__) assert issubclass(checked_object, class_object), assertion_str return func(checked_object, *args, **kwargs) return func_wrapper return subclass_checker
[docs]def make_deprecated(new_func): """ Decorator which raises a DeprecationWarning for the decorated object Parameters ---------- new_func : Any new function which should be used instead of the decorated one Returns ------- Wrapped Function Raises ------ Deprecation Warning """ def deprecation(func): @wraps(func) def func_wrapper(*args, **kwargs): warnings.warn(DeprecationWarning("%s is deprecated in favor of %s" " and will be removed at next " "release" % (func.__name__, new_func.__name__))) return func(*args, **kwargs) return func_wrapper return deprecation
numpy_array_func = dtype_func(np.ndarray) import os if "torch" in os.environ["DELIRA_BACKEND"]: import torch torch_tensor_func = dtype_func(torch.Tensor) torch_module_func = dtype_func(torch.nn.Module)