Skip to content

Commit

Permalink
fix: estimator overwrite problem with validation data loader (#143)
Browse files Browse the repository at this point in the history
* update code

* Update index.rst
  • Loading branch information
xuyxu authored Jan 27, 2023
1 parent a89cfb3 commit e3fb9fe
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Changelog
Ver 0.1.*
---------

* |Fix| Fix the issue that model will be overwritten when using the validation data loder | `@xuyxu <https://github.com/xuyxu>`__
* |Feature| Add internal :meth:`unsqueeze` operation in :meth:`forward` of all classifiers | `@xuyxu <https://github.com/xuyxu>`__
* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg <https://github.com/LukasGardberg>`__
* |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe <https://github.com/SunHaozhe>`__
Expand Down
6 changes: 4 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ Content
-------

.. toctree::
:maxdepth: 1
:caption: For Users
:maxdepth: 2
:hidden:

Quick Start <quick_start>
Introduction <introduction>
Expand All @@ -74,8 +75,9 @@ Content
API Reference <parameters>

.. toctree::
:maxdepth: 1
:caption: For Developers
:maxdepth: 2
:hidden:

Changelog <changelog>
Roadmap <roadmap>
Expand Down
22 changes: 12 additions & 10 deletions torchensemble/adversarial_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,12 @@ def _forward(estimators, *x):
acc,
epoch,
)
# No validation
else:
self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model:
io.save(self, save_dir, self.logger)

# Update the scheduler
with warnings.catch_warnings():
Expand All @@ -402,11 +408,6 @@ def _forward(estimators, *x):
else:
scheduler_.step()

self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model and not test_loader:
io.save(self, save_dir, self.logger)

@torchensemble_model_doc(item="classifier_evaluate")
def evaluate(self, test_loader, return_loss=False):
return super().evaluate(test_loader, return_loss)
Expand Down Expand Up @@ -580,6 +581,12 @@ def _forward(estimators, *x):
val_loss,
epoch,
)
# No validation
else:
self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model:
io.save(self, save_dir, self.logger)

# Update the scheduler
with warnings.catch_warnings():
Expand All @@ -595,11 +602,6 @@ def _forward(estimators, *x):
else:
scheduler_.step()

self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model and not test_loader:
io.save(self, save_dir, self.logger)

@torchensemble_model_doc(item="regressor_evaluate")
def evaluate(self, test_loader):
return super().evaluate(test_loader)
Expand Down
22 changes: 12 additions & 10 deletions torchensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ def _forward(estimators, *x):
self.tb_logger.add_scalar(
"bagging/Validation_Acc", acc, epoch
)
# No validation
else:
self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model:
io.save(self, save_dir, self.logger)

# Update the scheduler
with warnings.catch_warnings():
Expand All @@ -271,11 +277,6 @@ def _forward(estimators, *x):
else:
scheduler_.step()

self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model and not test_loader:
io.save(self, save_dir, self.logger)

@torchensemble_model_doc(item="classifier_evaluate")
def evaluate(self, test_loader, return_loss=False):
return super().evaluate(test_loader, return_loss)
Expand Down Expand Up @@ -449,6 +450,12 @@ def _forward(estimators, *x):
self.tb_logger.add_scalar(
"bagging/Validation_Loss", val_loss, epoch
)
# No validation
else:
self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model:
io.save(self, save_dir, self.logger)

# Update the scheduler
with warnings.catch_warnings():
Expand All @@ -464,11 +471,6 @@ def _forward(estimators, *x):
else:
scheduler_.step()

self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model and not test_loader:
io.save(self, save_dir, self.logger)

@torchensemble_model_doc(item="regressor_evaluate")
def evaluate(self, test_loader):
return super().evaluate(test_loader)
Expand Down
22 changes: 12 additions & 10 deletions torchensemble/voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ def _forward(estimators, *x):
self.tb_logger.add_scalar(
"voting/Validation_Acc", acc, epoch
)
# No validation
else:
self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model:
io.save(self, save_dir, self.logger)

# Update the scheduler
with warnings.catch_warnings():
Expand All @@ -294,11 +300,6 @@ def _forward(estimators, *x):
else:
scheduler_.step()

self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model and not test_loader:
io.save(self, save_dir, self.logger)

@torchensemble_model_doc(item="classifier_evaluate")
def evaluate(self, test_loader, return_loss=False):
return super().evaluate(test_loader, return_loss)
Expand Down Expand Up @@ -532,6 +533,12 @@ def _forward(estimators, *x):
self.tb_logger.add_scalar(
"voting/Validation_Loss", val_loss, epoch
)
# No validation
else:
self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model:
io.save(self, save_dir, self.logger)

# Update the scheduler
with warnings.catch_warnings():
Expand All @@ -547,11 +554,6 @@ def _forward(estimators, *x):
else:
scheduler_.step()

self.estimators_ = nn.ModuleList()
self.estimators_.extend(estimators)
if save_model and not test_loader:
io.save(self, save_dir, self.logger)

@torchensemble_model_doc(item="regressor_evaluate")
def evaluate(self, test_loader):
return super().evaluate(test_loader)
Expand Down

0 comments on commit e3fb9fe

Please sign in to comment.