Source code for delira.models.backends.torch.data_parallel

import torch

from delira.models.backends.torch.abstract_network import \
    AbstractPyTorchNetwork


[docs]class DataParallelPyTorchNetwork(AbstractPyTorchNetwork, torch.nn.DataParallel): """ A Wrapper around a :class:`AbstractPyTorchNetwork` instance to implement parallel training by splitting the batches """ def __init__(self, module: AbstractPyTorchNetwork, device_ids=None, output_device=None, dim=0): """ Parameters ---------- module : :class:`AbstractPyTorchNetwork` the module to wrap (will be replicated on all devices) device_ids : list a list containing the devices to use (either as strings or as :class:`chainer.backend.Device`). output_device : str or :class:`chainer.backend.Device` The output device Make sure, your labels are also on this device for loss calculation! If not specified, the second device of ``devices`` will be used for output gathering. dim : int the index of the batchdimension (usually 0, but can become e.g. 1 in NLP tasks) """ AbstractPyTorchNetwork.__init__(self) torch.nn.DataParallel.__init__(self, module, device_ids, output_device, dim)
[docs] def forward(self, *args, **kwargs): """ Scatters the inputs (both positional and keyword arguments) across all devices, feeds them through model replicas and re-builds batches on output device Parameters ---------- *args : positional arguments of arbitrary number and type **kwargs : keyword arguments of arbitrary number and type Returns ------- Any combined output from all scattered models """ return torch.nn.DataParallel.forward(*args, **kwargs)
@property def closure(self): return self.module.closure @property def prepare_batch(self): return self.module.prepare_batch