Source code for delira.models.backends.chainer.data_parallel

from delira.models.backends.chainer.abstract_network import \
import chainer

def _apply_scatter(inputs: chainer.Variable, target_devices: list,
                   dim: int = 0):
    Scatters inputs to target devices; Slicing will be done against a given

    inputs : :class:`chainer.Variable`
        the input variable to scatter
    target_devices : list
        the target devices to scatter to
    dim : int
        the dimension to use for slicing

        list of variable slices on correct devices


    def _slice_inputs(input_var, dim, num_dims, start, end, target_device):
        Slices the input variable along a given dimension from start to end
        and pushes it to correct device

        input_var : :class:`chainer.Variable`
            the variable to slice
        dim : int
            the dimension to slice along
        num_dims : int
            the dimensionality of ``input_var``
        start : int
            the start value for slicing (included)
        end : int
            the end value for slicing (excluded)
        target_device: str or :class:`chainer.backend.Device`
            the device to push to

            the slice of the variable

        slc = [slice(None)] * num_dims
        slc[dim] = slice(start, end)
        sliced_var = input_var[slc]
        output_shape = list(input_var.shape)
        output_shape[dim] = -1
        return sliced_var.reshape(output_shape)

    # create empty sliced input list
    scattered_inputs = []

    # calculate constant only once
    num_devices = len(target_devices)
    samples_per_device = inputs.shape[dim] // num_devices
    num_dims = len(inputs.shape)

    # iterate over number of devices and slice accordingly
    # (exclude last device)
    # iterating until the minimum of num_devices and inputs.shape[dim] -1
    # ensures that if the batchsize is too small to be scattered across all
    # devices, we will only scatter across as many devices as possible
    for i in range(min(num_devices, inputs.shape[dim]) - 1):
        start, end = i * samples_per_device, i + 1 * samples_per_device
        scattered_inputs.append(_slice_inputs(inputs, dim,
                                              num_dims, start, end,

    # all remaining samples (not yet sliced) are appended now
    # (all samples used; will be pushed to last device later)
        inputs, dim, len(inputs.shape,),
        (num_devices - 1) * samples_per_device,
        inputs.shape[dim], target_devices[-1]))

    return scattered_inputs

def _apply_gather(target_device, dim, *outputs):
    for _output in outputs:

    return chainer.functions.concat(outputs, dim)

