Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjusting the weights keys when necessary #411

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def is_one_band(img):

def write_tiff(img_wrt, filename, metadata):

# Adapting the number of bands to be compatible with the
# output dimensions.
count = img_wrt.shape[0]
metadata['count'] = count

with rasterio.open(filename, "w", **metadata) as dest:
if is_one_band(img_wrt):
img_wrt = img_wrt[None]
Expand Down
91 changes: 86 additions & 5 deletions terratorch/models/backbones/prithvi_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def weights_are_swin_implementation(state_dict: dict[str, torch.Tensor]):
return True
return False

def identify_prefix(state_dict, model):

state_dict_ = model.state_dict()

prefix = list(state_dict.keys())[0].replace(list(state_dict_.keys())[0], "")

return prefix

def checkpoint_filter_fn(state_dict: dict[str, torch.Tensor], model: torch.nn.Module, pretrained_bands, model_bands):
"""convert patch embedding weight from manual patchify + linear proj to conv"""
Expand All @@ -114,6 +121,7 @@ def checkpoint_filter_fn(state_dict: dict[str, torch.Tensor], model: torch.nn.Mo
if next(iter(_state_dict.keys())).startswith("module."):
_state_dict = {k[7:]: v for k, v in _state_dict.items()}


if weights_are_swin_implementation(_state_dict):
# keep only encoder weights
state_dict = OrderedDict()
Expand All @@ -134,9 +142,26 @@ def checkpoint_filter_fn(state_dict: dict[str, torch.Tensor], model: torch.nn.Mo
state_dict[k] = v

relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
# If the checkpoint uses a prefix for the keys, let's discover it.
prefix = identify_prefix(state_dict, model)

for table_key in relative_position_bias_table_keys:

# If we have prefixes, we need to remove them to use the information
# from the checkpoint.
if prefix:
table_key_ = table_key.replace(prefix, "")
else:
table_key_ = table_key

# Trying to minimize inconsistencies with older checkpoints
# strip prefix of state_dict
if table_key_.startswith("stages_"):
table_key_ = table_key_.replace("stages_", "stages.")

table_pretrained = state_dict[table_key]
table_current = model.state_dict()[table_key]
table_current = model.state_dict()[table_key_]

L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
Expand All @@ -151,12 +176,18 @@ def checkpoint_filter_fn(state_dict: dict[str, torch.Tensor], model: torch.nn.Mo
)
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0).contiguous()

if hasattr(model.head.fc, "weight"):
state_dict["head.fc.weight"] = model.head.fc.weight.detach().clone()
state_dict["head.fc.bias"] = model.head.fc.bias.detach().clone()
#if hasattr(model.head.fc, "weight"):
# state_dict["head.fc.weight"] = model.head.fc.weight.detach().clone()
# state_dict["head.fc.bias"] = model.head.fc.bias.detach().clone()

state_dict = select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)
return state_dict

state_dict_ = {}
for k, v in state_dict.items():
if k.startswith("stages_"):
state_dict_[k.replace("stages_", "stages.")] = v

return state_dict_


def _create_swin_mmseg_transformer(
Expand All @@ -166,6 +197,8 @@ def _create_swin_mmseg_transformer(
pretrained: bool = False, # noqa: FBT002, FBT001
**kwargs,
):
ckpt_path = kwargs.pop("ckpt_path")

if pretrained_bands is None:
pretrained_bands = PRETRAINED_BANDS

Expand All @@ -192,6 +225,7 @@ def checkpoint_filter_wrapper_fn(state_dict, model):

# When the pretrained configuration is not available in HF, we shift to
# pretrained=False
"""
try:
model: MMSegSwinTransformer = build_model_with_cfg(
MMSegSwinTransformer,
Expand All @@ -213,6 +247,53 @@ def checkpoint_filter_wrapper_fn(state_dict, model):
feature_cfg={"flatten_sequential": True, "out_indices": out_indices},
**kwargs,
)
"""

# Backwards compatibility from timm (pretrained_cfg_overlay={"file": "<path to weights>"}) TODO: Remove before v1.0
if "pretrained_cfg_overlay" in kwargs:
warnings.warn(f"pretrained_cfg_overlay is deprecated and will be removed in a future version, "
f"use ckpt_path=<file path> instead.", DeprecationWarning, stacklevel=2)
if ckpt_path is not None:
warnings.warn(f"pretrained_cfg_overlay and ckpt_path are provided, ignoring pretrained_cfg_overlay.")
elif "file" not in kwargs["pretrained_cfg_overlay"]:
warnings.warn("pretrained_cfg_overlay does not include 'file path', ignoring pretrained_cfg_overlay.")
else:
ckpt_path = kwargs.pop("pretrained_cfg_overlay")["file"]

_ = kwargs.pop("pretrained_cfg")
_ = kwargs.pop("pretrained_cfg_overlay")
_ = kwargs.pop("features_only")

model = MMSegSwinTransformer(**kwargs)

if pretrained:
if ckpt_path is not None:
# Load model from checkpoint
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
state_dict = checkpoint_filter_wrapper_fn(state_dict, model)
loaded_keys = model.load_state_dict(state_dict, strict=False)
if loaded_keys.missing_keys:
logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}")
if loaded_keys.unexpected_keys:
logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}")
else:
assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} "
f"(pretrained models: {pretrained_weights.keys()})")

