import logging
import numpy as np
import typing
import inspect
from batchgenerators.dataloading import SlimDataLoaderBase, \
MultiThreadedAugmenter
from batchgenerators.transforms import AbstractTransform
from .dataset import AbstractDataset, BaseCacheDataset, BaseLazyDataset, \
ConcatDataset
from .data_loader import BaseDataLoader
from .load_utils import default_load_fn_2d
from .sampler import SequentialSampler, AbstractSampler
from ..utils.decorators import make_deprecated
logger = logging.getLogger(__name__)
[docs]class BaseDataManager(object):
"""
Class to Handle Data
Creates Dataset , Dataloader and BatchGenerator
"""
def __init__(self, data, batch_size, n_process_augmentation,
transforms, sampler_cls=SequentialSampler,
sampler_kwargs={},
data_loader_cls=None, dataset_cls=None,
load_fn=default_load_fn_2d, from_disc=True, **kwargs):
"""
Parameters
----------
data : str or Dataset
if str: Path to data samples
if dataset: Dataset
batch_size : int
Number of samples per batch
n_process_augmentation : int
Number of processes for augmentations
transforms :
Data transformations for augmentation
sampler_cls : AbstractSampler
class defining the sampling strategy
sampler_kwargs : dict
keyword arguments for sampling
data_loader_cls : subclass of SlimDataLoaderBase
DataLoader class
dataset_cls : subclass of AbstractDataset
Dataset class
load_fn : function
function to load simple sample
from_disc : bool
whether or not to load data from disc just the time it is needed
**kwargs :
other keyword arguments (needed for dataloading and passed to
dataset_cls)
Raises
------
AssertionError
* `data_loader_cls` is not :obj:`None` and not a subclass of
`SlimDataLoaderBase`
* `dataset_cls` is not :obj:`None` and not a subclass of
:class:`.AbstractDataset`
See Also
--------
:class:`AbstractDataset`
"""
# Instantiate Hidden variables for property access
self._batch_size = None
self._n_process_augmentation = None
self._transforms = None
self._data_loader_cls = None
self._dataset = None
self._sampler = None
# set actual values to properties
self.batch_size = batch_size
self.n_process_augmentation = n_process_augmentation
self.transforms = transforms
if data_loader_cls is None:
logger.info("No DataLoader Class specified. Using BaseDataLoader")
data_loader_cls = BaseDataLoader
else:
assert inspect.isclass(data_loader_cls), \
"data_loader_cls must be class not instance of class"
assert issubclass(data_loader_cls, SlimDataLoaderBase), \
"dater_loader_cls must be subclass of SlimDataLoaderBase"
self.data_loader_cls = data_loader_cls
if isinstance(data, AbstractDataset):
self.dataset = data
else:
if dataset_cls is None:
if from_disc:
dataset_cls = BaseLazyDataset
else:
dataset_cls = BaseCacheDataset
logger.info("No DataSet Class specified. Using %s instead" %
dataset_cls.__name__)
else:
assert issubclass(dataset_cls, AbstractDataset), \
"dataset_cls must be subclass of AbstractDataset"
self.dataset = dataset_cls(data, load_fn, **kwargs)
assert inspect.isclass(sampler_cls) and issubclass(sampler_cls,
AbstractSampler)
self.sampler = sampler_cls.from_dataset(self.dataset, **sampler_kwargs)
[docs] def get_batchgen(self, seed=1):
"""
Create DataLoader and Batchgenerator
Parameters
----------
seed : int
seed for Random Number Generator
Returns
-------
MultiThreadedAugmenter
Batchgenerator
Raises
------
AssertionError
:attr:`BaseDataManager.n_batches` is smaller than or equal to zero
"""
assert self.n_batches > 0
data_loader = self.data_loader_cls(self.dataset,
batch_size=self.batch_size,
num_batches=self.n_batches,
seed=seed,
sampler=self.sampler
)
return MultiThreadedAugmenter(data_loader, self.transforms,
self.n_process_augmentation,
num_cached_per_queue=2,
seeds=self.n_process_augmentation*[seed])
[docs] def get_subset(self, indices):
"""
Returns a Subset of the current datamanager based on given indices
Parameters
----------
indices : iterable
valid indices to extract subset from current dataset
Returns
-------
:class:`BaseDataManager`
manager containing the subset
"""
subset_kwargs = {
"batch_size": self.batch_size,
"n_process_augmentation": self.n_process_augmentation,
"transforms": self.transforms,
"sampler_cls": self.sampler.__class__,
"data_loader_cls": self.data_loader_cls,
"dataset_cls": None,
"load_fn": None,
"from_disc": True
}
return self.__class__(self.dataset.get_subset(indices), **subset_kwargs)
[docs] def update_state_from_dict(self, new_state: dict):
"""
Updates internal state and therfore the behavior from dict.
If a key is not specified, the old attribute value will be used
Parameters
----------
new_state : dict
The dict to update the state from.
Valid keys are:
* ``batch_size``
* ``n_process_augmentation``
* ``data_loader_cls``
* ``sampler``
* ``sampling_kwargs``
* ``transforms``
If a key is not specified, the old value of the corresponding
attribute will be used
Raises
------
KeyError
Invalid keys are specified
"""
# update batch_size if specified
self.batch_size = new_state.pop("batch_size", self.batch_size)
# update n_process_augmentation if specified
self.n_process_augmentation = new_state.pop("n_process_augmentation",
self.n_process_augmentation)
# update data_loader_cls if specified
self.data_loader_cls = new_state.pop("data_loader_cls",
self.data_loader_cls)
# update
new_sampler = new_state.pop("sampler", None)
if new_sampler is not None:
self.sampler = new_sampler.from_dataset(
self.dataset,
**new_state.pop("sampling_kwargs", {}))
self.transforms = new_state.pop("transforms", self.transforms)
if new_state:
raise KeyError("Invalid Keys in new_state given: %s"
% (','.join(map(str, new_state.keys()))))
[docs] @make_deprecated("BaseDataManager.get_subset")
def train_test_split(self, *args, **kwargs):
"""
Calls :method:`AbstractDataset.train_test_split` and returns
a manager for each subset with same configuration as current manager
.. deprecation:: 0.3
method will be removed in next major release
Parameters
----------
*args :
positional arguments for
``sklearn.model_selection.train_test_split``
**kwargs :
keyword arguments for
``sklearn.model_selection.train_test_split``
"""
trainset, valset = self.dataset.train_test_split(*args, **kwargs)
subset_kwargs = {
"batch_size": self.batch_size,
"n_process_augmentation": self.n_process_augmentation,
"transforms": self.transforms,
"sampler_cls": self.sampler.__class__,
"data_loader_cls": self.data_loader_cls,
"dataset_cls": None,
"load_fn": None,
"from_disc": True
}
train_mgr = self.__class__(trainset, **subset_kwargs)
val_mgr = self.__class__(valset, **subset_kwargs)
return train_mgr, val_mgr
@property
def batch_size(self):
"""
Property to access the batchsize
Returns
-------
int
the batchsize
"""
return self._batch_size
@batch_size.setter
def batch_size(self, new_batch_size):
"""
Setter for current batchsize, casts to int before setting the attribute
Parameters
----------
new_batch_size : int, Any
the new batchsize; should be int but can be of any type that can be
casted to an int
"""
self._batch_size = int(new_batch_size)
@property
def n_process_augmentation(self):
"""
Property to access the number of augmentation processes
Returns
-------
int
number of augmentation processes
"""
return self._n_process_augmentation
@n_process_augmentation.setter
def n_process_augmentation(self, new_process_number):
"""
Setter for number of augmentation processes, casts to int before setting
the attribute
Parameters
----------
new_process_number : int, Any
new number of augmentation processes; should be int but can be of
any type that can be casted to an int
"""
self._n_process_augmentation = int(new_process_number)
@property
def transforms(self):
"""
Property to access the current data transforms
Returns
-------
None, ``AbstractTransform``
The transformation, can either be None or an instance of
``AbstractTransform``
"""
return self._transforms
@transforms.setter
def transforms(self, new_transforms):
"""
Setter for data transforms, assert if transforms are of valid type
(either None or instance of ``AbstractTransform``)
Parameters
----------
new_transforms : None, ``AbstractTransform``
the new transforms
"""
assert new_transforms is None or isinstance(new_transforms,
AbstractTransform)
self._transforms = new_transforms
@property
def data_loader_cls(self):
"""
Property to access the current data loader class
Returns
-------
type
Subclass of ``SlimDataLoaderBase``
"""
return self._data_loader_cls
@data_loader_cls.setter
def data_loader_cls(self, new_loader_cls):
"""
Setter for current data loader class, asserts if class is of valid type
(must be a class and a subclass of ``SlimDataLoaderBase``)
Parameters
----------
new_loader_cls : type
the new data loader class
"""
assert inspect.isclass(new_loader_cls) and issubclass(new_loader_cls,
SlimDataLoaderBase)
self._data_loader_cls = new_loader_cls
@property
def dataset(self):
"""
Property to access the current dataset
Returns
-------
:class:`AbstractDataset`
the current dataset
"""
return self._dataset
@dataset.setter
def dataset(self, new_dataset):
"""
Setter for new dataset
Parameters
----------
new_dataset : :class:`AbstractDataset`
"""
assert isinstance(new_dataset, AbstractDataset)
self._dataset = new_dataset
@property
def sampler(self):
"""
Property to access the current sampler
Returns
-------
:class:`AbstractSampler`
the current sampler
"""
return self._sampler
@sampler.setter
def sampler(self, new_sampler):
"""
Setter for current sampler.
If a valid class instance is passed, the sampler is simply assigned, if
a valid class type is passed, the sampler is created from the dataset
Parameters
----------
new_sampler : :class:`AbstractSampler`, type
instance or class object of new sampler
Raises
------
ValueError
Neither a valid class instance nor a valid class type is given
"""
if inspect.isclass(new_sampler) and issubclass(new_sampler,
AbstractSampler):
self._sampler = new_sampler.from_dataset(self.dataset)
elif isinstance(new_sampler, AbstractSampler):
self._sampler = new_sampler
else:
raise ValueError("Given Sampler is neither a subclass of \
AbstractSampler, nor an instance of a sampler ")
@property
def n_samples(self):
"""
Number of Samples
Returns
-------
int
Number of Samples
"""
return len(self.sampler)
@property
def n_batches(self):
"""
Returns Number of Batches based on batchsize,
number of samples and number of processes
Returns
-------
int
Number of Batches
Raises
------
AssertionError
:attr:`BaseDataManager.n_samples` is smaller than or equal to zero
"""
assert self.n_samples > 0
if self.n_process_augmentation == 1:
n_batches = int(np.floor(self.n_samples / self.batch_size))
elif self.n_process_augmentation > 1:
if (self.n_samples / self.batch_size) < self.n_process_augmentation:
self.n_process_augmentation = 1
logger.warning('Too few samples for n_process_augmentation={}. '
'Forcing n_process_augmentation={} '
'instead'.format(self.n_process_augmentation, 1))
n_batches = int(np.floor(self.n_samples / self.batch_size /
self.n_process_augmentation))
else:
raise ValueError('Invalid value for n_process_augmentation')
return n_batches