Generative Adversarial Networks¶
GenerativeAdversarialNetworkBasePyTorch¶
-
class
GenerativeAdversarialNetworkBasePyTorch
(n_channels, noise_length, **kwargs)[source]¶ Bases:
delira.models.abstract_network.AbstractPyTorchNetwork
Implementation of Vanilla DC-GAN to create 64x64 pixel images
Notes
The fully connected part in the discriminator has been replaced with an equivalent convolutional part
References
https://arxiv.org/abs/1511.06434
See also
AbstractPyTorchNetwork
-
static
_build_models
(in_channels, noise_length, **kwargs)[source]¶ Builds actual generator and discriminator models
Parameters: Returns: - torch.nn.Sequential – generator
- torch.nn.Sequential – discriminator
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)[source]¶ closure method to do a single backpropagation step
Parameters: - model (
ClassificationNetworkBase
) – trainable model - data_dict (dict) – dictionary containing data
- optimizers (dict) – dictionary of optimizers to optimize model’s parameters
- criterions (dict) – dict holding the criterions to calculate errors (gradients from different criterions will be accumulated)
- metrics (dict) – dict holding the metrics to calculate
- fold (int) – Current Fold in Crossvalidation (default: 0)
- kwargs (dict) – additional keyword arguments
Returns: - dict – Metric values (with same keys as input dict metrics)
- dict – Loss values (with same keys as input dict criterions)
- list – Arbitrary number of predictions as torch.Tensor
Raises: AssertionError
– if optimizers or criterions are empty or the optimizers are not specified- model (
-
forward
(real_image_batch)[source]¶ Create fake images by feeding noise through generator and feed results and real images through discriminator
Parameters: real_image_batch (torch.Tensor) – batch of real images Returns: - torch.Tensor – Generated fake images
- torch.Tensor – Discriminator prediction of fake images
- torch.Tensor – Discriminator prediction of real images
-
static
prepare_batch
(batch: dict, input_device, output_device)¶ Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)
Parameters: - batch (dict) – dictionary containing all the data
- input_device (torch.device) – device for network inputs
- output_device (torch.device) – device for network outputs
Returns: dictionary containing data in correct type and shape and on correct device
Return type:
-
static