Generative Adversarial Networks

GenerativeAdversarialNetworkBasePyTorch

class GenerativeAdversarialNetworkBasePyTorch(n_channels, noise_length, **kwargs)[source]

Bases: delira.models.abstract_network.AbstractPyTorchNetwork

Implementation of Vanilla DC-GAN to create 64x64 pixel images

Notes

The fully connected part in the discriminator has been replaced with an equivalent convolutional part

References

https://arxiv.org/abs/1511.06434

See also

AbstractPyTorchNetwork

static _build_models(in_channels, noise_length, **kwargs)[source]

Builds actual generator and discriminator models

Parameters
  • in_channels (int) – number of channels for generated images by generator and inputs of discriminator

  • noise_length (int) – length of noise vector (generator input)

  • **kwargs – additional keyword arguments

Returns

  • torch.nn.Sequential – generator

  • torch.nn.Sequential – discriminator

_init_kwargs = {}
static closure(model, data_dict: dict, optimizers: dict, losses=None, metrics=None, fold=0, **kwargs)[source]

closure method to do a single backpropagation step

Parameters
  • model (ClassificationNetworkBase) – trainable model

  • data_dict (dict) – dictionary containing data

  • optimizers (dict) – dictionary of optimizers to optimize model’s parameters

  • losses (dict) – dict holding the losses to calculate errors (gradients from different losses will be accumulated)

  • metrics (dict) – dict holding the metrics to calculate

  • fold (int) – Current Fold in Crossvalidation (default: 0)

  • kwargs (dict) – additional keyword arguments

Returns

  • dict – Metric values (with same keys as input dict metrics)

  • dict – Loss values (with same keys as input dict losses)

  • list – Arbitrary number of predictions as torch.Tensor

Raises

AssertionError – if optimizers or losses are empty or the optimizers are not specified

forward(real_image_batch)[source]

Create fake images by feeding noise through generator and feed results and real images through discriminator

Parameters

real_image_batch (torch.Tensor) – batch of real images

Returns

  • torch.Tensor – Generated fake images

  • torch.Tensor – Discriminator prediction of fake images

  • torch.Tensor – Discriminator prediction of real images

property init_kwargs

Returns all arguments registered as init kwargs

Returns

init kwargs

Return type

dict

static prepare_batch(batch: dict, input_device, output_device)

Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)

Parameters
  • batch (dict) – dictionary containing all the data

  • input_device (torch.device) – device for network inputs

  • output_device (torch.device) – device for network outputs

Returns

dictionary containing data in correct type and shape and on correct device

Return type

dict