Skip to content

Commit

Permalink
fix finetune tests
Browse files Browse the repository at this point in the history
Signed-off-by: Carlos Gomes <[email protected]>
  • Loading branch information
CarlosGomes98 committed Aug 2, 2024
1 parent 5c94134 commit 94102ca
Showing 1 changed file with 9 additions and 26 deletions.
35 changes: 9 additions & 26 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import os
import subprocess
import shutil

import pytest
import timm
Expand All @@ -10,47 +9,31 @@


@pytest.fixture(autouse=True)
def cleanup():
yield # everything after this runs after each test

for file in os.listdir("tests"):
if file.endswith(".pt"):
os.remove(os.path.join("tests", file))

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
def test_finetune_multiple_backbones(model_name):
def setup_and_cleanup(model_name):
model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests", model_name + ".pt"))

# Running the terratorch CLI
yield # everything after this runs after each test

os.remove(os.path.join("tests", model_name + ".pt"))
shutil.rmtree(os.path.join("tests", "all_ecos_random"))

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
def test_finetune_multiple_backbones(model_name):
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"]
_ = build_lightning_cli(command_list)


@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_intervals(model_name):
model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"]
_ = build_lightning_cli(command_list)


@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_str(model_name):
model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"]
_ = build_lightning_cli(command_list)

0 comments on commit 94102ca

Please sign in to comment.