def _scatter(inputs, target_devices: list, dim):
    Scatters all inputs across given target_devices

    inputs : Any
    target_devices : list
        list of devices to scatter to
    dim : int
        dimension to use for slicing

        list of scattered inputs


    def _scatter_map(inputs):
        Scatters all inputs across given target_devices

        inputs : Any

            list of scattered inputs


        # directly apply the scattering on variable
        if isinstance(inputs, chainer.Variable):
            return _apply_scatter(inputs, target_devices, dim)

        # map _scatter_map recursively to all samples in tuple
        if isinstance(inputs, tuple) and inputs:
            return list(zip(*map(_scatter_map, inputs)))

        # map _scatter_map recursively to all samples in list
        if isinstance(inputs, list) and inputs:
            return list(map(list, zip(*map(_scatter_map,

        # map _scatter_map recursively to all samples in dict
        if isinstance(inputs, dict) and inputs:
            return list(map(type(inputs), zip(*map(_scatter_map,

        # try to convert inputs to chainer variable first and afterwards
        # apply _scatter_map again

            return _scatter_map(chainer.as_variable(inputs))
        except TypeError:
            return [inputs for targets in target_devices]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has
    # references to a closure that has a reference to the scatter_map cell
    # (because the fn is recursive). To avoid this reference cycle, we set
    # the function to None, clearing the cell

        return _scatter_map(inputs)
        _scatter_map = None

def _gather(outputs, target_device, dim=0):
    Gathers tensors from different GPUs on a specified device
      (-1 means the CPU).

    def gather_map(outputs):
        out = outputs[0]
        if isinstance(out, chainer.Variable):
            return _apply_gather(target_device, dim, *outputs)
        if out is None:
            return None
        if isinstance(out, dict):
            if not all((len(out) == len(d) for d in outputs)):
                raise ValueError(
                    'All dicts must have the same number of keys')

            return type(out)(((k, gather_map([d[k] for d in outputs]))
                              for k in out))
        return type(out)(map(gather_map, zip(*outputs)))

    # Recursive function calls like this create reference cycles.
    # Setting the function to None clears the refcycle.
        return gather_map(outputs)
        gather_map = None

[docs]class DataParallelChainerNetwork(AbstractChainerNetwork): """ A Wrapper around a :class:`AbstractChainerNetwork` instance to implement parallel training by splitting the batches """ def __init__(self, module: AbstractChainerNetwork, devices: list, output_device=None, batch_dim=0): """ Parameters ---------- module : :class:`AbstractChainerNetwork` the module to wrap (will be replicated on all devices) devices : list a list containing the devices to use (either as strings or as :class:`chainer.backend.Device`). output_device : str or :class:`chainer.backend.Device` The output device Make sure, your labels are also on this device for loss calculation! If not specified, the second device of ``devices`` will be used for output gathering. batch_dim : int the index of the batchdimension (usually 0, but can become e.g. 1 in NLP tasks) """ super().__init__() modules = [module.copy() for _ in devices] for _module, _device in zip(modules, devices): _module.to_device(_device) with self.init_scope(): self.modules = chainer.ChainList(*modules) self.devices = devices if output_device is None: output_device = devices[1] self._output_device = output_device assert self._output_device in self.devices self._output_device_idx = self.devices.index(self._output_device) self.dim = batch_dim
[docs] def forward(self, *args, **kwargs): """ Scatters the inputs (both positional and keyword arguments) across all devices, feeds them through model replicas and re-builds batches on output device Parameters ---------- *args : positional arguments of arbitrary number and type **kwargs : keyword arguments of arbitrary number and type Returns ------- Any combined output from all scattered models """ scattered_args, scattered_kwargs = self._scatter(args, kwargs, self.devices, self.dim) predictions = [] for _args, _kwargs, _module in zip(scattered_args, scattered_kwargs, self.modules): predictions.append(_module(*_args, **_kwargs)) predictions = self._gather(predictions, self.dim, self._output_device) return predictions
[docs] def params(self, include_uninit=True): """ Only the parameters of the module on the first device will actually be updated, all the other parameters will be replicated by the optimizer after an update Parameters ---------- include_uninit : bool Returns ------- a generator holding the root-modules parameters """ return self.modules[0].params(include_uninit)
[docs] @staticmethod def _scatter(inputs, kwargs, target_devices: list, dim=0): """ Scatters all inputs (args and kwargs) to target devices and splits along given dimension Parameters ---------- inputs : list or tuple positional arguments kwargs : dict keyword arguments target_devices : list list of target device (either string or chainer.backend.Device) dim : int the dimension, which should be used for splitting the batch Returns ------- tuple scattered positional arguments tuple scattered keyword arguments """ # scatter inputs if given inputs = _scatter(inputs, target_devices, dim) if inputs else [] # scatter kwargs if given kwargs = _scatter(kwargs, target_devices, dim) if kwargs else [] # extend lengths by empty tuples if necessary if len(inputs) < len(kwargs): inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) elif len(kwargs) < len(inputs): kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) inputs = tuple(inputs) kwargs = tuple(kwargs) return inputs, kwargs
[docs] @staticmethod def _gather(predictions, dim, target_device): """ Re-Builds batches on the target device Parameters ---------- predictions : list list containing the predictions from all replicated models dim : int dimension to use for concatenating single predictions target_device : str or chainer.backend.Device the device, the re-built batch should lie on Returns ------- Any the rebuild batch (lying on ``target_device``) """ return _gather(predictions, target_device, dim)
[docs] def cleargrads(self): for module in self.modules: module.cleargrads()
[docs] def zerograds(self): for module in self.modules: module.zerograds()
@property def closure(self): return self.modules[0].closure @property def prepare_batch(self): return self.modules[0].prepare_batch
[docs]class ParallelOptimizerCumulateGradientsHook(object): """ A hook which sums up all replication's gradients in a DataParallel-Scenario """ name = "DataParallelCumulateGradients" call_for_each_param = False timing = 'pre' def __call__(self, optimizer: chainer.Optimizer): """ Summing up all parameters if the target is an instance of ``DataParallel`` Parameters ---------- optimizer : chainer.Optimizer the optimizer holding the target, whoose gradients should be summed across the replications """ if isinstance(, DataParallelChainerNetwork): for module in[1:]:[0].addgrads(module)
class ParallelOptimizerUpdateModelParameters(object): """ A hook to replicate all parameters from the root model, to all model-replicas after the optimizer step """ name = "DataParallelUpdateModelParams" call_for_each_param = False timing = "post" def __call__(self, optimizer: chainer.Optimizer): if isinstance(, DataParallelChainerNetwork): for module in[1:]: module.copyparams([0])
[docs]class DataParallelChainerOptimizer(chainer.Optimizer): """ An Optimizer-Wrapper to enable DataParallel. Basically this forwards all functions to the interal optimizer, but registers the additional hooks needed for DataParallel (namely :class:`ParallelOptimizerUpdateModelParameters` as a post-update hook and :class:`ParallelOptimizerCumulateGradientsHook` as a pre-update hook) """ def __init__(self, optimizer): """ Parameters ---------- optimizer : :class:`chainer.Optimizer` the optimizer to wrap """ if isinstance(optimizer, chainer.Optimizer): self._optimizer = optimizer else: raise RuntimeError("Invalid optimizer class given: Expected " "instance of chainer.Optimizer, but got %s" % optimizer.__class__.__name__)
[docs] @classmethod def from_optimizer_class(cls, optim_cls, *args, **kwargs): """ Parameters ---------- optim_cls : subclass of :class:`chainer.Optimizer` the optimizer to use internally *args : arbitrary positional arguments (will be used for initialization of internally used optimizer) **kwargs : arbitrary keyword arguments (will be used for initialization of internally used optimizer) """ if optim_cls is not None and issubclass(optim_cls, chainer.Optimizer): _optim = optim_cls(*args, **kwargs) else: raise RuntimeError("Invalid optimizer class given: Expected " "Subclass of chainer.Optimizer, but got %s" % optim_cls.__name__) return cls(_optim)
[docs] def setup(self, link): """ Calls the setup method of the internal optimizer and registers the necessary grads for data-parallel behavior Parameters ---------- link : :class:`DataParallel` the target, whoose parameters should be updated """ self._optimizer.setup(link) self._optimizer.add_hook(ParallelOptimizerCumulateGradientsHook()) self._optimizer.add_hook(ParallelOptimizerUpdateModelParameters())
@property def target(self): return @property def epoch(self): return self._optimizer.epoch @property def _pre_update_hooks(self): return self._optimizer._pre_update_hooks @property def _loss_scale(self): return self._optimizer._loss_scale @property def _loss_scale_max(self): return self._optimizer._loss_scale_max @property def _loss_scaling_is_dynamic(self): return self._optimizer._loss_scaling_is_dynamic @property def use_auto_new_epoch(self): return self._optimizer.use_auto_new_epoch @property def update(self): return self._optimizer.update @property def new_epoch(self): return self._optimizer.new_epoch @property def add_hook(self): return self._optimizer.add_hook @property def remove_hook(self): return self._optimizer.remove_hook @property def call_hooks(self): return self._optimizer.call_hooks @property def serialize(self): return self._optimizer.serialize @property def loss_scaling(self): return self._optimizer.loss_scaling @property def set_loss_scale(self): return self._optimizer.set_loss_scale @property def check_nan_in_grads(self): return self._optimizer.check_nan_in_grads @property def is_safe_to_update(self): return self._optimizer.is_safe_to_update @property def update_loss_scale(self): return self._optimizer.update_loss_scale