Generative Adversarial Networks

GenerativeAdversarialNetworkBasePyTorch

class GenerativeAdversarialNetworkBasePyTorch(n_channels, noise_length, **kwargs)[source]

Bases: delira.models.abstract_network.AbstractPyTorchNetwork

Implementation of Vanilla DC-GAN to create 64x64 pixel images

Notes

The fully connected part in the discriminator has been replaced with an equivalent convolutional part

References

https://arxiv.org/abs/1511.06434

See also

AbstractPyTorchNetwork

static _build_models(in_channels, noise_length, **kwargs)[source]

Builds actual generator and discriminator models

Parameters:
  • in_channels (int) – number of channels for generated images by generator and inputs of discriminator
  • noise_length (int) – length of noise vector (generator input)
  • **kwargs – additional keyword arguments
Returns:

  • torch.nn.Sequential – generator
  • torch.nn.Sequential – discriminator

_init_kwargs = {}
static closure(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)[source]

closure method to do a single backpropagation step

Parameters:
  • model (ClassificationNetworkBase) – trainable model
  • data_dict (dict) – dictionary containing data
  • optimizers (dict) – dictionary of optimizers to optimize model’s parameters
  • criterions (dict) – dict holding the criterions to calculate errors (gradients from different criterions will be accumulated)
  • metrics (dict) – dict holding the metrics to calculate
  • fold (int) – Current Fold in Crossvalidation (default: 0)
  • kwargs (dict) – additional keyword arguments
Returns:

  • dict – Metric values (with same keys as input dict metrics)
  • dict – Loss values (with same keys as input dict criterions)
  • list – Arbitrary number of predictions as torch.Tensor

Raises:

AssertionError – if optimizers or criterions are empty or the optimizers are not specified

forward(real_image_batch)[source]

Create fake images by feeding noise through generator and feed results and real images through discriminator

Parameters:real_image_batch (torch.Tensor) – batch of real images
Returns:
  • torch.Tensor – Generated fake images
  • torch.Tensor – Discriminator prediction of fake images
  • torch.Tensor – Discriminator prediction of real images
init_kwargs

Returns all arguments registered as init kwargs

Returns:init kwargs
Return type:dict
static prepare_batch(batch: dict, input_device, output_device)

Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)

Parameters:
  • batch (dict) – dictionary containing all the data
  • input_device (torch.device) – device for network inputs
  • output_device (torch.device) – device for network outputs
Returns:

dictionary containing data in correct type and shape and on correct device

Return type:

dict