import numpy as np
from delira import get_backends
from ..utils.decorators import dtype_func
def _check_and_correct_zero_shape(arg):
"""
Corrects the shape of numpy array to be at least 1d and returns the
argument as is otherwise
Parameters
----------
arg : Any
the argument which must be corrected in its shape if it's
zero-dimensional
Returns
-------
Any
argument (shape corrected if necessary)
"""
if isinstance(arg, np.ndarray) and arg.shape == ():
arg = arg.reshape(1)
return arg
[docs]def convert_batch_to_numpy_identity(*args, **kwargs):
"""
Corrects the shape of all zero-sized numpy arrays to be at least 1d
Parameters
----------
*args :
positional arguments of potential arrays to be corrected
**kwargs :
keyword arguments of potential arrays to be corrected
Returns
-------
"""
args = list(args)
for idx, arg in args:
args[idx] = _check_and_correct_zero_shape(arg)
for key, val in kwargs.items():
kwargs[key] = _check_and_correct_zero_shape(val)
return args, kwargs
if "TORCH" in get_backends():
import torch
from ..utils.decorators import torch_module_func
[docs] @dtype_func(float)
def float_to_pytorch_tensor(f: float):
"""
Converts a single float to a PyTorch Float-Tensor
Parameters
----------
f : float
float to convert
Returns
-------
torch.Tensor
converted float
"""
return torch.from_numpy(np.array([f], dtype=np.float32))
[docs] @torch_module_func
def create_optims_default_pytorch(model, optim_cls, **optim_params):
"""
Function to create a optimizer dictionary
(in this case only one optimizer for the whole network)
Parameters
----------
model : :class:`AbstractPyTorchNetwork`
model whose parameters should be updated by the optimizer
optim_cls :
Class implementing an optimization algorithm
**optim_params :
Additional keyword arguments (passed to the optimizer class
Returns
-------
dict
dictionary containing all created optimizers
"""
return {"default": optim_cls(model.parameters(), **optim_params)}
@torch_module_func
def create_optims_gan_default_pytorch(model, optim_cls, **optim_params):
"""
Function to create a optimizer dictionary
(in this case two optimizers: One for the generator,
one for the discriminator)
Parameters
----------
model : :class:`AbstractPyTorchNetwork`
model whose parameters should be updated by the optimizer
optim_cls :
Class implementing an optimization algorithm
optim_params :
Additional keyword arguments (passed to the optimizer class
Returns
-------
dict
dictionary containing all created optimizers
"""
return {"gen": optim_cls(model.gen.parameters(), **optim_params),
"discr": optim_cls(model.discr.parameters(), **optim_params)}
[docs] def convert_torch_tensor_to_npy(*args, **kwargs):
"""
Function to convert all torch Tensors to numpy arrays and
reshape zero-size tensors
Parameters
----------
*args :
arbitrary positional arguments
**kwargs :
arbitrary keyword arguments
Returns
-------
Iterable
all given positional arguments (converted if necessary)
dict
all given keyword arguments (converted if necessary)
"""
args = [_arg.detach().cpu().numpy() for _arg in args
if isinstance(_arg, torch.Tensor)]
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
kwargs[k] = v.detach().cpu().numpy()
return convert_batch_to_numpy_identity(*args, **kwargs)
if "TF" in get_backends():
import tensorflow as tf
[docs] def create_optims_default_tf(optim_cls, **optim_params):
"""
Function to create a optimizer dictionary
(in this case only one optimizer)
Parameters
----------
optim_cls :
Class implementing an optimization algorithm
**optim_params :
Additional keyword arguments (passed to the optimizer class)
Returns
-------
dict
dictionary containing all created optimizers
"""
return {"default": optim_cls(**optim_params)}
[docs] def initialize_uninitialized(sess):
"""
Function to initialize only uninitialized variables in a session graph
Parameters
----------
sess : tf.Session()
"""
global_vars = tf.global_variables()
is_not_initialized = sess.run(
[tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(
global_vars, is_not_initialized) if not f]
if not_initialized_vars:
sess.run(tf.variables_initializer(not_initialized_vars))
[docs] def convert_tf_tensor_to_npy(*args, **kwargs):
"""
Function to convert all tf Tensors to numpy arrays
and reshape zero-size tensors
Parameters
----------
*args :
arbitrary positional arguments
**kwargs :
arbitrary keyword arguments
Returns
-------
Iterable
all given positional arguments (converted if necessary)
dict
all given keyword arguments (converted if necessary)
"""
return convert_batch_to_numpy_identity(*args, **kwargs)