diff --git a/torchensemble/soft_gradient_boosting.py b/torchensemble/soft_gradient_boosting.py index 2da5a73..ab688be 100644 --- a/torchensemble/soft_gradient_boosting.py +++ b/torchensemble/soft_gradient_boosting.py @@ -278,7 +278,7 @@ def fit( # Validation if test_loader: flag = self._evaluate_during_fit(test_loader, epoch) - if flag: + if save_model and flag: io.save(self, save_dir, self.logger) # Update the scheduler