Skip to content

Commit

Permalink
Updated prithvi tests
Browse files Browse the repository at this point in the history
Signed-off-by: Benedikt Blumenstiel <[email protected]>
  • Loading branch information
blumenstiel committed Jan 21, 2025
1 parent 498fbf6 commit 512634a
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ model:
model_args:
decoder: UperNetDecoder
pretrained: false
backbone: prithvi_vit_100
backbone: prithvi_eo_v1_100
#backbone_pretrained_cfg_overlay:
#file: tests/all_ecos_random/version_0/checkpoints/epoch=0_state_dict.ckpt #tests/prithvi_vit_100.pt
backbone_drop_path_rate: 0.3
Expand Down
150 changes: 0 additions & 150 deletions tests/resources/configs/manufactured-finetune_prithvi_vit_300.yaml

This file was deleted.

42 changes: 17 additions & 25 deletions tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,52 +35,53 @@ def input_386():
return torch.ones((1, NUM_CHANNELS, 386, 386))


@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False)
input_tensor = request.getfixturevalue(test_input)
backbone(input_tensor)
gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False, features_only=True)
input_tensor = request.getfixturevalue(test_input)
backbone(input_tensor)
gc.collect()

@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_L", "prithvi_swin_B"])
@pytest.mark.parametrize("prefix", ["", "timm_"])
def test_can_create_timm_backbones_from_registry(model_name, input_224, prefix):
backbone = BACKBONE_REGISTRY.build(prefix+model_name, pretrained=False)
backbone(input_224)
gc.collect()


@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal):
backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES)
backbone(input_224_multitemporal)
def test_can_create_backbones_from_registry(model_name, input_224):
backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False)
backbone(input_224)
gc.collect()


@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
def test_vit_models_non_divisible_input(model_name, input_non_divisible):
#padding 'none','constant', 'reflect', 'replicate' or 'circular' default is 'none'
backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, padding='constant')
backbone(input_non_divisible)
def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal):
backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False, num_frames=NUM_FRAMES)
backbone(input_224_multitemporal)
gc.collect()


