Source code for delira.data_loading.sampler.weighted_sampler

from random import choices
from ..dataset import AbstractDataset
from .abstract_sampler import AbstractSampler


[docs]class WeightedRandomSampler(AbstractSampler): """ Implements Weighted Random Sampling """ def __init__(self, indices, weights=None, cum_weights=None): """ Parameters ---------- indices : list list of classes each sample belongs to. List index corresponds to data index and the value at a certain index indicates the corresponding class weights : Any or None sampling weights; for more details see random.choices cum_weights : Any or None cumulative sampling weights; for more details see random.choices """ super().__init__() self._indices = list(range(len(indices))) self._weights = weights self._cum_weights = cum_weights self._global_index = 0
[docs] @classmethod def from_dataset(cls, dataset: AbstractDataset, **kwargs): """ Classmethod to initialize the sampler from a given dataset Parameters ---------- dataset : AbstractDataset the given dataset Returns ------- AbstractSampler The initialzed sampler """ indices = list(range(len(dataset))) labels = [d['label'] for d in dataset.data] return cls(labels, **kwargs)
[docs] def _get_indices(self, n_indices): """ Actual Sampling Parameters ---------- n_indices : int number of indices to return Returns ------- list list of sampled indices Raises ------ StopIteration If maximal number of samples is reached TypeError if weights and cum_weights are specified at the same time ValueError if weights or cum_weights don't match the population """ if self._global_index >= len(self._indices): self._global_index = 0 raise StopIteration new_global_idx = self._global_index + n_indices # If we reach end, make batch smaller if new_global_idx >= len(self._indices): new_global_idx = len(self._indices) samples = choices(self._indices, weights=self._weights, cum_weights=self._cum_weights, k=new_global_idx-self._global_index) self._global_index = new_global_idx return samples
def __len__(self): return len(self._indices)