diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 30ceedb..6709681 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -12,7 +12,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: [3.8, 3.9] + python-version: [3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 - name: Set up Python diff --git a/torchensemble/tests/test_all_models.py b/torchensemble/tests/test_all_models.py index 9a5c9bd..00d1b78 100644 --- a/torchensemble/tests/test_all_models.py +++ b/torchensemble/tests/test_all_models.py @@ -16,7 +16,7 @@ torchensemble.VotingClassifier, torchensemble.BaggingClassifier, torchensemble.GradientBoostingClassifier, - # torchensemble.SnapshotEnsembleClassifier, + torchensemble.SnapshotEnsembleClassifier, torchensemble.AdversarialTrainingClassifier, torchensemble.FastGeometricClassifier, torchensemble.SoftGradientBoostingClassifier, @@ -38,7 +38,7 @@ np.random.seed(0) torch.manual_seed(0) -set_logger("pytest_all_models") +logger = set_logger("pytest_all_models") # Base estimator @@ -117,8 +117,9 @@ def test_clf_class(clf): # Train model.fit(train_loader, epochs=epochs, test_loader=test_loader) - # Evaluate + # Evaluate & Save model.evaluate(test_loader) + io.save(model, "./", logger) # Predict for _, (data, target) in enumerate(test_loader): @@ -166,8 +167,9 @@ def test_clf_object(clf): # Train model.fit(train_loader, epochs=epochs, test_loader=test_loader) - # Evaluate + # Evaluate & Save model.evaluate(test_loader) + io.save(model, "./", logger) # Predict for _, (data, target) in enumerate(test_loader): @@ -215,8 +217,9 @@ def test_reg_class(reg): # Train model.fit(train_loader, epochs=epochs, test_loader=test_loader) - # Evaluate + # Evaluate & Save model.evaluate(test_loader) + io.save(model, "./", logger) # Predict for _, (data, target) in enumerate(test_loader): @@ -264,8 +267,9 @@ def test_reg_object(reg): # Train model.fit(train_loader, epochs=epochs, test_loader=test_loader) - # Evaluate + # Evaluate & Save model.evaluate(test_loader) + io.save(model, "./", logger) # Predict for _, (data, target) in enumerate(test_loader):