import logging
from delira import get_backends
from delira.utils.decorators import make_deprecated
logger = logging.getLogger(__name__)
if "TORCH" in get_backends():
import torch
from delira.models.abstract_network import AbstractPyTorchNetwork
[docs] class GenerativeAdversarialNetworkBasePyTorch(AbstractPyTorchNetwork):
"""Implementation of Vanilla DC-GAN to create 64x64 pixel images
Notes
-----
The fully connected part in the discriminator has been replaced with an
equivalent convolutional part
References
----------
https://arxiv.org/abs/1511.06434
See Also
--------
:class:`AbstractPyTorchNetwork`
"""
@make_deprecated("Own repository to be announced")
def __init__(self, n_channels, noise_length, **kwargs):
"""
Parameters
----------
n_channels : int
number of image channels for generated images and input images
noise_length : int
length of noise vector
**kwargs :
additional keyword arguments
"""
# register params by passing them as kwargs to parent class
# __init__
super().__init__(n_channels=n_channels,
noise_length=noise_length,
**kwargs)
gen, discr = self._build_models(n_channels, noise_length, **kwargs)
self.nz = noise_length
self.gen = gen
self.discr = discr
for key, value in kwargs.items():
setattr(self, key, value)
[docs] def forward(self, real_image_batch):
"""
Create fake images by feeding noise through generator and feed
results and real images through discriminator
Parameters
----------
real_image_batch : torch.Tensor
batch of real images
Returns
-------
torch.Tensor
Generated fake images
torch.Tensor
Discriminator prediction of fake images
torch.Tensor
Discriminator prediction of real images
"""
noise = torch.randn(real_image_batch.size(0), self.nz, 1, 1,
device=real_image_batch.device)
fake_image_batch = self.gen(noise)
discr_pred_fake = self.discr(fake_image_batch)
discr_pred_real = self.discr(real_image_batch)
return {"fake_images": fake_image_batch,
"discr_fake": discr_pred_fake,
"discr_real": discr_pred_real}
[docs] @staticmethod
def closure(model, data_dict: dict,
optimizers: dict, losses={}, metrics={},
fold=0, **kwargs):
"""
closure method to do a single backpropagation step
Parameters
----------
model : :class:`ClassificationNetworkBase`
trainable model
data_dict : dict
dictionary containing data
optimizers : dict
dictionary of optimizers to optimize model's parameters
losses : dict
dict holding the losses to calculate errors
(gradients from different losses will be accumulated)
metrics : dict
dict holding the metrics to calculate
fold : int
Current Fold in Crossvalidation (default: 0)
kwargs : dict
additional keyword arguments
Returns
-------
dict
Metric values (with same keys as input dict metrics)
dict
Loss values (with same keys as input dict losses)
list
Arbitrary number of predictions as torch.Tensor
Raises
------
AssertionError
if optimizers or losses are empty or the optimizers are not
specified
"""
loss_vals = {}
metric_vals = {}
total_loss_discr_real = 0
total_loss_discr_fake = 0
total_loss_gen = 0
# choose suitable context manager:
if optimizers:
context_man = torch.enable_grad
else:
context_man = torch.no_grad
with context_man():
batch = data_dict.pop("data")
# predict batch
preds = model(batch)
# train discr with prediction from real image
for key, crit_fn in losses.items():
_loss_val = crit_fn(preds["discr_real"],
torch.ones_like(preds["discr_real"]))
loss_vals[key + "_discr_real"] = _loss_val.item()
total_loss_discr_real += _loss_val
# train discr with prediction from fake image
for key, crit_fn in losses.items():
_loss_val = crit_fn(preds["discr_fake"],
torch.zeros_like(preds["discr_fake"]))
loss_vals[key + "_discr_fake"] = _loss_val.item()
total_loss_discr_fake += _loss_val
total_loss_discr = total_loss_discr_fake + \
total_loss_discr_real
if optimizers:
# actual backpropagation
optimizers["discr"].zero_grad()
# perform loss scaling via apex if half precision is
# enabled
with optimizers["discr"].scale_loss(
total_loss_discr) as scaled_loss:
scaled_loss.backward(retain_graph=True)
optimizers["discr"].step()
# calculate adversarial loss for generator update
for key, crit_fn in losses.items():
_loss_val = crit_fn(preds["discr_fake"],
torch.ones_like(preds["discr_fake"]))
loss_vals[key + "_adversarial"] = _loss_val.item()
total_loss_gen += _loss_val
with torch.no_grad():
for key, metric_fn in metrics.items():
# calculate metrics for discriminator with real
# prediction
metric_vals[key + "_discr_real"] = metric_fn(
preds["discr_real"],
torch.ones_like(
preds["discr_real"])).item()
# calculate metrics for discriminator with fake
# prediction
metric_vals[key + "_discr_fake"] = metric_fn(
preds["discr_fake"],
torch.zeros_like(
preds["discr_fake"])).item()
# calculate adversarial metrics
metric_vals[key + "_adversarial"] = metric_fn(
preds["discr_fake"],
torch.ones_like(
preds["discr_fake"])).item()
if optimizers:
# actual backpropagation
optimizers["gen"].zero_grad()
# perform loss scaling via apex if half precision is
# enabled
with optimizers["gen"].scale_loss(
total_loss_gen) as scaled_loss:
scaled_loss.backward()
optimizers["gen"].step()
else:
# add prefix "val" in validation mode
eval_loss_vals, eval_metrics_vals = {}, {}
for key in loss_vals.keys():
eval_loss_vals["val_" + str(key)] = loss_vals[key]
for key in metric_vals:
eval_metrics_vals["val_" + str(key)] = metric_vals[key]
loss_vals = eval_loss_vals
metric_vals = eval_metrics_vals
return metric_vals, loss_vals, {k: v.detach()
for k, v in preds.items()}
[docs] @staticmethod
def _build_models(in_channels, noise_length, **kwargs):
"""
Builds actual generator and discriminator models
Parameters
----------
in_channels : int
number of channels for generated images by generator and inputs
of discriminator
noise_length : int
length of noise vector (generator input)
**kwargs :
additional keyword arguments
Returns
-------
torch.nn.Sequential
generator
torch.nn.Sequential
discriminator
"""
gen = torch.nn.Sequential(
# input is Z, going into a convolution
torch.nn.ConvTranspose2d(
noise_length, 64 * 8, 4, 1, 0, bias=False),
torch.nn.BatchNorm2d(64 * 8),
torch.nn.ReLU(True),
# state size. (64*8) x 4 x 4
torch.nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
torch.nn.BatchNorm2d(64 * 4),
torch.nn.ReLU(True),
# state size. (64*4) x 8 x 8
torch.nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
torch.nn.BatchNorm2d(64 * 2),
torch.nn.ReLU(True),
# state size. (64*2) x 16 x 16
torch.nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU(True),
# state size. (64) x 32 x 32
torch.nn.ConvTranspose2d(64, in_channels, 4, 2, 1, bias=False),
torch.nn.Tanh()
# state size. (nc) x 64 x 64
)
discr = torch.nn.Sequential(
# input is (nc) x 64 x 64
torch.nn.Conv2d(in_channels, 64, 4, 2, 1, bias=False),
torch.nn.LeakyReLU(0.2, inplace=True),
# state size. (64) x 32 x 32
torch.nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
torch.nn.BatchNorm2d(64 * 2),
torch.nn.LeakyReLU(0.2, inplace=True),
# state size. (64*2) x 16 x 16
torch.nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
torch.nn.BatchNorm2d(64 * 4),
torch.nn.LeakyReLU(0.2, inplace=True),
# state size. (64*4) x 8 x 8
torch.nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
torch.nn.BatchNorm2d(64 * 8),
torch.nn.LeakyReLU(0.2, inplace=True),
# state size. (64*8) x 4 x 4
torch.nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
torch.nn.Sigmoid()
)
return gen, discr