Source code for delira.data_loading.dataset

import abc
import os
import typing
import numpy as np
from tqdm import tqdm
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from delira import get_backends

from ..utils import subdirs
from ..utils.decorators import make_deprecated


[docs]class AbstractDataset: """ Base Class for Dataset """ def __init__(self, data_path: str, load_fn: typing.Callable): """ Parameters ---------- data_path : str path to data samples load_fn : function function to load single sample """ self.data_path = data_path self._load_fn = load_fn self.data = []
[docs] @abc.abstractmethod def _make_dataset(self, path: str): """ Create dataset Parameters ---------- path : str path to data samples Returns ------- list data: List of sample paths if lazy; List of samples if not """ pass
@abc.abstractmethod def __getitem__(self, index): """ return data with given index (and loads it before if lazy) Parameters ---------- index : int index of data Returns ------- dict data """ pass def __len__(self): """ Return number of samples Returns ------- int number of samples """ return len(self.data) def __iter__(self): """ Return an iterator for the dataset Returns ------- object a single sample """ return _DatasetIter(self)
[docs] def get_sample_from_index(self, index): """ Returns the data sample for a given index (without any loading if it would be necessary) This implements the base case and can be subclassed for index mappings. The actual loading behaviour (lazy or cached) should be implemented in ``__getitem__`` See Also -------- :method:ConcatDataset.get_sample_from_index :method:BaseLazyDataset.__getitem__ :method:BaseCacheDataset.__getitem__ Parameters ---------- index : int index corresponding to targeted sample Returns ------- Any sample corresponding to given index """ return self.data[index]
[docs] def get_subset(self, indices): """ Returns a Subset of the current dataset based on given indices Parameters ---------- indices : iterable valid indices to extract subset from current dataset Returns ------- :class:`BlankDataset` the subset """ # extract other important attributes from current dataset kwargs = {} for key, val in vars(self).items(): if not (key.startswith("__") and key.endswith("__")): if key == "data": continue kwargs[key] = val kwargs["old_getitem"] = self.__class__.__getitem__ subset_data = [self.get_sample_from_index(idx) for idx in indices] return BlankDataset(subset_data, **kwargs)
[docs] @make_deprecated("Dataset.get_subset") def train_test_split(self, *args, **kwargs): """ split dataset into train and test data .. deprecated:: 0.3 method will be removed in next major release Parameters ---------- *args : positional arguments of ``train_test_split`` **kwargs : keyword arguments of ``train_test_split`` Returns ------- :class:`BlankDataset` train dataset :class:`BlankDataset` test dataset See Also -------- ``sklearn.model_selection.train_test_split`` """ train_idxs, test_idxs = train_test_split( np.arange(len(self)), *args, **kwargs) return self.get_subset(train_idxs), self.get_subset(test_idxs)
class _DatasetIter(object): """ Iterator for dataset """ def __init__(self, dset): """ Parameters ---------- dset: :class: `AbstractDataset` the dataset which should be iterated """ self._dset = dset self._curr_index = 0 def __iter__(self): return self def __next__(self): if self._curr_index >= len(self._dset): raise StopIteration sample = self._dset[self._curr_index] self._curr_index += 1 return sample class BlankDataset(AbstractDataset): """ Blank Dataset loading the data, which has been passed in it's ``__init__`` by it's ``_sample_fn`` """ def __init__(self, data, old_getitem, **kwargs): """ Parameters ---------- data : iterable data to load old_getitem : function get item method of previous dataset **kwargs : additional keyword arguments (are set as class attribute) """ super().__init__(None, None) self.data = data self._old_getitem = old_getitem for key, val in kwargs.items(): setattr(self, key, val) def __getitem__(self, index): """ returns single sample corresponding to ``index`` via the ``_sample_fn`` Parameters ---------- index : int index specifying the data to load Returns ------- dict dictionary containing a single sample """ return self._old_getitem(self, index) def __len__(self): """ returns the length of the dataset Returns ------- int number of samples """ return len(self.data)
[docs]class BaseCacheDataset(AbstractDataset): """ Dataset to preload and cache data Notes ----- data needs to fit completely into RAM! """ def __init__(self, data_path: typing.Union[str, list], load_fn: typing.Callable, **load_kwargs): """ Parameters ---------- data_path : str or list if data_path is a string, _sample_fn is called for all items inside the specified directory if data_path is a list, _sample_fn is called for elements in the list load_fn : function function to load a single data sample **load_kwargs : additional loading keyword arguments (image shape, channel number, ...); passed to _sample_fn """ super().__init__(data_path, load_fn) self._load_kwargs = load_kwargs self.data = self._make_dataset(data_path)
[docs] def _make_dataset(self, path: typing.Union[str, list]): """ Helper Function to make a dataset containing all samples in a certain directory Parameters ---------- path: str or list if data_path is a string, _sample_fn is called for all items inside the specified directory if data_path is a list, _sample_fn is called for elements in the list Returns ------- list list of items which where returned from _sample_fn (typically dict) Raises ------ AssertionError if `path` is not a list and is not a valid directory """ data = [] if isinstance(path, list): # iterate over all elements for p in tqdm(path, unit='samples', desc="Loading samples"): data.append(self._load_fn(p, **self._load_kwargs)) else: # call _sample_fn for all elements inside directory assert os.path.isdir(path), '%s is not a valid directory' % dir for p in tqdm(os.listdir(path), unit='samples', desc="Loading samples"): data.append(self._load_fn(os.path.join(path, p), **self._load_kwargs)) return data
def __getitem__(self, index): """ return data sample specified by index Parameters ---------- index : int index to specifiy which data sample to return Returns ------- dict data sample """ data_dict = self.get_sample_from_index(index) return data_dict
[docs]class BaseLazyDataset(AbstractDataset): """ Dataset to load data in a lazy way """ def __init__(self, data_path: typing.Union[str, list], load_fn: typing.Callable, **load_kwargs): """ Parameters ---------- data_path : str or list if data_path is a string, _sample_fn is called for all items inside the specified directory if data_path is a list, _sample_fn is called for elements in the list load_fn : function function to load single data sample **load_kwargs : additional loading keyword arguments (image shape, channel number, ...); passed to _sample_fn """ super().__init__(data_path, load_fn) self._load_kwargs = load_kwargs self.data = self._make_dataset(self.data_path)
[docs] def _make_dataset(self, path: typing.Union[str, list]): """ Helper Function to make a dataset containing paths to all images in a certain directory Parameters ---------- path : str or list path to data samples Returns ------- list list of sample paths Raises ------ AssertionError if `path` is not a valid directory """ if isinstance(path, list): # generate list from iterable data = list(path) else: # generate list from all items assert os.path.isdir(path), '%s is not a valid directory' % dir data = [os.path.join(path, p) for p in os.listdir(path)] return data
def __getitem__(self, index): """ load data sample specified by index Parameters ---------- index : int index to specifiy which data sample to load Returns ------- dict loaded data sample """ data_dict = self._load_fn(self.get_sample_from_index(index), **self._load_kwargs) return data_dict
class BaseExtendCacheDataset(BaseCacheDataset): """ Dataset to preload and cache data. Function to load sample is expected to return an iterable which can contain multiple samples Notes ----- data needs to fit completely into RAM! """ def __init__(self, data_path: typing.Union[str, list], load_fn: typing.Callable, **load_kwargs): """ Parameters ---------- data_path : str or list if data_path is a string, _sample_fn is called for all items inside the specified directory if data_path is a list, _sample_fn is called for elements in the list load_fn : function function to load a multiple data samples at once. Needs to return an iterable which extends the internal list. **load_kwargs : additional loading keyword arguments (image shape, channel number, ...); passed to _sample_fn See Also -------- :class: `BaseCacheDataset` """ super().__init__(data_path, load_fn, **load_kwargs) def _make_dataset(self, path: typing.Union[str, list]): """ Helper Function to make a dataset containing all samples in a certain directory Parameters ---------- path: str or iterable if data_path is a string, _sample_fn is called for all items inside the specified directory if data_path is a list, _sample_fn is called for elements in the list Returns ------- list list of items which where returned from _sample_fn (typically dict) Raises ------ AssertionError if `path` is not a list and is not a valid directory """ data = [] if isinstance(path, list): # iterate over all elements for p in tqdm(path, unit='samples', desc="Loading samples"): data.extend(self._load_fn(p, **self._load_kwargs)) else: # call _sample_fn for all elements inside directory assert os.path.isdir(path), '%s is not a valid directory' % dir for p in tqdm(os.listdir(path), unit='samples', desc="Loading samples"): data.extend(self._load_fn(os.path.join(path, p), **self._load_kwargs)) return data
[docs]class ConcatDataset(AbstractDataset): def __init__(self, *datasets): """ Concatenate multiple datasets to one Parameters ---------- datasets: variable number of datasets """ super().__init__(None, None) # TODO: Why should datasets[0] be a list not a AbstractDataset? # check if first item in datasets is list and datasets is of length 1 if (len(datasets) == 1) and isinstance(datasets[0], list): datasets = datasets[0] self.data = datasets
[docs] def get_sample_from_index(self, index): """ Returns the data sample for a given index (without any loading if it would be necessary) This method implements the index mapping of a global index to the subindices for each dataset. The actual loading behaviour (lazy or cached) should be implemented in ``__getitem__`` See Also -------- :method:AbstractDataset.get_sample_from_index :method:BaseLazyDataset.__getitem__ :method:BaseCacheDataset.__getitem__ Parameters ---------- index : int index corresponding to targeted sample Returns ------- Any sample corresponding to given index """ curr_max_index = 0 for dset in self.data: prev_max_index = curr_max_index curr_max_index += len(dset) if prev_max_index <= index < curr_max_index: return dset[index - prev_max_index] else: continue raise IndexError("Index %d is out of range for %d items in datasets" % (index, len(self)))
def __getitem__(self, index): return self.get_sample_from_index(index) def __len__(self): return sum([len(dset) for dset in self.data])
@make_deprecated('Will be removed in favour of LoadSample function.') class Nii3DLazyDataset(BaseLazyDataset): """ Dataset to load 3D medical images (e.g. from .nii files) during training """ def __init__(self, data_path, load_fn, img_extensions, gt_extensions, img_files, label_file, **load_kwargs): """ Parameters ---------- data_path : str root path to data samples where each samples has it's own folder load_fn : function function to load single data sample img_extensions : list valid extensions of image files gt_extensions : list valid extensions of label files img_files : list list of image filenames label_file : string label file name **load_kwargs : additional loading keyword arguments (image shape, channel number, ...); passed to load_fn """ self.img_files = img_files self.label_file = label_file super().__init__(data_path, load_fn, **load_kwargs) def _make_dataset(self, path): """ Helper Function to make a dataset containing all samples in a certain directory Parameters ---------- path: str path to data samples Returns ------- list list of sample paths Raises ------ AssertionError if `path` is not a valid directory """ assert os.path.isdir(path) data = [[{'img': [os.path.join(t, i) for i in self.img_files], 'label': os.path.join(t, self.label_file)}] for t in subdirs(path)] return data @make_deprecated('Will be removed in favour of LoadSample function.') class Nii3DCacheDatset(BaseCacheDataset): """ Dataset to load 3D medical images (e.g. from .nii files) before training """ def __init__(self, data_path, load_fn, img_extensions, gt_extensions, img_files, label_file, **load_kwargs): """ Parameters ---------- data_path : str root path to data samples where each samples has it's own folder load_fn : function function to load single data sample img_extensions : list valid extensions of image files gt_extensions : list valid extensions of label files img_files : list list of image filenames label_file : str label file name **load_kwargs : additional loading keyword arguments (image shape, channel number, ...); passed to load_fn """ self.img_files = img_files self.label_file = label_file super().__init__(data_path, load_fn, **load_kwargs) def _make_dataset(self, path): """ Helper Function to make a dataset containing all samples in a certain directory Parameters ---------- path: str path to data samples Returns ------- list list of samples Raises ------ AssertionError if `path` is not a valid directory """ assert os.path.isdir(path) data = [] for s in tqdm(subdirs(path), unit='samples', desc="Loading samples"): files = {'img': [os.path.join(s, i) for i in self.img_files], 'label': os.path.join(s, self.label_file)} data.append(self._load_fn(files, **self._load_kwargs)) return data if "TORCH" in get_backends(): from torchvision.datasets import CIFAR10, CIFAR100, EMNIST, MNIST, FashionMNIST class TorchvisionClassificationDataset(AbstractDataset): """ Wrapper for torchvision classification datasets to provide consistent API """ def __init__(self, dataset, root="/tmp/", train=True, download=True, img_shape=(28, 28), one_hot=False, **kwargs): """ Parameters ---------- dataset : str Defines the dataset to use. must be one of ['mnist', 'emnist', 'fashion_mnist', 'cifar10', 'cifar100'] root : str path dataset (If download is True: dataset will be extracted here; else: path to extracted dataset) train : bool whether to use the train or the testset download : bool whether or not to download the dataset (If already downloaded at specified path, it won't be downloaded again) img_shape : tuple Height and width of output images (will be interpolated) **kwargs : Additional keyword arguments passed to the torchvision dataset class for initialization """ super().__init__("", None) self.download = download self.train = train self.root = root self.img_shape = img_shape self.num_classes = None self.one_hot = one_hot self.data = self._make_dataset(dataset, **kwargs) def _make_dataset(self, dataset, **kwargs): """ Create the actual dataset Parameters ---------- dataset: str Defines the dataset to use. must be one of ['mnist', 'emnist', 'fashion_mnist', 'cifar10', 'cifar100'] **kwargs : Additional keyword arguments passed to the torchvision dataset class for initialization Returns ------- torchvision.Dataset actual Dataset Raises ------ KeyError Dataset string does not specify a valid dataset """ if dataset.lower() == "mnist": _dataset_cls = MNIST self.num_classes = 10 elif dataset.lower() == "emnist": _dataset_cls = EMNIST # TODO: EMNIST requires split as kwarg. Search for 'split' in kwargs and # update self.num_classes accordingly # https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.EMNIST self.num_classes = None elif dataset.lower() == "fashion_mnist": _dataset_cls = FashionMNIST self.num_classes = 10 elif dataset.lower() == "cifar10": _dataset_cls = CIFAR10 self.num_classes = 10 elif dataset.lower() == "cifar100": _dataset_cls = CIFAR100 self.num_classes = 100 else: raise KeyError("Dataset %s not found!" % dataset.lower()) return _dataset_cls(root=self.root, train=self.train, download=self.download, **kwargs) def __getitem__(self, index): """ return data sample specified by index Parameters ---------- index : int index to specifiy which data sample to return Returns ------- dict data sample """ data = self.data[index] data_dict = {"data": np.array(data[0]), "label": data[1].reshape(1).astype(np.float32)} if self.one_hot: # TODO: Remove and refer to batchgenerators transform: # https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/utility_transforms.py#L97 def make_onehot(num_classes, labels): """ Function that converts label-encoding to one-hot format. Parameters ---------- num_classes : int number of classes present in the dataset labels : np.ndarray labels in label-encoding format Returns ------- np.ndarray labels in one-hot format """ if isinstance(labels, list) or isinstance(labels, int): labels = np.asarray(labels) assert isinstance(labels, np.ndarray) if len(labels.shape) > 1: one_hot = np.zeros(shape=(list(labels.shape) + [num_classes]), dtype=labels.dtype) for i, c in enumerate(np.arange(num_classes)): one_hot[..., i][labels == c] = 1 else: one_hot = np.zeros(shape=([num_classes]), dtype=labels.dtype) for i, c in enumerate(np.arange(num_classes)): if labels == c: one_hot[i] = 1 return one_hot data_dict['label'] = make_onehot(self.num_classes, data_dict['label']) img = data_dict["data"] img = resize(img, self.img_shape, mode='reflect', anti_aliasing=True) if len(img.shape) <= 3: img = img.reshape( *img.shape, 1) img = img.transpose( (len(img.shape) - 1, *range(len(img.shape) - 1))) data_dict["data"] = img.astype(np.float32) return data_dict def __len__(self): """ Return Number of samples Returns ------- int number of samples """ return len(self.data)