Skip to content

Commit

Permalink
Alternative tests
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Jul 3, 2024
1 parent eb815d1 commit 0c1fad8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,4 @@ jobs:
run: pip list
- name: Test with pytest
run: |
export PYTHONPATH=.
pytest -s tests
19 changes: 18 additions & 1 deletion tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os

from terratorch.cli_tools import build_lightning_cli

"""
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
def test_finetune_multiple_backbones(model_name):
Expand All @@ -22,5 +22,22 @@ def test_finetune_multiple_backbones(model_name):
# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"]
_ = build_lightning_cli(command_list)
"""

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
def test_finetune_multiple_backbones(model_name):

model_instance = timm.create_model(model_name)
pretrained_bands = [0, 1, 2, 3, 4, 5]
model_bands = [0, 1, 2, 3, 4, 5]

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))

# Running the terratorch CLI
command_str = f"python terratorch/__main__.py fit -c tests/manufactured-finetune_{model_name}.yaml"
command_out = subprocess.run(command_str, shell=True)

assert not command_out.returncode

0 comments on commit 0c1fad8

Please sign in to comment.