TensorFlow Eager Execution¶
AbstractTfEagerNetwork¶
-
class
AbstractTfEagerNetwork
(data_format='channels_first', trainable=True, name=None, dtype=None, **kwargs)[source]¶ Bases:
delira.models.abstract_network.AbstractNetwork
,tensorflow.keras.layers.Layer
Abstract Network for TF eager execution backend. All models to use with this backend should be derived from this class
-
_init_kwargs
= {}¶
-
abstract
call
(*args, **kwargs)[source]¶ Defines the model’s forward pass
- Parameters
*args – arbitrary positional arguments
**kwargs – arbbitrary keyword arguments
- Raises
NotImplementedError – If not overwritten by subclass
-
static
closure
(model, data_dict: dict, optimizers: Dict[str, tensorflow.train.Optimizer], losses={}, metrics={}, fold=0, **kwargs)[source]¶ Function which handles prediction from batch, logging, loss calculation and optimizer step
- Parameters
model (
AbstractNetwork
) – model to forward data throughdata_dict (dict) – dictionary containing the data
optimizers (dict) – dictionary containing all optimizers to perform parameter update
losses (dict) – Functions or classes to calculate losses
metrics (dict) – Functions or classes to calculate other metrics
fold (int) – Current Fold in Crossvalidation (default: 0)
kwargs (dict) – additional keyword arguments
- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict losses)
dict – Arbitrary number of predictions
- Raises
NotImplementedError – If not overwritten by subclass
-
property
init_kwargs
¶ Returns all arguments registered as init kwargs
- Returns
init kwargs
- Return type
-
DataParallelTfEagerNetwork¶
-
class
DataParallelTfEagerNetwork
(module, devices)[source]¶ Bases:
delira.models.backends.tf_eager.abstract_network.AbstractTfEagerNetwork
DataParallel Module for the TF eager execution backend
Warning
This Module is highly experimental and not guaranteed to work properly!
-
_init_kwargs
= {}¶
-
call
(*args, **kwargs)[source]¶ Defines the forward pass of the module
- Parameters
*args – arbitrary positional arguments
**kwargs – arbitrary keyword arguments
-
property
closure
¶ Function which handles prediction from batch, logging, loss calculation and optimizer step
- Parameters
model (
AbstractNetwork
) – model to forward data throughdata_dict (dict) – dictionary containing the data
optimizers (dict) – dictionary containing all optimizers to perform parameter update
losses (dict) – Functions or classes to calculate losses
metrics (dict) – Functions or classes to calculate other metrics
fold (int) – Current Fold in Crossvalidation (default: 0)
kwargs (dict) – additional keyword arguments
- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict losses)
dict – Arbitrary number of predictions
- Raises
NotImplementedError – If not overwritten by subclass
-
property
init_kwargs
¶ Returns all arguments registered as init kwargs
- Returns
init kwargs
- Return type
-
property
prepare_batch
¶ Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)
-