import abc
import os
from tqdm import tqdm
import numpy as np
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from ..utils import subdirs
[docs]class AbstractDataset:
"""
Base Class for Dataset
"""
def __init__(self, data_path, load_fn, img_extensions, gt_extensions):
"""
Parameters
----------
data_path : str
path to data samples
load_fn : function
function to load single sample
img_extensions : list
valid extensions of image files
gt_extensions : list
valid extensions of label files
"""
self._img_extensions = img_extensions
self._gt_extensions = gt_extensions
self.data_path = data_path
self._load_fn = load_fn
self.data = []
[docs] @abc.abstractmethod
def _make_dataset(self, path):
"""
Create dataset
Parameters
----------
path : str
path to data samples
Returns
-------
list
data: List of sample paths if lazy; List of samples if not
"""
pass
@abc.abstractmethod
def __getitem__(self, index):
"""
return data with given index (and loads it before if lazy)
Parameters
----------
index : int
index of data
Returns
-------
dict
data
"""
pass
def __len__(self):
"""
Return number of samples
Returns
-------
int
number of samples
"""
return len(self.data)
[docs] def train_test_split(self, *args, **kwargs):
"""
split dataset into train and test data
Parameters
----------
*args :
positional arguments of ``train_test_split``
**kwargs :
keyword arguments of ``train_test_split``
Returns
-------
:class:`BlankDataset`
train dataset
:class:`BlankDataset`
test dataset
See Also
--------
``sklearn.model_selection.train_test_split``
"""
train_idxs, test_idxs = train_test_split(
np.arange(len(self)), *args, **kwargs)
train_data = [self.data[idx] for idx in train_idxs]
test_data = [self.data[idx] for idx in test_idxs]
kwargs = {}
for key, val in vars(self).items():
if not (key.startswith("__") and key.endswith("__")):
kwargs[key] = val
kwargs["__getitem__"] = self.__getitem__
train_dset = BlankDataset(train_data, **kwargs)
test_dset = BlankDataset(test_data, **kwargs)
return train_dset, test_dset
class BlankDataset(AbstractDataset):
"""
Blank Dataset loading the data, which has been passed
in it's ``__init__`` by it's ``load_fn``
"""
def __init__(self, data, load_fn, load_kwargs={}, **kwargs):
"""
Parameters
----------
data : iterable
data to load
load_fn : function
function to load the ``data``
load_kwargs : dict
dictionary containing all keyword arguments passed to the ``load_fn``
**kwargs :
additional keyword arguments (are set as class attribute)
"""
super().__init__(None, load_fn, None, None)
self.data = data
self.load_kwargs = load_kwargs
for key, val in kwargs.items():
setattr(self, key, val)
def __getitem__(self, index):
"""
returns single sample corresponding to ``index`` via the ``load_fn``
Parameters
----------
index : int
index specifying the data to load
Returns
-------
dict
dictionary containing a single sample
"""
return self._load_fn(self.data[index], **self.load_kwargs)
def __len__(self):
"""
returns the length of the dataset
Returns
-------
int
number of samples
"""
return len(self.data)
[docs]class BaseLazyDataset(AbstractDataset):
"""
Dataset to load data in a lazy way
"""
def __init__(self, data_path, load_fn, img_extensions, gt_extensions,
**load_kwargs):
"""
Parameters
----------
data_path : str
path to data samples
load_fn : function
function to load single data sample
img_extensions : list
valid extensions of image files
gt_extensions : list
valid extensions of label files
**load_kwargs :
additional loading keyword arguments (image shape,
channel number, ...); passed to load_fn
"""
super().__init__(data_path, load_fn, img_extensions, gt_extensions)
self._load_kwargs = load_kwargs
self.data = self._make_dataset(self.data_path)
[docs] def _make_dataset(self, path):
"""
Helper Function to make a dataset containing paths to all images in a
certain directory
Parameters
----------
path : str
path to data samples
Returns
-------
list
list of sample paths
Raises
------
AssertionError
if `path` is not a valid directory
"""
data = []
assert os.path.isdir(path), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(path)):
for fname in fnames:
fpath = os.path.join(root, fname)
if self._is_valid_image_file(fpath):
sample = [fpath]
for ext in self._gt_extensions:
gt_path = fpath.rsplit(".", maxsplit=1)[0] + ext
if os.path.isfile(gt_path):
sample.append(gt_path)
data.append(sample)
return data
[docs] def _is_valid_image_file(self, fname):
"""
Helper Function to check wheter file is image file and has at least
one label file
Parameters
----------
fname : str
filename of image path
Returns
-------
bool
is valid data sample
"""
is_valid_file = False
for ext in self._img_extensions:
if fname.endswith(ext):
is_valid_file = True
has_label = False
for ext in self._gt_extensions:
label_file = fname.rsplit(".", maxsplit=1)[0] + ext
if os.path.isfile(label_file):
has_label = True
return is_valid_file and has_label
def __getitem__(self, index):
"""
load data sample specified by index
Parameters
----------
index : int
index to specifiy which data sample to load
Returns
-------
dict
loaded data sample
"""
data_dict = self._load_fn(*self.data[index], **self._load_kwargs)
return data_dict
[docs]class BaseCacheDataset(AbstractDataset):
"""
Dataset to preload and cache data
Notes
-----
data needs to fit completely into RAM!
"""
def __init__(self, data_path, load_fn, img_extensions, gt_extensions,
**load_kwargs):
"""
Parameters
----------
data_path : str
path to data samples
load_fn : function
function to load single data sample
img_extensions : list
valid extensions of image files
gt_extensions : list
valid extensions of label files
**load_kwargs :
additional loading keyword arguments (image shape,
channel number, ...); passed to load_fn
"""
super().__init__(data_path, load_fn, img_extensions, gt_extensions)
self._load_kwargs = load_kwargs
self.data = self._make_dataset(data_path)
[docs] def _make_dataset(self, path):
"""
Helper Function to make a dataset containing all samples in a certain
directory
Parameters
----------
path: str
path to data samples
Returns
-------
list
list of sample paths
Raises
------
AssertionError
if `path` is not a valid directory
"""
data = []
assert os.path.isdir(path), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(path)):
for fname in fnames:
fpath = os.path.join(root, fname)
if self._is_valid_image_file(fpath):
sample = [fpath]
for ext in self._gt_extensions:
gt_path = fpath.rsplit(".", maxsplit=1)[0] + ext
if os.path.isfile(gt_path):
sample.append(gt_path)
data.append(self._load_fn(
*sample, **self._load_kwargs))
return data
[docs] def _is_valid_image_file(self, fname):
"""
Helper Function to check wheter file is image file and has at least
one label file
Parameters
----------
fname : str
filename of image path
Returns
-------
bool
is valid data sample
"""
is_valid_file = False
for ext in self._img_extensions:
if fname.endswith(ext):
is_valid_file = True
has_label = False
for ext in self._gt_extensions:
label_file = fname.rsplit(".", maxsplit=1)[0] + ext
if os.path.isfile(label_file):
has_label = True
return is_valid_file and has_label
def __getitem__(self, index):
"""
return data sample specified by index
Parameters
----------
index : int
index to specifiy which data sample to return
Returns
-------
dict
data sample
"""
data_dict = self.data[index]
return data_dict
class Nii3DLazyDataset(BaseLazyDataset):
"""
Dataset to load 3D medical images (e.g. from .nii files) during training
"""
def __init__(self, data_path, load_fn, img_extensions, gt_extensions,
img_files, label_file, **load_kwargs):
"""
Parameters
----------
data_path : str
root path to data samples where each samples has it's own folder
load_fn : function
function to load single data sample
img_extensions : list
valid extensions of image files
gt_extensions : list
valid extensions of label files
img_files : list
list of image filenames
label_file : string
label file name
**load_kwargs :
additional loading keyword arguments (image shape,
channel number, ...); passed to load_fn
"""
self.img_files = img_files
self.label_file = label_file
super().__init__(data_path, load_fn, img_extensions, gt_extensions,
**load_kwargs)
def _make_dataset(self, path):
"""
Helper Function to make a dataset containing all samples in a certain
directory
Parameters
----------
path: str
path to data samples
Returns
-------
list
list of sample paths
Raises
------
AssertionError
if `path` is not a valid directory
"""
assert os.path.isdir(path)
data = [[{'img': [os.path.join(t, i) for i in self.img_files],
'label': os.path.join(t, self.label_file)}]
for t in subdirs(path)]
return data
class Nii3DCacheDatset(BaseCacheDataset):
"""
Dataset to load 3D medical images (e.g. from .nii files) before training
"""
def __init__(self, data_path, load_fn, img_extensions, gt_extensions,
img_files, label_file, **load_kwargs):
"""
Parameters
----------
data_path : str
root path to data samples where each samples has it's own folder
load_fn : function
function to load single data sample
img_extensions : list
valid extensions of image files
gt_extensions : list
valid extensions of label files
img_files : list
list of image filenames
label_file : str
label file name
**load_kwargs :
additional loading keyword arguments (image shape,
channel number, ...); passed to load_fn
"""
self.img_files = img_files
self.label_file = label_file
super().__init__(data_path, load_fn, img_extensions, gt_extensions,
**load_kwargs)
def _make_dataset(self, path):
"""
Helper Function to make a dataset containing all samples in a certain
directory
Parameters
----------
path: str
path to data samples
Returns
-------
list
list of samples
Raises
------
AssertionError
if `path` is not a valid directory
"""
assert os.path.isdir(path)
data = []
for s in tqdm(subdirs(path), unit='samples', desc="Loading samples"):
files = {'img': [os.path.join(s, i) for i in self.img_files],
'label': os.path.join(s, self.label_file)}
data.append(self._load_fn(files, **self._load_kwargs))
return data
try:
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
class TorchvisionClassificationDataset(AbstractDataset):
"""
Wrapper for torchvision classification datasets to provide consistent API
"""
def __init__(self, dataset, root="/tmp/", train=True, download=True,
img_shape=(28, 28), **kwargs):
"""
Parameters
----------
dataset : str
Defines the dataset to use.
must be one of
['mnist', 'emnist', 'fashion_mnist', 'cifar10', 'cifar100']
root : str
path dataset (If download is True: dataset will be extracted here;
else: path to extracted dataset)
train : bool
whether to use the train or the testset
download : bool
whether or not to download the dataset
(If already downloaded at specified path,
it won't be downloaded again)
img_shape : tuple
Height and width of output images (will be interpolated)
**kwargs :
Additional keyword arguments passed to the torchvision dataset
class for initialization
"""
super().__init__("", None, [], [])
self.download = download
self.train = train
self.root = root
self.img_shape = img_shape
self.data = self._make_dataset(dataset, **kwargs)
def _make_dataset(self, dataset, **kwargs):
"""
Create the actual dataset
Parameters
----------
dataset: str
Defines the dataset to use.
must be one of
['mnist', 'emnist', 'fashion_mnist', 'cifar10', 'cifar100']
**kwargs :
Additional keyword arguments passed to the torchvision dataset
class for initialization
Returns
-------
torchvision.Dataset
actual Dataset
Raises
------
KeyError
Dataset string does not specify a valid dataset
"""
if dataset.lower() == "mnist":
_dataset_cls = MNIST
elif dataset.lower() == "emnist":
_dataset_cls = EMNIST
elif dataset.lower() == "fashion_mnist":
_dataset_cls = FashionMNIST
elif dataset.lower() == "cifar10":
_dataset_cls = CIFAR10
elif dataset.lower() == "cifar100":
_dataset_cls = CIFAR100
else:
raise KeyError("Dataset %s not found!" % dataset.lower())
return _dataset_cls(root=self.root, train=self.train,
download=self.download, **kwargs)
def __getitem__(self, index):
"""
return data sample specified by index
Parameters
----------
index : int
index to specifiy which data sample to return
Returns
-------
dict
data sample
"""
data = self.data[index]
data_dict = {"data": np.array(data[0]),
"label": data[1].numpy().reshape(1).astype(np.float32)}
img = data_dict["data"]
img = resize(img, self.img_shape, mode='reflect', anti_aliasing=True)
if len(img.shape) <= 3:
img = img.reshape(
*img.shape, 1)
img = img.transpose((len(img.shape) - 1, *range(len(img.shape) - 1)))
data_dict["data"] = img.astype(np.float32)
return data_dict
def __len__(self):
"""
Return Number of samples
Returns
-------
int
number of samples
"""
return len(self.data)
except ImportError as e:
raise e