Source code for delira.data_loading.sampler.sequential_sampler

from collections import OrderedDict
from random import shuffle

from delira.data_loading.sampler.abstract_sampler import AbstractSampler
from delira.data_loading.dataset import AbstractDataset


[docs]class SequentialSampler(AbstractSampler): """ Implements Sequential Sampling from whole Dataset """ def __init__(self, indices): """ Parameters ---------- indices : list list of classes each sample belongs to. List index corresponds to data index and the value at a certain index indicates the corresponding class """ super().__init__(indices) self._indices = list(range(len(indices)))
[docs] def _get_indices(self, n_indices): """ Actual Sampling Parameters ---------- n_indices : int number of indices to return Raises ------ StopIteration : If end of dataset reached Returns ------- list list of sampled indices """ n_indices = self._check_batchsize(n_indices) return range(self._global_index - n_indices, self._global_index)
def __len__(self): return self._num_samples
[docs]class PrevalenceSequentialSampler(AbstractSampler): """ Implements Per-Class Sequential sampling and ensures same number of samples per batch for each class; If out of samples for one class: restart at first sample """ def __init__(self, indices, shuffle_batch=True): """ Parameters ---------- indices : list list of classes each sample belongs to. List index corresponds to data index and the value at a certain index indicates the corresponding class shuffle_batch : bool if False: indices per class will be returned in a sequential way (first: indices belonging to class 1, second: indices belonging to class 2 etc.) if True: indices will be sampled in a sequential way per class and sampled indices will be shuffled """ super().__init__(indices) _indices = {} _global_idxs = {} for idx, class_idx in enumerate(indices): class_idx = int(class_idx) if class_idx in _indices.keys(): _indices[class_idx].append(idx) else: _indices[class_idx] = [idx] _global_idxs[class_idx] = 0 # sort classes after descending number of elements ordered_dict = OrderedDict() length = 0 for k in sorted(_indices, key=lambda k: len(_indices[k]), reverse=True): ordered_dict[k] = _indices[k] length += len(_indices[k]) self._num_samples = length self._indices = ordered_dict self._n_classes = len(_indices.keys()) self._global_idxs = _global_idxs self._shuffle = shuffle_batch
[docs] @classmethod def from_dataset(cls, dataset: AbstractDataset, **kwargs): indices = range(len(dataset)) labels = [dataset[idx]['label'] for idx in indices] return cls(labels, **kwargs)
[docs] def _get_indices(self, n_indices): """ Actual Sampling Parameters ---------- n_indices : int number of indices to return Raises ------ StopIteration : If end of class indices is reached Returns ------- list list of sampled indices """ n_indices = self._check_batchsize(n_indices) samples_per_class = n_indices // self._n_classes _samples = [] for key, idx_list in self._indices.items(): if self._global_idxs[key] >= len(idx_list): self._global_idxs[key] = 0 new_global_idx = self._global_idxs[key] + samples_per_class if new_global_idx >= len(idx_list): new_global_idx = len(idx_list) _samples += list(range(self._global_idxs[key], new_global_idx)) self._global_idxs[key] = new_global_idx for key, idx_list in self._indices.items(): if len(_samples) >= n_indices: break if self._global_idxs[key] >= len(idx_list): self._global_idxs[key] = 0 new_global_idx = self._global_idxs[key] + 1 _samples += list(range(self._global_idxs[key], new_global_idx)) self._global_idxs[key] = new_global_idx if self._shuffle: shuffle(_samples) return _samples
def __len__(self): return self._num_samples
[docs]class StoppingPrevalenceSequentialSampler(AbstractSampler): """ Implements Per-Class Sequential sampling and ensures same number of samples per batch for each class; Stops if all samples of first class have been sampled """ def __init__(self, indices, shuffle_batch=True): """ Parameters ---------- indices : list list of classes each sample belongs to. List index corresponds to data index and the value at a certain index indicates the corresponding class shuffle_batch : bool if False: indices per class will be returned in a sequential way (first: indices belonging to class 1, second: indices belonging to class 2 etc.) if True: indices will be sampled in a sequential way per class and sampled indices will be shuffled """ super().__init__() _indices = {} _global_idxs = {} for idx, class_idx in enumerate(indices): class_idx = int(class_idx) if class_idx in _indices.keys(): _indices[class_idx].append(idx) else: _indices[class_idx] = [idx] _global_idxs[class_idx] = 0 # sort classes after descending number of elements ordered_dict = OrderedDict() length = float('inf') for k in sorted(_indices, key=lambda k: len(_indices[k]), reverse=True): ordered_dict[k] = _indices[k] length = min(length, len(_indices[k])) self._length = length self._indices = ordered_dict self._n_classes = len(_indices.keys()) self._global_idxs = _global_idxs self._shuffle = shuffle_batch
[docs] @classmethod def from_dataset(cls, dataset: AbstractDataset): indices = range(len(dataset)) labels = [dataset[idx]['label'] for idx in indices] return cls(labels)
[docs] def _check_batchsize(self, n_indices): """ Checks if batchsize is valid for all classes Parameters ---------- n_indices : int the number of samples to return Returns ------- dict number of samples per class to return """ n_indices = super()._check_batchsize(n_indices) samples_per_class = n_indices // self._n_classes remaining = n_indices % self._n_classes samples = {} try: # sample same number of sample for each class for key, idx_list in self._indices.items(): if self._global_idxs[key] >= len(idx_list): raise StopIteration # truncate if necessary samples[key] = min( samples_per_class, len(self._indices[key]) - self._global_idxs[key]) self._global_idxs[key] += samples[key] # fill up starting with largest class while remaining: for key, idx_list in self._indices.items(): samples[key] += 1 remaining -= 1 except StopIteration as e: # set all global indices to 0 for key in self._global_idxs.keys(): self._global_idxs[key] = 0 raise e finally: return samples
[docs] def _get_indices(self, n_indices): """ Actual Sampling Parameters ---------- n_indices : int number of indices to return Raises ------ StopIteration : If end of class indices is reached for one class Returns ------- list list of sampled indices """ n_indices = self._check_batchsize(n_indices) samples = [] for key, val in n_indices.items(): start = self._global_idxs[key] - val end = self._global_idxs[key] samples += self._indices[start: end] if self._shuffle: shuffle(samples) return samples
def __len__(self): return self._length