Generative Adversarial Networks¶
GenerativeAdversarialNetworkBasePyTorch¶
-
class
GenerativeAdversarialNetworkBasePyTorch
(n_channels, noise_length, **kwargs)[source]¶ Bases:
delira.models.abstract_network.AbstractPyTorchNetwork
Implementation of Vanilla DC-GAN to create 64x64 pixel images
Notes
The fully connected part in the discriminator has been replaced with an equivalent convolutional part
References
https://arxiv.org/abs/1511.06434
See also
AbstractPyTorchNetwork
-
static
_build_models
(in_channels, noise_length, **kwargs)[source]¶ Builds actual generator and discriminator models
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)[source]¶ closure method to do a single backpropagation step
- Parameters
model (
ClassificationNetworkBase
) – trainable modeldata_dict (dict) – dictionary containing data
optimizers (dict) – dictionary of optimizers to optimize model’s parameters
criterions (dict) – dict holding the criterions to calculate errors (gradients from different criterions will be accumulated)
metrics (dict) – dict holding the metrics to calculate
fold (int) – Current Fold in Crossvalidation (default: 0)
kwargs (dict) – additional keyword arguments
- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict criterions)
list – Arbitrary number of predictions as torch.Tensor
- Raises
AssertionError – if optimizers or criterions are empty or the optimizers are not specified
-
forward
(real_image_batch)[source]¶ Create fake images by feeding noise through generator and feed results and real images through discriminator
- Parameters
real_image_batch (torch.Tensor) – batch of real images
- Returns
torch.Tensor – Generated fake images
torch.Tensor – Discriminator prediction of fake images
torch.Tensor – Discriminator prediction of real images
-
static
prepare_batch
(batch: dict, input_device, output_device)¶ Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)
-
static