Source code for delira.data_loading.sampler.random_sampler

from collections import OrderedDict
from ..dataset import AbstractDataset
from .abstract_sampler import AbstractSampler

from numpy.random import choice, shuffle
from numpy import concatenate


[docs]class RandomSampler(AbstractSampler): """ Implements Random Sampling from whole Dataset """ def __init__(self, indices): """ Parameters ---------- indices : list list of classes each sample belongs to. List index corresponds to data index and the value at a certain index indicates the corresponding class """ super().__init__() self._indices = list(range(len(indices))) self._global_index = 0
[docs] def _get_indices(self, n_indices): """ Actual Sampling Parameters ---------- n_indices: int number of indices to return Returns ------- list list of sampled indices Raises ------ StopIteration If maximal number of samples is reached """ if self._global_index >= len(self._indices): self._global_index = 0 raise StopIteration new_global_idx = self._global_index + n_indices # If we reach end, make batch smaller if new_global_idx >= len(self._indices): new_global_idx = len(self._indices) indices = choice(self._indices, size=new_global_idx - self._global_index) # indices = choices(self._indices, k=new_global_idx - self._global_index) self._global_index = new_global_idx return indices
def __len__(self): return len(self._indices)
[docs]class PrevalenceRandomSampler(AbstractSampler): """ Implements random Per-Class Sampling and ensures same number of samplers per batch for each class """ def __init__(self, indices, shuffle_batch=True): """ Parameters ---------- indices : list list of classes each sample belongs to. List index corresponds to data index and the value at a certain index indicates the corresponding class shuffle_batch : bool if False: indices per class will be returned in a sequential way (first: indices belonging to class 1, second: indices belonging to class 2 etc.) while indices for each class are sampled in a random way if True: indices will be sampled in a random way per class and sampled indices will be shuffled """ super().__init__() self._num_indices = 0 _indices = {} for idx, class_idx in enumerate(indices): self._num_indices += 1 class_idx = int(class_idx) if class_idx in _indices.keys(): _indices[class_idx].append(idx) else: _indices[class_idx] = [idx] # sort classes after descending number of elements ordered_dict = OrderedDict() for k in sorted(_indices, key=lambda k: len(_indices[k]), reverse=True): ordered_dict[k] = _indices[k] self._indices = ordered_dict self._n_classes = len(_indices.keys()) self._shuffle = shuffle_batch self._global_index = 0
[docs] @classmethod def from_dataset(cls, dataset: AbstractDataset, **kwargs): """ Classmethod to initialize the sampler from a given dataset Parameters ---------- dataset : AbstractDataset the given dataset Returns ------- AbstractSampler The initialized sampler """ indices = range(len(dataset)) labels = [dataset[idx]['label'] for idx in indices] return cls(labels, **kwargs)
[docs] def _get_indices(self, n_indices): """ Actual Sampling Parameters ---------- n_indices : int number of indices to return Returns ------- list list of sampled indices Raises ------ StopIteration If maximal number of samples is reached """ if self._global_index >= self._num_indices: self._global_index = 0 raise StopIteration samples_per_class = int(n_indices / self._n_classes) _samples = [] for key, idx_list in self._indices.items(): _samples.append(choice(idx_list, size=samples_per_class)) # add elements until len(_samples) == n_indices # works because less indices are left than n_classes # and self._indices is sorted by decreasing number of elements for key, idx_list in self._indices.items(): if len(_samples) >= n_indices: break _samples.append(choice(idx_list, size=1)) _samples = concatenate(_samples) self._global_index += n_indices if self._shuffle: shuffle(_samples) return _samples
def __len__(self): return self._num_indices
[docs]class StoppingPrevalenceRandomSampler(AbstractSampler): """ Implements random Per-Class Sampling and ensures same number of samplers per batch for each class; Stops if out of samples for smallest class """ def __init__(self, indices, shuffle_batch=True): """ Parameters ---------- indices : list list of classes each sample belongs to. List index corresponds to data index and the value at a certain index indicates the corresponding class shuffle_batch : bool if False: indices per class will be returned in a sequential way (first: indices belonging to class 1, second: indices belonging to class 2 etc.) if True: indices will be sampled in a sequential way per class and sampled indices will be shuffled """ super().__init__() _indices = {} _global_idxs = {} for idx, class_idx in enumerate(indices): class_idx = int(class_idx) if class_idx in _indices.keys(): _indices[class_idx].append(idx) else: _indices[class_idx] = [idx] _global_idxs[class_idx] = 0 # sort classes after descending number of elements ordered_dict = OrderedDict() length = float('inf') for k in sorted(_indices, key=lambda k: len(_indices[k]), reverse=True): ordered_dict[k] = _indices[k] length = min(length, len(_indices[k])) self._length = length self._indices = ordered_dict self._n_classes = len(_indices.keys()) self._global_idxs = _global_idxs self._shuffle = shuffle_batch
[docs] @classmethod def from_dataset(cls, dataset: AbstractDataset, **kwargs): """ Classmethod to initialize the sampler from a given dataset Parameters ---------- dataset : AbstractDataset the given dataset Returns ------- AbstractSampler The initialized sampler """ indices = range(len(dataset)) labels = [dataset[idx]['label'] for idx in indices] return cls(labels, **kwargs)
[docs] def _get_indices(self, n_indices): """ Actual Sampling Parameters ---------- n_indices: int number of indices to return Raises ------ StopIteration: If end of class indices is reached for one class Returns ------- list: list of sampled indices """ samples_per_class = int(n_indices / self._n_classes) _samples = [] for key, idx_list in self._indices.items(): if self._global_idxs[key] >= len(idx_list): self._global_idxs[key] = 0 raise StopIteration new_global_idx = self._global_idxs[key] + samples_per_class if new_global_idx >= len(idx_list): new_global_idx = len(idx_list) _samples.append(choice(idx_list, size=samples_per_class)) self._global_idxs[key] = new_global_idx for key, idx_list in self._indices.items(): if len(_samples) >= n_indices: break if self._global_idxs[key] >= len(idx_list): self._global_idxs[key] = 0 new_global_idx = self._global_idxs[key] + 1 _samples.append(choice(idx_list, size=1)) self._global_idxs[key] = new_global_idx _samples = concatenate(_samples) if self._shuffle: shuffle(_samples) return _samples
def __len__(self): return self._length