try:
# Download config.json to count model downloads
_ = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"], filename="config.json")
# Load model from Hugging Face
pretrained_path = hf_hub_download(repo_id=pretrained_weights[variant]["hf_hub_id"],
filename=pretrained_weights[variant]["hf_hub_filename"])
state_dict = torch.load(pretrained_path, map_location="cpu", weights_only=True)
state_dict = checkpoint_filter_wrapper_fn(state_dict, model, pretrained_bands, model_bands)
model.load_state_dict(state_dict, strict=True)
except RuntimeError as e:
logger.error(f"Failed to load the pre-trained weights for {variant}.")
raise e
elif ckpt_path is not None:
logger.warning(f"ckpt_path is provided but pretrained is set to False, ignoring ckpt_path {ckpt_path}.")

model.pretrained_bands = pretrained_bands
model.model_bands = model_bands
Expand Down
53 changes: 48 additions & 5 deletions terratorch/models/backbones/select_patch_embed_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,39 @@ def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoi
checkpoint_shape = [checkpoint_patch_embed.shape[i] for i in range(len(checkpoint_patch_embed.shape)) if i != 1]
return model_shape == checkpoint_shape

def get_state_dict(state_dict):

if "state_dict" in state_dict.keys():
return state_dict["state_dict"]
else:
return state_dict

def get_proj_key(state_dict, return_prefix=False):

proj_key = None

for key in state_dict.keys():
if key.endswith('patch_embed.proj.weight') or key.endswith('patch_embed.projection.weight'):
proj_key = key
break

if return_prefix and proj_key:

for sufix in ['patch_embed.proj.weight', 'patch_embed.projection.weight']:
if proj_key.endswith(sufix):
prefix = proj_key.replace(sufix, "")
break
else:
prefix = None

return proj_key, prefix

def remove_prefixes(state_dict, prefix):
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k.replace(prefix, "")] = v
return new_state_dict

def select_patch_embed_weights(
state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int | OpticalBands| SARBands], model_bands: list[HLSBands | int | OpticalBands| SARBands], proj_key: str | None = None
) -> dict:
Expand All @@ -38,18 +71,25 @@ def select_patch_embed_weights(
"""
if (type(pretrained_bands) == type(model_bands)) | (type(pretrained_bands) == int) | (type(model_bands) == int):

state_dict = get_state_dict(state_dict)
prefix = None # we expect no prefix will be necessary in principle

if proj_key is None:
# Search for patch embedding weight in state dict
for key in state_dict.keys():
if key.endswith('patch_embed.proj.weight') or key.endswith('patch_embed.projection.weight'):
proj_key = key
break
proj_key, prefix = get_proj_key(state_dict, return_prefix=True)
if proj_key is None or proj_key not in state_dict:
raise Exception("Could not find key for patch embed weight in state_dict.")

patch_embed_weight = state_dict[proj_key]

temp_weight = model.state_dict()[proj_key].clone()
# It seems `proj_key` can have different names for
# the checkpoint and the model instance
proj_key_, _ = get_proj_key(model.state_dict())

if proj_key_:
temp_weight = model.state_dict()[proj_key_].clone()
else:
temp_weight = model.state_dict()[proj_key].clone()

# only do this if the patch size and tubelet size match. If not, start with random weights
if patch_embed_weights_are_compatible(temp_weight, patch_embed_weight):
Expand All @@ -68,4 +108,7 @@ def select_patch_embed_weights(

state_dict[proj_key] = temp_weight

if prefix:
state_dict = remove_prefixes(state_dict, prefix)

return state_dict
13 changes: 12 additions & 1 deletion terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
tiled_inference_parameters: TiledInferenceParameters = None,
test_dataloaders_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
output_most_probable: bool = True,
) -> None:
"""Constructor

Expand Down Expand Up @@ -112,6 +113,8 @@ def __init__(
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
output_most_probable (bool): A boolean to define if the output during the inference will be just
for the most probable class or if it will include all of them.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand All @@ -138,6 +141,12 @@ def __init__(
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"
self.plot_on_val = int(plot_on_val)
self.output_most_probable = output_most_probable

if output_most_probable:
self.select_classes = lambda y: y.argmax(dim=1)
else:
self.select_classes = lambda y: y

def configure_losses(self) -> None:
"""Initialize the loss criterion.
Expand Down Expand Up @@ -351,5 +360,7 @@ def model_forward(x):
)
else:
y_hat: Tensor = self(x, **rest).output
y_hat = y_hat.argmax(dim=1)

y_hat = self.select_classes(y_hat)

return y_hat, file_names
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ data:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
train_data_root: tests/resources/inputs
train_label_data_root: tests/resources/inputs
val_data_root: tests/resources/inputs
val_label_data_root: tests/resources/inputs
test_data_root: tests/resources/inputs
test_label_data_root: tests/resources/inputs
img_grep: "segmentation*input*.tif"
label_grep: "segmentation*label*.tif"
means:
Expand All @@ -83,8 +83,8 @@ model:
decoder: UperNetDecoder
pretrained: true
backbone: prithvi_swin_B
backbone_pretrained_cfg_overlay:
file: tests/prithvi_swin_B.pt
#backbone_pretrained_cfg_overlay:
#file: tests/prithvi_swin_B.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 256
Expand All @@ -99,6 +99,7 @@ model:
num_frames: 1
num_classes: 2
head_dropout: 0.5708022831486758
output_most_probable: false
loss: ce
#aux_heads:
# - name: aux_head
Expand Down
Loading