TensorFlow Graph Execution

AbstractTfGraphNetwork

class AbstractTfGraphNetwork(sess=tensorflow.Session, **kwargs)[source]

Bases: delira.models.abstract_network.AbstractNetwork

Abstract Class for Tf Networks

See also

AbstractNetwork

_abc_impl = <_abc_data object>
_add_losses(losses: dict)[source]

Adds losses to model that are to be used by optimizers or during evaluation. Can be overwritten for more advanced loss behavior

Parameters

losses (dict) – dictionary containing all losses. Individual losses are averaged

_add_optims(optims: dict)[source]

Adds optims to model that are to be used by optimizers or during training. Can be overwritten for more advanced optimizers

Parameters

optim (dict) – dictionary containing all optimizers, optimizers should be of Type[tf.train.Optimizer]

_init_kwargs = {}
static closure(model, data_dict: dict, optimizers: dict, losses={}, metrics={}, fold=0, **kwargs)[source]

default closure method to do a single training step; Could be overwritten for more advanced models

Parameters
  • model (SkLearnEstimator) – trainable model

  • data_dict (dict) – dictionary containing the data

  • optimizers (dict) – dictionary of optimizers to optimize model’s parameters; ignored here, just passed for compatibility reasons

  • losses (dict) – dict holding the losses to calculate errors; ignored here, just passed for compatibility reasons

  • metrics (dict) – dict holding the metrics to calculate

  • fold (int) – Current Fold in Crossvalidation (default: 0)

  • **kwargs – additional keyword arguments

Returns

  • dict – Metric values (with same keys as input dict metrics)

  • dict – Loss values (with same keys as input dict losses; will always be empty here)

  • dict – dictionary containing all predictions

property init_kwargs

Returns all arguments registered as init kwargs

Returns

init kwargs

Return type

dict

static prepare_batch(batch: dict, input_device, output_device)[source]

Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)

Parameters
  • batch (dict) – dictionary containing all the data

  • input_device (Any) – device for module inputs (will be ignored here; just given for compatibility)

  • output_device (Any) – device for module outputs (will be ignored here; just given for compatibility)

Returns

dictionary containing data in correct type and shape and on correct device

Return type

dict

run(*args, **kwargs)[source]

Evaluates self.outputs_train or self.outputs_eval based on self.training

Parameters
  • *args – currently unused, exist for compatibility reasons

  • **kwargs – kwargs used to feed as self.inputs. Same keys as for self.inputs must be used

Returns

sames keys as outputs_train or outputs_eval, containing evaluated expressions as values

Return type

dict