Source code for federatedscope.core.monitors.early_stopper

import operator
import numpy as np


# TODO: make this as a sub-module of monitor class
[docs]class EarlyStopper(object): """ Track the history of metric (e.g., validation loss), \ check whether should stop (training) process if the metric doesn't \ improve after a given patience. Args: patience (int): (Default: 5) How long to wait after last time the \ monitored metric improved. Note that the \ ``actual_checking_round = patience * cfg.eval.freq`` delta (float): (Default: 0) Minimum change in the monitored metric to \ indicate an improvement. improve_indicator_mode (str): Early stop when no improve to \ last ``patience`` round, in ``['mean', 'best']`` """ def __init__(self, patience=5, delta=0, improve_indicator_mode='best', the_larger_the_better=True): assert 0 <= patience == int( patience ), "Please use a non-negtive integer to indicate the patience" assert delta >= 0, "Please use a positive value to indicate the change" assert improve_indicator_mode in [ 'mean', 'best' ], "Please make sure `improve_indicator_mode` is 'mean' or 'best']" self.patience = patience self.counter_no_improve = 0 self.best_metric = None self.early_stopped = False self.the_larger_the_better = the_larger_the_better self.delta = delta self.improve_indicator_mode = improve_indicator_mode # For expansion usages of comparisons self.comparator = operator.lt self.improvement_operator = operator.add def __track_and_check_dummy(self, new_result): """ Dummy stopper, always return false Args: new_result: Returns: False """ self.early_stopped = False return self.early_stopped def __track_and_check_best(self, history_result): """ Tracks the best result and checks whether the patience is exceeded. Args: history_result: results of all evaluation round Returns: Bool: whether stop """ new_result = history_result[-1] if self.best_metric is None: self.best_metric = new_result elif not self.the_larger_the_better and self.comparator( self.improvement_operator(self.best_metric, -self.delta), new_result): # add(best_metric, -delta) < new_result self.counter_no_improve += 1 elif self.the_larger_the_better and self.comparator( new_result, self.improvement_operator(self.best_metric, self.delta)): # new_result < add(best_metric, delta) self.counter_no_improve += 1 else: self.best_metric = new_result self.counter_no_improve = 0 self.early_stopped = self.counter_no_improve >= self.patience return self.early_stopped def __track_and_check_mean(self, history_result): new_result = history_result[-1] if len(history_result) > self.patience: if not self.the_larger_the_better and self.comparator( self.improvement_operator( np.mean(history_result[-self.patience - 1:-1]), -self.delta), new_result): self.early_stopped = True elif self.the_larger_the_better and self.comparator( new_result, self.improvement_operator( np.mean(history_result[-self.patience - 1:-1]), self.delta)): self.early_stopped = True else: self.early_stopped = False return self.early_stopped
[docs] def track_and_check(self, new_result): """ Checks the new result and if it improves it returns True. Args: new_result: new evaluation result Returns: Bool: whether stop """ track_method = self.__track_and_check_dummy # do nothing if self.patience == 0: track_method = self.__track_and_check_dummy elif self.improve_indicator_mode == 'best': track_method = self.__track_and_check_best elif self.improve_indicator_mode == 'mean': track_method = self.__track_and_check_mean return track_method(new_result)