@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
@pytest.mark.parametrize("patch_size", [8, 16])
@pytest.mark.parametrize("patch_size_time", [1, 2, 4])
def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, patch_size_time, input_224_multitemporal):
backbone = timm.create_model(
backbone = BACKBONE_REGISTRY.build(
model_name,
pretrained=False,
num_frames=NUM_FRAMES,
patch_size=[patch_size_time, patch_size, patch_size],
features_only=True,
)
embedding = backbone(input_224_multitemporal)
processed_embedding = backbone.prepare_features_for_image_model(embedding)
Expand All @@ -105,29 +106,18 @@ def test_vit_models_different_patch_tubelet_sizes(model_name, patch_size, patch_
gc.collect()
@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
def test_out_indices(model_name, input_224):
# out_indices = [2, 4, 8, 10]
out_indices = (2, 4, 8, 10)
backbone = timm.create_model(model_name, pretrained=False, features_only=True, out_indices=out_indices)
assert backbone.feature_info.out_indices == out_indices
backbone = BACKBONE_REGISTRY.build(model_name, pretrained=False, out_indices=out_indices)
assert backbone.out_indices == out_indices

output = backbone(input_224)
full_output = backbone.forward_features(input_224)

for filtered_index, full_index in enumerate(out_indices):
assert torch.allclose(full_output[full_index], output[filtered_index])
gc.collect()
@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
def test_out_indices_non_divisible(model_name, input_non_divisible):
out_indices = [2, 4, 8, 10]
backbone = timm.create_model(model_name, pretrained=False, features_only=True, num_frames=NUM_FRAMES, out_indices=out_indices, padding='constant')
assert backbone.feature_info.out_indices == tuple(out_indices)

output = backbone(input_non_divisible)
full_output = backbone.forward_features(input_non_divisible)

for filtered_index, full_index in enumerate(out_indices):
assert torch.allclose(full_output[full_index], output[filtered_index])
gc.collect()
@pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"])
def test_scale_mae(model_name):
# out_indices = [2, 4, 8, 10]
Expand All @@ -139,6 +129,8 @@ def test_scale_mae(model_name):

assert len(output) == len(out_indices)
gc.collect()


@pytest.mark.parametrize("model_name", ["vit_base_patch16", "vit_large_patch16"])
@pytest.mark.parametrize("bands", [2, 4, 6])
def test_scale_mae_new_channels(model_name, bands):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch

from terratorch.cli_tools import build_lightning_cli
from terratorch.registry import BACKBONE_REGISTRY

@pytest.fixture(autouse=True)
def setup_and_cleanup(model_name):
model_instance = timm.create_model(model_name)
model_instance = BACKBONE_REGISTRY.build(model_name)

state_dict = model_instance.state_dict()

Expand All @@ -22,7 +23,7 @@ def setup_and_cleanup(model_name):
if os.path.isdir(os.path.join("tests", "all_ecos_random")):
shutil.rmtree(os.path.join("tests", "all_ecos_random"))

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_eo_v2_300", "prithvi_eo_v2_600"])
@pytest.mark.parametrize("model_name", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B", "prithvi_swin_L", "prithvi_eo_v2_600"])
@pytest.mark.parametrize("case", ["fit", "test", "validate"])
def test_finetune_multiple_backbones(model_name, case):
command_list = [case, "-c", f"tests/resources/configs/manufactured-finetune_{model_name}.yaml"]
Expand Down
19 changes: 7 additions & 12 deletions tests/test_prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,25 @@

from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS
from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights
from terratorch.registry import BACKBONE_REGISTRY

import gc

@pytest.mark.parametrize("patch_size", [4, 8, 16])
@pytest.mark.parametrize("patch_size_time,num_frames", [(1, 1), (1, 2), (1, 3), (2, 2), (3,3)])
def test_prithvi_vit_patch_embed_loading_compatible(patch_size, patch_size_time, num_frames):
model = timm.create_model(
model = BACKBONE_REGISTRY.build(
"prithvi_eo_v1_100",
pretrained=False,
num_frames=num_frames,
patch_size=[patch_size_time, 16, 16],
features_only=True,
)

weights = timm.create_model(
weights = BACKBONE_REGISTRY.build(
"prithvi_eo_v1_100",
pretrained=False,
num_frames=num_frames,
patch_size=[patch_size_time, 16, 16],
features_only=True,
).state_dict()

select_patch_embed_weights(weights, model, PRETRAINED_BANDS, PRETRAINED_BANDS)
Expand All @@ -31,20 +30,18 @@ def test_prithvi_vit_patch_embed_loading_compatible(patch_size, patch_size_time,

@pytest.mark.parametrize("patch_size_time,patch_size_time_other", [(1, 2), (2, 4)])
def test_prithvi_vit_patch_embed_loading_time_patch_size_other(patch_size_time,patch_size_time_other):
model = timm.create_model(
model = BACKBONE_REGISTRY.build(
"prithvi_eo_v1_100",
pretrained=False,
num_frames=4,
patch_size=[patch_size_time, 16, 16],
features_only=True,
)

weights = timm.create_model(
weights = BACKBONE_REGISTRY.build(
"prithvi_eo_v1_100",
pretrained=False,
num_frames=4,
patch_size=[patch_size_time_other, 16, 16],
features_only=True,
).state_dict()

# assert warning produced
Expand All @@ -55,20 +52,18 @@ def test_prithvi_vit_patch_embed_loading_time_patch_size_other(patch_size_time,p

@pytest.mark.parametrize("patch_size,patch_size_other", [(2, 4), (4, 8), (16, 4)])
def test_prithvi_vit_patch_embed_loading_not_compatible_patch(patch_size, patch_size_other):
model = timm.create_model(
model = BACKBONE_REGISTRY.build(
"prithvi_eo_v1_100",
pretrained=False,
num_frames=1,
patch_size=patch_size,
features_only=True,
)

weights = timm.create_model(
weights = BACKBONE_REGISTRY.build(
"prithvi_eo_v1_100",
pretrained=False,
num_frames=1,
patch_size=patch_size_other,
features_only=True,
).state_dict()

with pytest.warns(UserWarning):
Expand Down

0 comments on commit 512634a

Please sign in to comment.