Skip to content

Commit

Permalink
fix: add model saving in unit tests (#122)
Browse files Browse the repository at this point in the history
* Update build-and-test.yml

* Update io.py

* Update io.py

* test

* test

* Update test_all_models.py

* update

* Update test_all_models.py

* Update test_all_models.py
  • Loading branch information
xuyxu authored Jul 25, 2022
1 parent 10949e9 commit cddad01
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
python-version: [3.8, 3.9]
python-version: [3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python
Expand Down
16 changes: 10 additions & 6 deletions torchensemble/tests/test_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
torchensemble.VotingClassifier,
torchensemble.BaggingClassifier,
torchensemble.GradientBoostingClassifier,
# torchensemble.SnapshotEnsembleClassifier,
torchensemble.SnapshotEnsembleClassifier,
torchensemble.AdversarialTrainingClassifier,
torchensemble.FastGeometricClassifier,
torchensemble.SoftGradientBoostingClassifier,
Expand All @@ -38,7 +38,7 @@

np.random.seed(0)
torch.manual_seed(0)
set_logger("pytest_all_models")
logger = set_logger("pytest_all_models")


# Base estimator
Expand Down Expand Up @@ -117,8 +117,9 @@ def test_clf_class(clf):
# Train
model.fit(train_loader, epochs=epochs, test_loader=test_loader)

# Evaluate
# Evaluate & Save
model.evaluate(test_loader)
io.save(model, "./", logger)

# Predict
for _, (data, target) in enumerate(test_loader):
Expand Down Expand Up @@ -166,8 +167,9 @@ def test_clf_object(clf):
# Train
model.fit(train_loader, epochs=epochs, test_loader=test_loader)

# Evaluate
# Evaluate & Save
model.evaluate(test_loader)
io.save(model, "./", logger)

# Predict
for _, (data, target) in enumerate(test_loader):
Expand Down Expand Up @@ -215,8 +217,9 @@ def test_reg_class(reg):
# Train
model.fit(train_loader, epochs=epochs, test_loader=test_loader)

# Evaluate
# Evaluate & Save
model.evaluate(test_loader)
io.save(model, "./", logger)

# Predict
for _, (data, target) in enumerate(test_loader):
Expand Down Expand Up @@ -264,8 +267,9 @@ def test_reg_object(reg):
# Train
model.fit(train_loader, epochs=epochs, test_loader=test_loader)

# Evaluate
# Evaluate & Save
model.evaluate(test_loader)
io.save(model, "./", logger)

# Predict
for _, (data, target) in enumerate(test_loader):
Expand Down

0 comments on commit cddad01

Please sign in to comment.