Skip to content

Commit

Permalink
Testing more backbones for instantiation
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: Pedro Henrique Conrado <[email protected]>
  • Loading branch information
Joao-L-S-Almeida authored and PedroConrado committed Dec 13, 2024
1 parent 74ca62a commit 78763db
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,23 @@ def input_386():
return torch.ones((1, NUM_CHANNELS, 386, 386))


@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False)
input_tensor = request.getfixturevalue(test_input)
backbone(input_tensor)
gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False, features_only=True)
input_tensor = request.getfixturevalue(test_input)
backbone(input_tensor)
gc.collect()
@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])

@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("prefix", ["", "timm_"])
def test_can_create_timm_backbones_from_registry(model_name, input_224, prefix):
backbone = BACKBONE_REGISTRY.build(prefix+model_name, pretrained=False)
Expand All @@ -62,12 +63,14 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal):
backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES)
backbone(input_224_multitemporal)
gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
def test_vit_models_non_divisible_input(model_name, input_non_divisible):
#padding 'none','constant', 'reflect', 'replicate' or 'circular' default is 'none'
backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, padding='constant')
backbone(input_non_divisible)
gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
@pytest.mark.parametrize("patch_size", [8, 16])
@pytest.mark.parametrize("patch_size_time", [1, 2, 4])
Expand Down

0 comments on commit 78763db

Please sign in to comment.