From e3fb9fe7a3767e28805a381c115301618de968e0 Mon Sep 17 00:00:00 2001 From: Yi-Xuan Xu Date: Fri, 27 Jan 2023 18:53:23 +0800 Subject: [PATCH] fix: estimator overwrite problem with validation data loader (#143) * update code * Update index.rst --- CHANGELOG.rst | 1 + docs/index.rst | 6 ++++-- torchensemble/adversarial_training.py | 22 ++++++++++++---------- torchensemble/bagging.py | 22 ++++++++++++---------- torchensemble/voting.py | 22 ++++++++++++---------- 5 files changed, 41 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 51f055b..e9c4d99 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ Changelog Ver 0.1.* --------- +* |Fix| Fix the issue that model will be overwritten when using the validation data loder | `@xuyxu `__ * |Feature| Add internal :meth:`unsqueeze` operation in :meth:`forward` of all classifiers | `@xuyxu `__ * |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg `__ * |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe `__ diff --git a/docs/index.rst b/docs/index.rst index 54c9a69..3db9db6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -64,8 +64,9 @@ Content ------- .. toctree:: - :maxdepth: 1 :caption: For Users + :maxdepth: 2 + :hidden: Quick Start Introduction @@ -74,8 +75,9 @@ Content API Reference .. toctree:: - :maxdepth: 1 :caption: For Developers + :maxdepth: 2 + :hidden: Changelog Roadmap diff --git a/torchensemble/adversarial_training.py b/torchensemble/adversarial_training.py index 9e0d318..b42e3bb 100644 --- a/torchensemble/adversarial_training.py +++ b/torchensemble/adversarial_training.py @@ -384,6 +384,12 @@ def _forward(estimators, *x): acc, epoch, ) + # No validation + else: + self.estimators_ = nn.ModuleList() + self.estimators_.extend(estimators) + if save_model: + io.save(self, save_dir, self.logger) # Update the scheduler with warnings.catch_warnings(): @@ -402,11 +408,6 @@ def _forward(estimators, *x): else: scheduler_.step() - self.estimators_ = nn.ModuleList() - self.estimators_.extend(estimators) - if save_model and not test_loader: - io.save(self, save_dir, self.logger) - @torchensemble_model_doc(item="classifier_evaluate") def evaluate(self, test_loader, return_loss=False): return super().evaluate(test_loader, return_loss) @@ -580,6 +581,12 @@ def _forward(estimators, *x): val_loss, epoch, ) + # No validation + else: + self.estimators_ = nn.ModuleList() + self.estimators_.extend(estimators) + if save_model: + io.save(self, save_dir, self.logger) # Update the scheduler with warnings.catch_warnings(): @@ -595,11 +602,6 @@ def _forward(estimators, *x): else: scheduler_.step() - self.estimators_ = nn.ModuleList() - self.estimators_.extend(estimators) - if save_model and not test_loader: - io.save(self, save_dir, self.logger) - @torchensemble_model_doc(item="regressor_evaluate") def evaluate(self, test_loader): return super().evaluate(test_loader) diff --git a/torchensemble/bagging.py b/torchensemble/bagging.py index 623b7c3..b62fce3 100644 --- a/torchensemble/bagging.py +++ b/torchensemble/bagging.py @@ -253,6 +253,12 @@ def _forward(estimators, *x): self.tb_logger.add_scalar( "bagging/Validation_Acc", acc, epoch ) + # No validation + else: + self.estimators_ = nn.ModuleList() + self.estimators_.extend(estimators) + if save_model: + io.save(self, save_dir, self.logger) # Update the scheduler with warnings.catch_warnings(): @@ -271,11 +277,6 @@ def _forward(estimators, *x): else: scheduler_.step() - self.estimators_ = nn.ModuleList() - self.estimators_.extend(estimators) - if save_model and not test_loader: - io.save(self, save_dir, self.logger) - @torchensemble_model_doc(item="classifier_evaluate") def evaluate(self, test_loader, return_loss=False): return super().evaluate(test_loader, return_loss) @@ -449,6 +450,12 @@ def _forward(estimators, *x): self.tb_logger.add_scalar( "bagging/Validation_Loss", val_loss, epoch ) + # No validation + else: + self.estimators_ = nn.ModuleList() + self.estimators_.extend(estimators) + if save_model: + io.save(self, save_dir, self.logger) # Update the scheduler with warnings.catch_warnings(): @@ -464,11 +471,6 @@ def _forward(estimators, *x): else: scheduler_.step() - self.estimators_ = nn.ModuleList() - self.estimators_.extend(estimators) - if save_model and not test_loader: - io.save(self, save_dir, self.logger) - @torchensemble_model_doc(item="regressor_evaluate") def evaluate(self, test_loader): return super().evaluate(test_loader) diff --git a/torchensemble/voting.py b/torchensemble/voting.py index 33da787..241f2d9 100644 --- a/torchensemble/voting.py +++ b/torchensemble/voting.py @@ -276,6 +276,12 @@ def _forward(estimators, *x): self.tb_logger.add_scalar( "voting/Validation_Acc", acc, epoch ) + # No validation + else: + self.estimators_ = nn.ModuleList() + self.estimators_.extend(estimators) + if save_model: + io.save(self, save_dir, self.logger) # Update the scheduler with warnings.catch_warnings(): @@ -294,11 +300,6 @@ def _forward(estimators, *x): else: scheduler_.step() - self.estimators_ = nn.ModuleList() - self.estimators_.extend(estimators) - if save_model and not test_loader: - io.save(self, save_dir, self.logger) - @torchensemble_model_doc(item="classifier_evaluate") def evaluate(self, test_loader, return_loss=False): return super().evaluate(test_loader, return_loss) @@ -532,6 +533,12 @@ def _forward(estimators, *x): self.tb_logger.add_scalar( "voting/Validation_Loss", val_loss, epoch ) + # No validation + else: + self.estimators_ = nn.ModuleList() + self.estimators_.extend(estimators) + if save_model: + io.save(self, save_dir, self.logger) # Update the scheduler with warnings.catch_warnings(): @@ -547,11 +554,6 @@ def _forward(estimators, *x): else: scheduler_.step() - self.estimators_ = nn.ModuleList() - self.estimators_.extend(estimators) - if save_model and not test_loader: - io.save(self, save_dir, self.logger) - @torchensemble_model_doc(item="regressor_evaluate") def evaluate(self, test_loader): return super().evaluate(test_loader)