from .abstract_callback import AbstractCallback

[docs]class EarlyStopping(AbstractCallback): """ Implements Early Stopping as callback See Also -------- :class:`AbstractCallback` """ def __init__(self, monitor_key, min_delta=0, patience=0, mode='min'): """ Parameters ---------- monitor_key : str the validation key to monitor min_delta : float or int the minimum difference between the best metric value so far and the current one patience : int number of epochs to wait before stopping training mode : str (default: 'min') defines the optimum for the monitored value """ super().__init__() self.monitor_key = monitor_key, self.min_delta = min_delta self.patience = patience self.mode = mode if 'min' == mode: self.best_metric = float('inf') elif 'max' == mode: self.best_metric = - float('inf') else: raise ValueError("Unknown compare mode: Got %s, but expected one " "of ['min', 'max']" % mode) self.epochs_waited = 0
[docs] def _is_better(self, metric): """ Helper function to decide whether the current metric is better than the best metric so far Parameters ---------- metric : current metric value Returns ------- bool whether this metric is the new best metric or not """ if 'min' == self.mode: return metric < (self.best_metric - self.min_delta) else: return metric > (self.best_metric + self.min_delta)
[docs] def at_epoch_end(self, trainer, **kwargs): """ Actual early stopping: Checks at end of each epoch if monitored metric is new best and if it hasn't improved over `self.patience` epochs, the training will be stopped Parameters ---------- trainer : :class:`AbstractNetworkTrainer` the trainer whose arguments can be modified **kwargs : additional keyword arguments Returns ------- :class:`AbstractNetworkTrainer` trainer with modified attributes """ metric = kwargs.get("val_metrics", {})[self.monitor_key] self.epochs_waited += 1 - int(self._is_better(metric)) if self.epochs_waited >= self.patience: stop_training = True else: stop_training = False return {"stop_training": stop_training}