Skip to content

Commit

Permalink
remove duplicate keys
Browse files Browse the repository at this point in the history
Signed-off-by: Carlos Gomes <[email protected]>
  • Loading branch information
CarlosGomes98 committed Aug 2, 2024
1 parent e444ba1 commit 5c94134
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
1 change: 0 additions & 1 deletion tests/manufactured-finetune_prithvi_vit_100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ model:
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
Expand Down
1 change: 0 additions & 1 deletion tests/manufactured-finetune_prithvi_vit_300.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ model:
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
Expand Down
20 changes: 14 additions & 6 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,47 @@
from terratorch.cli_tools import build_lightning_cli


@pytest.fixture(autouse=True)
def cleanup():
yield # everything after this runs after each test

for file in os.listdir("tests"):
if file.endswith(".pt"):
os.remove(os.path.join("tests", file))

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
def test_finetune_multiple_backbones(model_name, tmpdir):
def test_finetune_multiple_backbones(model_name):
model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join(tmpdir, model_name + ".pt"))
torch.save(state_dict, os.path.join("tests", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"]
_ = build_lightning_cli(command_list)


@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_intervals(model_name, tmpdir):
def test_finetune_bands_intervals(model_name):
model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join(tmpdir, model_name + ".pt"))
torch.save(state_dict, os.path.join("tests", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"]
_ = build_lightning_cli(command_list)


@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_str(model_name, tmpdir):
def test_finetune_bands_str(model_name):
model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join(tmpdir, model_name + ".pt"))
torch.save(state_dict, os.path.join("tests", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"]
Expand Down

0 comments on commit 5c94134

Please sign in to comment.