Skip to content

Commit

Permalink
This key can be a set
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Jan 7, 2025
1 parent bf0e870 commit e580bd2
Showing 1 changed file with 5 additions and 24 deletions.
29 changes: 5 additions & 24 deletions terratorch/models/backbones/select_patch_embed_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,19 @@ def select_patch_embed_weights(
_possible_keys_for_proj_weight = {custom_proj_key}

patch_embed_proj_weight_key = state_dict.keys() & _possible_keys_for_proj_weight if (type(state_dict) in [collections.OrderedDict, dict]) else state_dict().keys() & _possible_keys_for_proj_weight

if len(patch_embed_proj_weight_key) == 0:
msg = "Could not find key for patch embed weight"
raise Exception(msg)
if len(patch_embed_proj_weight_key) > 1:
msg = "Too many matches for key for patch embed weight"
raise Exception(msg)

# extract the single element from the set
if isinstance(patch_embed_proj_weight_key, tuple):
(patch_embed_proj_weight_key,) = patch_embed_proj_weight_key
elif isinstance(patch_embed_proj_weight_key, set):
patch_embed_proj_weight_key = list(patch_embed_proj_weight_key)[0]

patch_embed_weight = state_dict[patch_embed_proj_weight_key]

Expand All @@ -80,26 +83,4 @@ def select_patch_embed_weights(

state_dict[patch_embed_proj_weight_key] = temp_weight

# extract the single element from the set
(patch_embed_proj_weight_key,) = patch_embed_proj_weight_key
patch_embed_weight = state_dict[patch_embed_proj_weight_key]

temp_weight = model.state_dict()[patch_embed_proj_weight_key].clone()

# only do this if the patch size and tubelet size match. If not, start with random weights
if patch_embed_weights_are_compatible(temp_weight, patch_embed_weight):
torch.nn.init.xavier_uniform_(temp_weight.view([temp_weight.shape[0], -1]))
for index, band in enumerate(model_bands):
if band in pretrained_bands:
logging.info(f"Loaded weights for {band} in position {index} of patch embed")
temp_weight[:, index] = patch_embed_weight[:, pretrained_bands.index(band)]
else:
warnings.warn(
f"Incompatible shapes between patch embedding of model {temp_weight.shape} and\
of checkpoint {patch_embed_weight.shape}",
category=UserWarning,
stacklevel=1,
)

state_dict[patch_embed_proj_weight_key] = temp_weight
return state_dict
return state_dict

0 comments on commit e580bd2

Please sign in to comment.