diff --git a/avalanche/training/plugins/early_stopping.py b/avalanche/training/plugins/early_stopping.py index 2514fa25e..5ede6d788 100644 --- a/avalanche/training/plugins/early_stopping.py +++ b/avalanche/training/plugins/early_stopping.py @@ -115,18 +115,20 @@ def _update_best(self, strategy): f"Metric {self.metric_name} used by the EarlyStopping plugin " f"is not computed yet. EarlyStopping will not be triggered." ) - if self.best_val is None or self.operator(val_acc, self.best_val): + + if self.best_val is None: + self.best_state = deepcopy(strategy.model.state_dict()) + self.best_val = val_acc + self.best_step = self._get_strategy_counter(strategy) + return None + + delta_val = float(val_acc - self.best_val) + if self.operator(delta_val, 0) and abs(delta_val) >= self.margin: self.best_state = deepcopy(strategy.model.state_dict()) - if self.best_val is None: - self.best_val = val_acc - self.best_step = 0 - return None - - if self.operator(float(val_acc - self.best_val), self.margin): - self.best_step = self._get_strategy_counter(strategy) - self.best_val = val_acc - if self.verbose: - print("EarlyStopping: new best value:", val_acc) + self.best_val = val_acc + self.best_step = self._get_strategy_counter(strategy) + if self.verbose: + print("EarlyStopping: new best value:", val_acc) return self.best_val