import logging
import inspect
from batchgenerators.dataloading import MultiThreadedAugmenter, \
SingleThreadedAugmenter, SlimDataLoaderBase
from batchgenerators.transforms import AbstractTransform
from delira.data_loading.data_loader import BaseDataLoader
from delira.data_loading.dataset import AbstractDataset, BaseCacheDataset, \
BaseLazyDataset
from delira.data_loading.load_utils import default_load_fn_2d
from delira.data_loading.sampler import SequentialSampler, AbstractSampler
from delira.utils.decorators import make_deprecated
from delira import get_current_debug_mode
from multiprocessing import Queue
from queue import Full
logger = logging.getLogger(__name__)
[docs]class Augmenter(object):
"""
Class wrapping ``MultiThreadedAugmentor`` and ``SingleThreadedAugmenter``
to provide a uniform API and to disable multiprocessing/multithreading
inside the dataloading pipeline
"""
def __init__(self, data_loader: BaseDataLoader, transforms,
n_process_augmentation, sampler, sampler_queues: list,
num_cached_per_queue=2, seeds=None, **kwargs):
"""
Parameters
----------
data_loader : :class:`BaseDataLoader`
the dataloader providing the actual data
transforms : Callable or None
the transforms to use. Can be single callable or None
n_process_augmentation : int
the number of processes to use for augmentation (only necessary if
not in debug mode)
sampler : :class:`AbstractSampler`
the sampler to use; must be used here instead of inside the
dataloader to avoid duplications and oversampling due to
multiprocessing
sampler_queues : list of :class:`multiprocessing.Queue`
queues to pass the sample indices to the actual dataloader
num_cached_per_queue : int
the number of samples to cache per queue (only necessary if not in
debug mode)
seeds : int or list
the seeds for each process (only necessary if not in debug mode)
**kwargs :
additional keyword arguments
"""
self._batchsize = data_loader.batch_size
# don't use multiprocessing in debug mode
if get_current_debug_mode():
augmenter = SingleThreadedAugmenter(data_loader, transforms)
else:
assert isinstance(n_process_augmentation, int)
# no seeds are given -> use default seed of 1
if seeds is None:
seeds = 1
# only an int is gien as seed -> replicate it for each process
if isinstance(seeds, int):
seeds = [seeds] * n_process_augmentation
# avoid same seeds for all processes
if any([seeds[0] == _seed for _seed in seeds[1:]]):
for idx in range(len(seeds)):
seeds[idx] = seeds[idx] + idx
augmenter = MultiThreadedAugmenter(
data_loader, transforms,
num_processes=n_process_augmentation,
num_cached_per_queue=num_cached_per_queue,
seeds=seeds,
**kwargs)
self._augmenter = augmenter
self._sampler = sampler
self._sampler_queues = sampler_queues
self._queue_id = 0
def __iter__(self):
"""
Function returning an iterator
Returns
-------
Augmenter
self
"""
return self
[docs] def _next_queue(self):
idx = self._queue_id
self._queue_id = (self._queue_id + 1) % len(self._sampler_queues)
return self._sampler_queues[idx]
def __next__(self):
"""
Function to sample and load the next batch
Returns
-------
dict
the next batch
"""
idxs = self._sampler(self._batchsize)
queue = self._next_queue()
# dont't wait forever. Release this after short timeout and try again
# to avoid deadlock
while True:
try:
queue.put(idxs, timeout=0.2)
break
except Full:
continue
return next(self._augmenter)
[docs] def next(self):
"""
Function to sample and load
Returns
-------
dict
the next batch
"""
return next(self)
@staticmethod
def __identity_fn(*args, **kwargs):
"""
Helper function accepting arbitrary args and kwargs and returning
without doing anything
Parameters
----------
*args
keyword arguments
**kwargs
positional arguments
"""
return
[docs] def _fn_checker(self, function_name):
"""
Checks if the internal augmenter has a given attribute and returns it.
Otherwise it returns ``__identity_fn``
Parameters
----------
function_name : str
the function name to check for
Returns
-------
Callable
either the function corresponding to the given function name or
``__identity_fn``
"""
# same as:
# if hasattr(self._augmenter, function_name):
# return getattr(self._augmenter, functionname)
# else:
# return self.__identity_fn
# but one less getattr call, because hasattr also calls getattr and
# handles AttributeError
try:
return getattr(self._augmenter, function_name)
except AttributeError:
return self.__identity_fn
@property
def _start(self):
"""
Property to provide uniform API of ``_start``
Returns
-------
Callable
either the augmenter's ``_start`` method (if available) or
``__identity_fn`` (if not available)
"""
return self._fn_checker("_start")
[docs] def restart(self):
"""
Property to provide uniform API of ``restart``
Returns
-------
Callable
either the augmenter's ``restart`` method (if available) or
``__identity_fn`` (if not available)
"""
return self._fn_checker("restart")
[docs] def _finish(self):
"""
Property to provide uniform API of ``_finish``
Returns
-------
Callable
either the augmenter's ``_finish`` method (if available) or
``__identity_fn`` (if not available)
"""
ret_val = self._fn_checker("_finish")()
for queue in self._sampler_queues:
queue.close()
queue.join_thread()
return ret_val
@property
def num_batches(self):
"""
Property returning the number of batches
Returns
-------
int
number of batches
"""
if isinstance(self._augmenter, MultiThreadedAugmenter):
return self._augmenter.generator.num_batches
return self._augmenter.data_loader.num_batches
@property
def num_processes(self):
"""
Property returning the number of processes to use for loading and
augmentation
Returns
-------
int
number of processes to use for loading and
augmentation
"""
if isinstance(self._augmenter, MultiThreadedAugmenter):
return self._augmenter.num_processes
return 1
def __del__(self):
"""
Function defining what to do, if object should be deleted
"""
self._finish()
del self._augmenter
[docs]class BaseDataManager(object):
"""
Class to Handle Data
Creates Dataset , Dataloader and BatchGenerator
"""
def __init__(self, data, batch_size, n_process_augmentation,
transforms, sampler_cls=SequentialSampler,
sampler_kwargs=None,
data_loader_cls=None, dataset_cls=None,
load_fn=default_load_fn_2d, from_disc=True, **kwargs):
"""
Parameters
----------
data : str or Dataset
if str: Path to data samples
if dataset: Dataset
batch_size : int
Number of samples per batch
n_process_augmentation : int
Number of processes for augmentations
transforms :
Data transformations for augmentation
sampler_cls : AbstractSampler
class defining the sampling strategy
sampler_kwargs : dict
keyword arguments for sampling
data_loader_cls : subclass of SlimDataLoaderBase
DataLoader class
dataset_cls : subclass of AbstractDataset
Dataset class
load_fn : function
function to load simple sample
from_disc : bool
whether or not to load data from disc just the time it is needed
**kwargs :
other keyword arguments (needed for dataloading and passed to
dataset_cls)
Raises
------
AssertionError
* `data_loader_cls` is not :obj:`None` and not a subclass of
`SlimDataLoaderBase`
* `dataset_cls` is not :obj:`None` and not a subclass of
:class:`.AbstractDataset`
See Also
--------
:class:`AbstractDataset`
"""
# Instantiate Hidden variables for property access
if sampler_kwargs is None:
sampler_kwargs = {}
self._batch_size = None
self._n_process_augmentation = None
self._transforms = None
self._data_loader_cls = None
self._dataset = None
self._sampler = None
# set actual values to properties
self.batch_size = batch_size
self.n_process_augmentation = n_process_augmentation
self.transforms = transforms
if data_loader_cls is None:
logger.info("No DataLoader Class specified. Using BaseDataLoader")
data_loader_cls = BaseDataLoader
else:
assert inspect.isclass(data_loader_cls), \
"data_loader_cls must be class not instance of class"
assert issubclass(data_loader_cls, SlimDataLoaderBase), \
"dater_loader_cls must be subclass of SlimDataLoaderBase"
self.data_loader_cls = data_loader_cls
if isinstance(data, AbstractDataset):
self.dataset = data
else:
if dataset_cls is None:
if from_disc:
dataset_cls = BaseLazyDataset
else:
dataset_cls = BaseCacheDataset
logger.info("No DataSet Class specified. Using %s instead" %
dataset_cls.__name__)
else:
assert issubclass(dataset_cls, AbstractDataset), \
"dataset_cls must be subclass of AbstractDataset"
self.dataset = dataset_cls(data, load_fn, **kwargs)
assert inspect.isclass(sampler_cls) and issubclass(sampler_cls,
AbstractSampler)
self.sampler = sampler_cls.from_dataset(self.dataset, **sampler_kwargs)
[docs] def get_batchgen(self, seed=1):
"""
Create DataLoader and Batchgenerator
Parameters
----------
seed : int
seed for Random Number Generator
Returns
-------
Augmenter
Batchgenerator
Raises
------
AssertionError
:attr:`BaseDataManager.n_batches` is smaller than or equal to zero
"""
assert self.n_batches > 0
sampler_queues = []
for idx in range(self.n_process_augmentation):
sampler_queues.append(Queue())
data_loader = self.data_loader_cls(
self.dataset,
batch_size=self.batch_size,
num_batches=self.n_batches,
seed=seed,
sampler_queues=sampler_queues
)
return Augmenter(data_loader, self.transforms,
self.n_process_augmentation,
sampler=self.sampler,
sampler_queues=sampler_queues,
num_cached_per_queue=2,
seeds=self.n_process_augmentation * [seed])
[docs] def get_subset(self, indices):
"""
Returns a Subset of the current datamanager based on given indices
Parameters
----------
indices : iterable
valid indices to extract subset from current dataset
Returns
-------
:class:`BaseDataManager`
manager containing the subset
"""
subset_kwargs = {
"batch_size": self.batch_size,
"n_process_augmentation": self.n_process_augmentation,
"transforms": self.transforms,
"sampler_cls": self.sampler.__class__,
"data_loader_cls": self.data_loader_cls,
"dataset_cls": None,
"load_fn": None,
"from_disc": True
}
return self.__class__(
self.dataset.get_subset(indices),
**subset_kwargs)
[docs] def update_state_from_dict(self, new_state: dict):
"""
Updates internal state and therfore the behavior from dict.
If a key is not specified, the old attribute value will be used
Parameters
----------
new_state : dict
The dict to update the state from.
Valid keys are:
* ``batch_size``
* ``n_process_augmentation``
* ``data_loader_cls``
* ``sampler``
* ``sampling_kwargs``
* ``transforms``
If a key is not specified, the old value of the corresponding
attribute will be used
Raises
------
KeyError
Invalid keys are specified
"""
# update batch_size if specified
self.batch_size = new_state.pop("batch_size", self.batch_size)
# update n_process_augmentation if specified
self.n_process_augmentation = new_state.pop(
"n_process_augmentation", self.n_process_augmentation)
# update data_loader_cls if specified
self.data_loader_cls = new_state.pop("data_loader_cls",
self.data_loader_cls)
# update
new_sampler = new_state.pop("sampler", None)
if new_sampler is not None:
self.sampler = new_sampler.from_dataset(
self.dataset,
**new_state.pop("sampling_kwargs", {}))
self.transforms = new_state.pop("transforms", self.transforms)
if new_state:
raise KeyError("Invalid Keys in new_state given: %s"
% (','.join(map(str, new_state.keys()))))
[docs] @make_deprecated("BaseDataManager.get_subset")
def train_test_split(self, *args, **kwargs):
"""
Calls :method:`AbstractDataset.train_test_split` and returns
a manager for each subset with same configuration as current manager
.. deprecation:: 0.3
method will be removed in next major release
Parameters
----------
*args :
positional arguments for
``sklearn.model_selection.train_test_split``
**kwargs :
keyword arguments for
``sklearn.model_selection.train_test_split``
"""
trainset, valset = self.dataset.train_test_split(*args, **kwargs)
subset_kwargs = {
"batch_size": self.batch_size,
"n_process_augmentation": self.n_process_augmentation,
"transforms": self.transforms,
"sampler_cls": self.sampler.__class__,
"data_loader_cls": self.data_loader_cls,
"dataset_cls": None,
"load_fn": None,
"from_disc": True
}
train_mgr = self.__class__(trainset, **subset_kwargs)
val_mgr = self.__class__(valset, **subset_kwargs)
return train_mgr, val_mgr
@property
def batch_size(self):
"""
Property to access the batchsize
Returns
-------
int
the batchsize
"""
return self._batch_size
@batch_size.setter
def batch_size(self, new_batch_size):
"""
Setter for current batchsize, casts to int before setting the attribute
Parameters
----------
new_batch_size : int, Any
the new batchsize; should be int but can be of any type that can be
casted to an int
"""
self._batch_size = int(new_batch_size)
@property
def n_process_augmentation(self):
"""
Property to access the number of augmentation processes
Returns
-------
int
number of augmentation processes
"""
if get_current_debug_mode():
return 1
return self._n_process_augmentation
@n_process_augmentation.setter
def n_process_augmentation(self, new_process_number):
"""
Setter for number of augmentation processes, casts to int before
setting the attribute
Parameters
----------
new_process_number : int, Any
new number of augmentation processes; should be int but can be of
any type that can be casted to an int
"""
self._n_process_augmentation = int(new_process_number)
@property
def transforms(self):
"""
Property to access the current data transforms
Returns
-------
None, ``AbstractTransform``
The transformation, can either be None or an instance of
``AbstractTransform``
"""
return self._transforms
@transforms.setter
def transforms(self, new_transforms):
"""
Setter for data transforms, assert if transforms are of valid type
(either None or instance of ``AbstractTransform``)
Parameters
----------
new_transforms : None, ``AbstractTransform``
the new transforms
"""
assert new_transforms is None or isinstance(new_transforms,
AbstractTransform)
self._transforms = new_transforms
@property
def data_loader_cls(self):
"""
Property to access the current data loader class
Returns
-------
type
Subclass of ``SlimDataLoaderBase``
"""
return self._data_loader_cls
@data_loader_cls.setter
def data_loader_cls(self, new_loader_cls):
"""
Setter for current data loader class, asserts if class is of valid type
(must be a class and a subclass of ``SlimDataLoaderBase``)
Parameters
----------
new_loader_cls : type
the new data loader class
"""
assert inspect.isclass(new_loader_cls) and issubclass(
new_loader_cls, SlimDataLoaderBase)
self._data_loader_cls = new_loader_cls
@property
def dataset(self):
"""
Property to access the current dataset
Returns
-------
:class:`AbstractDataset`
the current dataset
"""
return self._dataset
@dataset.setter
def dataset(self, new_dataset):
"""
Setter for new dataset
Parameters
----------
new_dataset : :class:`AbstractDataset`
"""
assert isinstance(new_dataset, AbstractDataset)
self._dataset = new_dataset
@property
def sampler(self):
"""
Property to access the current sampler
Returns
-------
:class:`AbstractSampler`
the current sampler
"""
return self._sampler
@sampler.setter
def sampler(self, new_sampler):
"""
Setter for current sampler.
If a valid class instance is passed, the sampler is simply assigned, if
a valid class type is passed, the sampler is created from the dataset
Parameters
----------
new_sampler : :class:`AbstractSampler`, type
instance or class object of new sampler
Raises
------
ValueError
Neither a valid class instance nor a valid class type is given
"""
if inspect.isclass(new_sampler) and issubclass(new_sampler,
AbstractSampler):
self._sampler = new_sampler.from_dataset(self.dataset)
elif isinstance(new_sampler, AbstractSampler):
self._sampler = new_sampler
else:
raise ValueError("Given Sampler is neither a subclass of \
AbstractSampler, nor an instance of a sampler ")
@property
def n_samples(self):
"""
Number of Samples
Returns
-------
int
Number of Samples
"""
return len(self.sampler)
@property
def n_batches(self):
"""
Returns Number of Batches based on batchsize and number of samples
Returns
-------
int
Number of Batches
Raises
------
AssertionError
:attr:`BaseDataManager.n_samples` is smaller than or equal to zero
"""
assert self.n_samples > 0
n_batches = self.n_samples // self.batch_size
truncated_batch = self.n_samples % self.batch_size
n_batches += int(bool(truncated_batch))
return n_batches