import logging
import numpy as np
import typing
from batchgenerators.dataloading import SlimDataLoaderBase, \
MultiThreadedAugmenter
from torch.utils.data import ConcatDataset
from .dataset import AbstractDataset, BaseCacheDataset, BaseLazyDataset
from .data_loader import BaseDataLoader
from .load_utils import default_load_fn_2d
from .sampler import SequentialSampler
logger = logging.getLogger(__name__)
[docs]class BaseDataManager(object):
"""
Class to Handle Data
Creates Dataset , Dataloader and BatchGenerator
"""
def __init__(self, data, batch_size, n_process_augmentation,
transforms, sampler_cls=SequentialSampler,
data_loader_cls=None, dataset_cls=None,
load_fn=default_load_fn_2d, from_disc=True, **kwargs):
"""
Parameters
----------
data : str or Dataset
if str: Path to data samples
if dataset: Dataset
batch_size : int
Number of samples per batch
n_process_augmentation : int
Number of processes for augmentations
transforms :
Data transformations for augmentation
sampler_cls : AbstractSampler
class defining the sampling strategy
data_loader_cls : subclass of SlimDataLoaderBase
DataLoader class
dataset_cls : subclass of AbstractDataset
Dataset class
load_fn : function
function to load simple sample
from_disc : bool
whether or not to load data from disc just the time it is needed
**kwargs :
other keyword arguments (needed for dataloading and passed to
dataset_cls)
Raises
------
AssertionError
* `data_loader_cls` is not :obj:`None` and not a subclass of
`SlimDataLoaderBase`
* `dataset_cls` is not :obj:`None` and not a subclass of
:class:`.AbstractDataset`
See Also
--------
:class:`AbstractDataset`
"""
self.batch_size = batch_size
self.n_process_augmentation = n_process_augmentation
self.transforms = transforms
if data_loader_cls is None:
logger.info("No DataLoader Class specified. Using BaseDataLoader")
data_loader_cls = BaseDataLoader
else:
assert issubclass(data_loader_cls, SlimDataLoaderBase), \
"dater_loader_cls must be subclass of SlimDataLoaderBase"
self.data_loader_cls = data_loader_cls
if isinstance(data, AbstractDataset):
self.dataset = data
else:
if dataset_cls is None:
if from_disc:
dataset_cls = BaseLazyDataset
else:
dataset_cls = BaseCacheDataset
logger.info("No DataSet Class specified. Using %s instead" %
dataset_cls.__name__)
else:
assert issubclass(dataset_cls, AbstractDataset), \
"dataset_cls must be subclass of AbstractDataset"
self.dataset = dataset_cls(data, load_fn, **kwargs)
self.sampler = sampler_cls.from_dataset(self.dataset)
[docs] def get_batchgen(self, seed=1):
"""
Create DataLoader and Batchgenerator
Parameters
----------
seed : int
seed for Random Number Generator
Returns
-------
MultiThreadedAugmenter
Batchgenerator
Raises
------
AssertionError
:attr:`BaseDataManager.n_batches` is smaller than or equal to zero
"""
assert self.n_batches > 0
data_loader = self.data_loader_cls(self.dataset,
batch_size=self.batch_size,
num_batches=self.n_batches,
seed=seed,
sampler=self.sampler
)
return MultiThreadedAugmenter(data_loader, self.transforms,
self.n_process_augmentation,
num_cached_per_queue=2,
seeds=self.n_process_augmentation*[seed])
[docs] def train_test_split(self, *args, **kwargs):
"""
Calls :method:`AbstractDataset.train_test_split` and returns
a manager for each subset with same configuration as current manager
*args :
positional arguments for
``sklearn.model_selection.train_test_split``
**kwargs :
keyword arguments for
``sklearn.model_selection.train_test_split``
"""
trainset, valset = self.dataset.train_test_split(*args, **kwargs)
subset_kwargs = {
"batch_size": self.batch_size,
"n_process_augmentation": self.n_process_augmentation,
"transforms": self.transforms,
"sampler_cls": self.sampler.__class__,
"data_loader_cls": self.data_loader_cls,
"dataset_cls": None,
"load_fn": None,
"from_disc": True
}
train_mgr = self.__class__(trainset, **subset_kwargs)
val_mgr = self.__class__(valset, **subset_kwargs)
@property
def n_samples(self):
"""
Number of Samples
Returns
-------
int
Number of Samples
"""
return len(self.sampler)
@property
def n_batches(self):
"""
Returns Number of Batches based on batchsize,
number of samples and number of processes
Returns
-------
int
Number of Batches
Raises
------
AssertionError
:attr:`BaseDataManager.n_samples` is smaller than or equal to zero
"""
assert self.n_samples > 0
if self.n_process_augmentation == 1:
n_batches = int(np.floor(self.n_samples / self.batch_size))
elif self.n_process_augmentation > 1:
if (self.n_samples / self.batch_size) < self.n_process_augmentation:
self.n_process_augmentation = 1
logger.warning('Too few samples for n_process_augmentation={}. '
'Forcing n_process_augmentation={} '
'instead'.format(self.n_process_augmentation, 1))
n_batches = int(np.floor(self.n_samples / self.batch_size /
self.n_process_augmentation))
else:
raise ValueError('Invalid value for n_process_augmentation')
return n_batches
class ConcatDataManager(object):
"""
Class to concatenate DataManagers
"""
def __init__(self, datamanager=typing.List[BaseDataManager]):
"""
Parameters
----------
datamanager : list
the datamanagers which should be concatenated
(All attributes except the dataset are extracted
from the first manager inside the list)
"""
self.dataset = ConcatDataset(
[tmp.dataset for tmp in datamanager])
self.data_loader_cls = datamanager[0].data_loader_cls
self.batch_size = datamanager[0].batch_size
self.n_process_augmentation = datamanager[0].n_process_augmentation
self.transforms = datamanager[0].transforms
self.sampler = datamanager[0].sampler.__class__.from_dataset(
self.dataset
)
def get_batchgen(self, seed=1):
"""
Create DataLoader and Batchgenerator
Parameters
----------
seed : int
seed for Random Number Generator
Returns
-------
MultiThreadedAugmenter
Batchgenerator
Raises
------
AssertionError
:attr:`ConcatDataManager.n_batches` is smaller than or equal to zero
"""
assert self.n_batches > 0
data_loader = self.data_loader_cls(self.dataset,
batch_size=self.batch_size,
num_batches=self.n_batches,
seed=seed,
sampler=self.sampler
)
return MultiThreadedAugmenter(data_loader, self.transforms,
self.n_process_augmentation,
num_cached_per_queue=2,
seeds=self.n_process_augmentation*[seed])
@property
def n_samples(self):
"""
Number of Samples
Returns
-------
int
Number of Samples
"""
return len(self.sampler)
@property
def n_batches(self):
"""
Returns Number of Batches based on batchsize,
number of samples and number of processes
Returns
-------
int
Number of Batches
Raises
------
AssertionError
:attr:`ConcatDataManager.n_samples` is smaller than or equal to zero
"""
assert self.n_samples > 0
if self.n_process_augmentation == 1:
n_batches = int(np.floor(self.n_samples / self.batch_size))
elif self.n_process_augmentation > 1:
n_batches = int(np.floor(
self.n_samples / self.batch_size / self.n_process_augmentation))
else:
raise ValueError('Invalid value for n_process')
return n_batches