From cddad01fbdff9619b30e8739c5c42ebfacccefb9 Mon Sep 17 00:00:00 2001 From: Yi-Xuan Xu Date: Mon, 25 Jul 2022 23:31:29 +0800 Subject: [PATCH] fix: add model saving in unit tests (#122) * Update build-and-test.yml * Update io.py * Update io.py * test * test * Update test_all_models.py * update * Update test_all_models.py * Update test_all_models.py --- .github/workflows/build-and-test.yml | 2 +- torchensemble/tests/test_all_models.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 30ceedb..6709681 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -12,7 +12,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: [3.8, 3.9] + python-version: [3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 - name: Set up Python diff --git a/torchensemble/tests/test_all_models.py b/torchensemble/tests/test_all_models.py index 9a5c9bd..00d1b78 100644 --- a/torchensemble/tests/test_all_models.py +++ b/torchensemble/tests/test_all_models.py @@ -16,7 +16,7 @@ torchensemble.VotingClassifier, torchensemble.BaggingClassifier, torchensemble.GradientBoostingClassifier, - # torchensemble.SnapshotEnsembleClassifier, + torchensemble.SnapshotEnsembleClassifier, torchensemble.AdversarialTrainingClassifier, torchensemble.FastGeometricClassifier, torchensemble.SoftGradientBoostingClassifier, @@ -38,7 +38,7 @@ np.random.seed(0) torch.manual_seed(0) -set_logger("pytest_all_models") +logger = set_logger("pytest_all_models") # Base estimator @@ -117,8 +117,9 @@ def test_clf_class(clf): # Train model.fit(train_loader, epochs=epochs, test_loader=test_loader) - # Evaluate + # Evaluate & Save model.evaluate(test_loader) + io.save(model, "./", logger) # Predict for _, (data, target) in enumerate(test_loader): @@ -166,8 +167,9 @@ def test_clf_object(clf): # Train model.fit(train_loader, epochs=epochs, test_loader=test_loader) - # Evaluate + # Evaluate & Save model.evaluate(test_loader) + io.save(model, "./", logger) # Predict for _, (data, target) in enumerate(test_loader): @@ -215,8 +217,9 @@ def test_reg_class(reg): # Train model.fit(train_loader, epochs=epochs, test_loader=test_loader) - # Evaluate + # Evaluate & Save model.evaluate(test_loader) + io.save(model, "./", logger) # Predict for _, (data, target) in enumerate(test_loader): @@ -264,8 +267,9 @@ def test_reg_object(reg): # Train model.fit(train_loader, epochs=epochs, test_loader=test_loader) - # Evaluate + # Evaluate & Save model.evaluate(test_loader) + io.save(model, "./", logger) # Predict for _, (data, target) in enumerate(test_loader):