From 07d5f98bddc49015e4406b01af17d74dff70bf20 Mon Sep 17 00:00:00 2001 From: myscon Date: Tue, 20 Aug 2024 15:19:44 -0700 Subject: [PATCH] added clay_v1_base_model Signed-off-by: myscon --- terratorch/datamodules/__init__.py | 6 +- terratorch/datamodules/openearthmap.py | 58 +++ terratorch/datasets/__init__.py | 4 + terratorch/datasets/openearthmap.py | 114 +++++ terratorch/datasets/utils.py | 4 +- terratorch/models/__init__.py | 1 + terratorch/models/backbones/__init__.py | 1 + .../models/backbones/clay_v1/__init__.py | 3 + .../models/backbones/clay_v1/embedder.py | 168 +++++++ .../models/backbones/clay_v1/modules.py | 445 ++++++++++++++++++ terratorch/models/backbones/clay_v1/utils.py | 44 ++ terratorch/models/clay_model_factory.py | 82 +--- 12 files changed, 861 insertions(+), 69 deletions(-) create mode 100644 terratorch/datamodules/openearthmap.py create mode 100644 terratorch/datasets/openearthmap.py create mode 100644 terratorch/models/backbones/clay_v1/__init__.py create mode 100644 terratorch/models/backbones/clay_v1/embedder.py create mode 100644 terratorch/models/backbones/clay_v1/modules.py create mode 100644 terratorch/models/backbones/clay_v1/utils.py diff --git a/terratorch/datamodules/__init__.py b/terratorch/datamodules/__init__.py index 3e35c579..1f8b3d78 100644 --- a/terratorch/datamodules/__init__.py +++ b/terratorch/datamodules/__init__.py @@ -29,6 +29,9 @@ from terratorch.datamodules.sen1floods11 import Sen1Floods11NonGeoDataModule from terratorch.datamodules.torchgeo_data_module import TorchGeoDataModule, TorchNonGeoDataModule +# miscellaneous datamodules +from terratorch.datamodules.openearthmap import OpenEarthMapNonGeoDataModule + __all__ = ( "GenericNonGeoSegmentationDataModule", "GenericNonGeoPixelwiseRegressionDataModule", @@ -50,5 +53,6 @@ "MChesapeakeLandcoverNonGeoDataModule", "MPv4gerSegNonGeoDataModule", "MSACropTypeNonGeoDataModule", - "MNeonTreeNonGeoDataModule" + "MNeonTreeNonGeoDataModule", + "OpenEarthMapModule" ) diff --git a/terratorch/datamodules/openearthmap.py b/terratorch/datamodules/openearthmap.py new file mode 100644 index 00000000..613a6425 --- /dev/null +++ b/terratorch/datamodules/openearthmap.py @@ -0,0 +1,58 @@ +from typing import Any +import torch + +import albumentations as A +import kornia.augmentation as K +from torchgeo.datamodules import NonGeoDataModule +from torchgeo.transforms import AugmentationSequential +from terratorch.datasets import OpenEarthMapNonGeo +from terratorch.datamodules.utils import wrap_in_compose_is_list + +MEANS = { + "BLUE": 116.628328, + "GREEN": 119.65935, + "RED": 113.385309 +} + +STDS = { + "BLUE": 44.668890717415586, + "GREEN": 48.282311849967364, + "RED": 54.19692448815262, +} + +class OpenEarthMapNonGeoDataModule(NonGeoDataModule): + def __init__( + self, + batch_size: int = 8, + num_workers: int = 0, + data_root: str = "./", + train_transform: A.Compose | None | list[A.BasicTransform] = None, + val_transform: A.Compose | None | list[A.BasicTransform] = None, + test_transform: A.Compose | None | list[A.BasicTransform] = None, + aug: AugmentationSequential = None, + **kwargs: Any + ) -> None: + super().__init__(OpenEarthMapNonGeo, batch_size, num_workers, **kwargs) + + bands = kwargs.get("bands", OpenEarthMapNonGeo.all_band_names) + self.means = torch.tensor([MEANS[b] for b in bands]) + self.stds = torch.tensor([STDS[b] for b in bands]) + self.train_transform = wrap_in_compose_is_list(train_transform) + self.val_transform = wrap_in_compose_is_list(val_transform) + self.test_transform = wrap_in_compose_is_list(test_transform) + self.data_root = data_root + self.aug = AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug + + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = self.dataset_class( + split="train", data_root=self.data_root, transform=self.train_transform, **self.kwargs + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( + split="val", data_root=self.data_root, transform=self.val_transform, **self.kwargs + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( + split="test",data_root=self.data_root, transform=self.test_transform, **self.kwargs + ) \ No newline at end of file diff --git a/terratorch/datasets/__init__.py b/terratorch/datasets/__init__.py index ae19d1d8..b875f27a 100644 --- a/terratorch/datasets/__init__.py +++ b/terratorch/datasets/__init__.py @@ -33,6 +33,9 @@ # TorchGeo RasterDatasets from terratorch.datasets.wsf import WSF2019, WSFEvolution +# miscellaneous datasets +from terratorch.datasets.openearthmap import OpenEarthMapNonGeo + __all__ = ( "GenericNonGeoSegmentationDataset", "GenericNonGeoPixelwiseRegressionDataset", @@ -59,4 +62,5 @@ "WSFEvolution", "HLSL30", "HLSS30", + "OpenEarthMapNonGeo" ) diff --git a/terratorch/datasets/openearthmap.py b/terratorch/datasets/openearthmap.py new file mode 100644 index 00000000..bad3bbcd --- /dev/null +++ b/terratorch/datasets/openearthmap.py @@ -0,0 +1,114 @@ +import numpy as np +from collections.abc import Sequence +import matplotlib.pyplot as plt +import torch +import rasterio +from pathlib import Path + +import albumentations as A + +from torchgeo.datasets import NonGeoDataset +from terratorch.datasets.utils import to_tensor + + + +class OpenEarthMapNonGeo(NonGeoDataset): + + all_band_names = ("BLUE","GREEN","RED") + + rgb_bands = ("RED","GREEN","BLUE") + + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + def __init__(self, data_root: str, + bands: Sequence[str] = BAND_SETS["all"], + transform: A.Compose | None = None, + split="train", + crop_size: int = 256, + random_crop: bool = True) -> None: + super().__init__() + if split not in ["train", "test", "val"]: + msg = "Split must be one of train, test, val." + raise Exception(msg) + + self.transform = transform if transform else lambda **batch: to_tensor(batch, transpose=False) + self.split = split + self.data_root = data_root + + # images in openearthmap are not all 1024x1024 and must be cropped + self.crop_size = crop_size + self.random_crop = random_crop + + assert self.crop_size > 0, "Crop size must be greater than 0" + + self.image_files = self._get_file_paths(Path(self.data_root, f"{split}.txt")) + + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: + image_path, label_path = self.image_files[index] + + with rasterio.open(image_path) as src: + image = src.read() + with rasterio.open(label_path) as src: + mask = src.read() + + # some images in the dataset are not perfect squares + # cropping to fit to the prepare_features_for_image_model call + if self.random_crop: + image, mask = self._random_crop(image, mask) + else: + image, mask = self._center_crop(image, mask) + + output = { + "image": image.astype(np.float32), + "mask": mask + } + + output = self.transform(**output) + output['mask'] = output['mask'].long() + + return output + + def _parse_file_name(self, file_name: str): + underscore_pos = file_name.rfind('_') + folder_name = file_name[:underscore_pos] + region_path = Path(self.data_root, folder_name) + image_path = Path(region_path, "images", file_name) + label_path = Path(region_path, "labels", file_name) + return image_path, label_path + + def _get_file_paths(self, text_file_path: str): + with open(text_file_path, 'r') as file: + lines = file.readlines() + file_paths = [self._parse_file_name(line.strip()) for line in lines] + return file_paths + + def __len__(self): + return len(self.image_files) + + def _random_crop(self, image, mask): + h, w = image.shape[1:] + top = np.random.randint(0, h - self.crop_size) + left = np.random.randint(0, w - self.crop_size) + + image = image[:, top: top + self.crop_size, left: left + self.crop_size] + mask = mask[:, top: top + self.crop_size, left: left + self.crop_size] + + return image, mask + + def _center_crop(self, image, mask): + h, w = image.shape[1:] + top = (h - self.crop_size) // 2 + left = (w - self.crop_size) // 2 + + image = image[:, top: top + self.crop_size, left: left + self.crop_size] + mask = mask[:, top: top + self.crop_size, left: left + self.crop_size] + + return image, mask + + def plot(self, arg, suptitle: str | None = None) -> None: + pass + + def plot_sample(self, sample, prediction=None, suptitle: str | None = None, class_names=None): + pass + + \ No newline at end of file diff --git a/terratorch/datasets/utils.py b/terratorch/datasets/utils.py index 0d4065a8..4eda8350 100644 --- a/terratorch/datasets/utils.py +++ b/terratorch/datasets/utils.py @@ -71,13 +71,13 @@ def _split_filter_function(file_name, valid_files: list[str], ignore_extensions= return False -def to_tensor(d): +def to_tensor(d, transpose=True): new_dict = {} for k, v in d.items(): if not isinstance(v, np.ndarray): new_dict[k] = v else: - if k == "image": + if k == "image" and transpose: v = np.moveaxis(v, -1, 0) new_dict[k] = torch.from_numpy(v) return new_dict diff --git a/terratorch/models/__init__.py b/terratorch/models/__init__.py index a04404ad..f740a1ea 100644 --- a/terratorch/models/__init__.py +++ b/terratorch/models/__init__.py @@ -1,5 +1,6 @@ # Copyright contributors to the Terratorch project +from terratorch.models.clay_model_factory import ClayModelFactory from terratorch.models.prithvi_model_factory import PrithviModelFactory from terratorch.models.satmae_model_factory import SatMAEModelFactory from terratorch.models.scalemae_model_factory import ScaleMAEModelFactory diff --git a/terratorch/models/backbones/__init__.py b/terratorch/models/backbones/__init__.py index 1f27dfd4..43ea6d74 100644 --- a/terratorch/models/backbones/__init__.py +++ b/terratorch/models/backbones/__init__.py @@ -3,3 +3,4 @@ # import so they get registered import terratorch.models.backbones.prithvi_swin import terratorch.models.backbones.prithvi_vit +import terratorch.models.backbones.clay_v1 diff --git a/terratorch/models/backbones/clay_v1/__init__.py b/terratorch/models/backbones/clay_v1/__init__.py new file mode 100644 index 00000000..bdc08c0b --- /dev/null +++ b/terratorch/models/backbones/clay_v1/__init__.py @@ -0,0 +1,3 @@ +import terratorch.models.backbones.clay_v1.embedder +import terratorch.models.backbones.clay_v1.modules +import terratorch.models.backbones.clay_v1.utils \ No newline at end of file diff --git a/terratorch/models/backbones/clay_v1/embedder.py b/terratorch/models/backbones/clay_v1/embedder.py new file mode 100644 index 00000000..c00dc6ba --- /dev/null +++ b/terratorch/models/backbones/clay_v1/embedder.py @@ -0,0 +1,168 @@ +import re +import warnings + +import numpy as np +import torch +from torch import nn, Tensor +import torch +from timm.models import FeatureInfo +from timm.models._builder import build_model_with_cfg +from timm.models._registry import generate_default_cfgs, register_model + +from terratorch.models.backbones.clay_v1.modules import EmbeddingEncoder, Datacuber + +warnings.filterwarnings("ignore", category=UserWarning) + + +default_cfgs = generate_default_cfgs( + { + "clay_v1_base": { + "hf_hub_id": "made-with-clay/Clay", + "hf_hub_filename": "clay-v1-base.ckpt" + } + } +) + + +class Embedder(nn.Module): + default_out_indices = (0,) # Single out_indices for simplicity + + def __init__(self, + img_size=256, + num_frames=1, + ckpt_path=None, + device="cuda", + **kwargs): + super().__init__() + self.feature_info = [] + self.img_size = img_size + self.num_frames = num_frames + + if kwargs.get("datacuber", True) is not None: + self.datacuber = Datacuber() + else: + self.datacuber = None + + # TODO: add support for various clay versions + self.clay_encoder = ( + EmbeddingEncoder( # Default parameters for the Clay base model + img_size=img_size, + patch_size=8, + dim=768, + depth=12, + heads=12, + dim_head=64, + mlp_ratio=4.0, + ).to(device) + ) + + # for use in features list. Single layer feature for simplicity + self.feature_info.append( + {"num_chs": 768, "reduction": 1, "module": f"clay_encoder"}) + + # assuming this is used to fine tune a network on top of the embeddings + self.device = torch.device(device) + if ckpt_path: + self.load_clay_weights(ckpt_path) + + def load_clay_weights(self, ckpt_path): + "Load the weights from the Clay model encoder." + ckpt = torch.load(ckpt_path, map_location=self.device) + state_dict = ckpt.get("state_dict") + state_dict = { + re.sub(r"^model\.encoder\.", "", name): param + for name, param in state_dict.items() + if name.startswith("model.encoder") + } + + with torch.no_grad(): + for name, param in self.clay_encoder.named_parameters(): + if name in state_dict and param.size() == state_dict[name].size(): + param.data.copy_(state_dict[name]) # Copy the weights + else: + print( + f"No matching parameter for {name} with size {param.size()}") + + for param in self.clay_encoder.parameters(): + param.requires_grad = False + + self.clay_encoder.eval() + + @staticmethod + def transform_state_dict(state_dict, model): + state_dict = state_dict.get("state_dict") + state_dict = { + re.sub(r"^model\.encoder\.", "clay_encoder.", name): param + for name, param in state_dict.items() + if name.startswith("model.encoder") + } + return state_dict + + def forward_features(self, x): + datacube = self.datacuber(x) + embeddings = self.clay_encoder(datacube) + + # TODO: actually return features individually + return [embeddings] + + def fake_datacube(self): + "Generate a fake datacube for model export." + dummy_datacube = { + "pixels": torch.randn(2, 3, self.img_size, self.img_size), + "time": torch.randn(2, 4), + "latlon": torch.randn(2, 4), + "waves": torch.randn(3), + "gsd": torch.randn(1), + } + dummy_datacube = {k: v.to(self.device) + for k, v in dummy_datacube.items()} + return dummy_datacube + + def prepare_features_for_image_model(self, features: list[Tensor]) -> list[Tensor]: + x_no_token = features[-1][:, 1:, :] + encoded = x_no_token.permute(0, 2, 1).reshape( + x_no_token.shape[0], + -1, + int(np.sqrt(x_no_token.shape[1] // self.num_frames)), + int(np.sqrt(x_no_token.shape[1] // self.num_frames)), + ) + + # return as list for features list compatibility + return [encoded] + + +def _make_clay( + variant: str, + pretrained: bool, + **kwargs +): + encoder_only = kwargs.pop("features_only", False) + model = build_model_with_cfg( + model_cls=Embedder, + variant=variant, + pretrained=pretrained, + pretrained_strict=True, + pretrained_filter_fn=Embedder.transform_state_dict, + **kwargs, + ) + if encoder_only: + out_indices = kwargs.pop("out_indices", model.default_out_indices) + model.feature_info = FeatureInfo(model.feature_info, out_indices) + model.model_bands = kwargs.get("model_bands") + + # TODO: split features according to typical TIMM outputs + model.forward = model.forward_features + model.pretrained_bands = kwargs.get("pretrained_bands") + return model + + +@register_model +def clay_v1_base( + pretrained: bool = False, + **kwargs, +) -> Embedder: + return _make_clay( + "clay_v1_base", + pretrained=pretrained, + **kwargs + ) diff --git a/terratorch/models/backbones/clay_v1/modules.py b/terratorch/models/backbones/clay_v1/modules.py new file mode 100644 index 00000000..b35a39a2 --- /dev/null +++ b/terratorch/models/backbones/clay_v1/modules.py @@ -0,0 +1,445 @@ +import math +import os +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn, Tensor +from vit_pytorch.simple_vit import Transformer + +from terratorch.models.backbones.clay_v1.utils import posemb_sincos_1d, posemb_sincos_2d_with_gsd + +os.environ["TORCH_CUDNN_V8_API_DISABLED"] = "1" + + +class Encoder(nn.Module): + def __init__( + self, + mask_ratio, + patch_size, + shuffle, + dim, + depth, + heads, + dim_head, + mlp_ratio, + ): + super().__init__() + self.mask_ratio = mask_ratio + self.patch_size = patch_size + self.shuffle = shuffle + self.dim = dim + self.cls_token = nn.Parameter(torch.randn(1, 1, dim) * 0.02) + + self.patch_embedding = DynamicEmbedding( + wave_dim=128, + num_latent_tokens=128, + patch_size=patch_size, + embed_dim=dim, + is_decoder=False, + ) + + self.transformer = Transformer( + dim=dim, + depth=depth, + heads=heads, + dim_head=dim_head, + mlp_dim=int(dim * mlp_ratio), + ) + + def to_patch_embed(self, cube, waves): + """Split the input cube into patches & create embeddings per patch""" + patches, waves_encoded = self.patch_embedding(cube, waves) # [B L D] + return patches, waves_encoded # ([B L D], [N D]) + + def add_encodings(self, patches, time, latlon, gsd): + """Add position encoding to the patches""" + B, L, D = patches.shape + + grid_size = int(math.sqrt(L)) + self.num_patches = grid_size**2 + + pos_encoding = ( + posemb_sincos_2d_with_gsd( + h=grid_size, + w=grid_size, + dim=(self.dim - 8), + gsd=gsd, + ) + .to(patches.device) + .detach() + ) # [L (D - 8)] + + time_latlon = torch.hstack((time, latlon)).to( + patches.device).detach() # [B 8] + + pos_encoding = repeat( + pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] + time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] + pos_metadata_encoding = torch.cat( + (pos_encoding, time_latlon), dim=-1 + ) # [B L D] + + # [B L D] + [B L D] -> [B L D] + patches = patches + pos_metadata_encoding + return patches # [B L D] + + def mask_out(self, patches): + """ + Mask out patches randomly by shuffling the patches & masking out the + first N patches + + Parameters + ---------- + patches : torch.Tensor A tensor of shape (B, L, D) + + Returns + ------- + unmasked_patches : torch.Tensor + A tensor of shape (B, L:(1 - mask_ratio), D) containing the + embeddings of the unmasked patches. + unmasked_indices : torch.Tensor + A tensor of shape (B, (1 - mask_ratio)) containing the indices of + the unmasked patches. + masked_indices : torch.Tensor + A tensor of shape (B, mask_ratio) containing the indices of the + masked patches. + masked_matrix : torch.Tensor + A tensor of shape (B, L) containing the mask matrix, 1 indicates a masked + patch & 0 indicates an unmasked patch. + """ + B, L, D = patches.shape + # assert ( + # L == self.num_patches + # ), f"Expected {self.num_patches} patches, got {L} patches." + + if self.shuffle: # Shuffle the patches + noise = torch.randn((B, L), device=patches.device) # [B L] + else: # Don't shuffle, useful for interpolation & inspection of embeddings + noise = rearrange( + torch.arange(B * L, device=patches.device), "(B L) -> B L", B=B, L=L + ) + + random_indices = torch.argsort(noise, dim=-1) # [B L] + reverse_indices = torch.argsort(random_indices, dim=-1) # [B L] + + num_masked_patches = int( + self.mask_ratio * self.num_patches + ) # Number of patches to be masked out + masked_indices, unmasked_indices = ( + random_indices[:, :num_masked_patches], # [B mask_ratio * L] + random_indices[:, num_masked_patches:], # [B (1 - mask_ratio) * L] + ) + + # create a mask of shape B L, where 1 indicates a masked patch + # and 0 indicates an unmasked patch + masked_matrix = torch.zeros((B, L), device=patches.device) # [B L] = 0 + masked_matrix[:, :num_masked_patches] = 1 # [B mask_ratio * L] = 1 + masked_matrix = torch.gather( + masked_matrix, dim=1, index=reverse_indices + ) # [B L] -> [B L] - reorder the patches + + # mask out the patches + batch_indices = rearrange( + torch.arange(B, device=patches.device), "B -> B 1" + ) # [B 1] + unmasked_patches = patches[ + batch_indices, unmasked_indices, : + ] # [B L:(1 - mask_ratio) D] + _ = patches[batch_indices, masked_indices, :] # [B L:mask_ratio D] + + return ( + unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L] + + def forward(self, datacube): + cube, time, latlon, gsd, waves = ( + datacube["pixels"], # [B C H W] + datacube["time"], # [B 2] + datacube["latlon"], # [B 2] + datacube["gsd"], # 1 + datacube["waves"], # [N] + ) # [B C H W] + + B, C, H, W = cube.shape + + patches, waves_encoded = self.to_patch_embed( + cube, waves + ) # [B L D] - patchify & create embeddings per patch + # TODO: Add time & latlon as encoding to patches + patches = self.add_encodings( + patches, + time, + latlon, + gsd, + ) # [B L D] - add position encoding to the embeddings + + # mask out patches + ( + unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) = self.mask_out( + patches + ) # [B L:(1 - mask_ratio) D], [(1-mask_ratio)], [mask_ratio], [B L] + + # Add class tokens + cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] + unmasked_patches = torch.cat( + (cls_tokens, unmasked_patches), dim=1 + ) # [B (1 + L) D] + + # pass the unmasked patches through the transformer + encoded_unmasked_patches = self.transformer( + unmasked_patches + ) # [B ((1 + L)):(1 - mask_ratio)) D] + + return ( + encoded_unmasked_patches, + unmasked_indices, + masked_indices, + masked_matrix, + ) # [B ((1 + L):(1 - mask_ratio)) D], [(1-mask_ratio)], [mask_ratio], [B L] + + +class EmbeddingEncoder(Encoder): + """Clay Encoder without mask and shuffle.""" + + def __init__( # noqa: PLR0913 + self, + img_size, + patch_size, + dim, + depth, + heads, + dim_head, + mlp_ratio, + ): + super().__init__( + mask_ratio=0.0, + shuffle=False, + patch_size=patch_size, + dim=dim, + depth=depth, + heads=heads, + dim_head=dim_head, + mlp_ratio=mlp_ratio, + ) + self.img_size = img_size + + # Using fixed grid size for inference + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + + def add_encodings(self, patches, time, latlon, gsd): + """Add position encoding to the patches""" + B, L, D = patches.shape + + grid_size = self.grid_size + + pos_encoding = ( + posemb_sincos_2d_with_gsd( + h=grid_size, + w=grid_size, + dim=(self.dim - 8), + gsd=gsd, + ) + .to(patches.device) + .detach() + ) # [L (D - 8)] + + time_latlon = torch.hstack((time, latlon)).to( + patches.device).detach() # [B 8] + + pos_encoding = repeat( + pos_encoding, "L D -> B L D", B=B) # [B L (D - 8)] + time_latlon = repeat(time_latlon, "B D -> B L D", L=L) # [B L 8] + pos_metadata_encoding = torch.cat( + (pos_encoding, time_latlon), dim=-1 + ) # [B L D] + + # [B L D] + [B L D] -> [B L D] + patches = patches + pos_metadata_encoding + return patches # [B L D] + + # def forward(self, cube, time, latlon, waves, gsd): + def forward(self, datacube): + cube, time, latlon, gsd, waves = ( + datacube["pixels"], # [B C H W] + datacube["time"], # [B 2] + datacube["latlon"], # [B 2] + datacube["gsd"], # 1 + datacube["waves"], # [N] + ) # [B C H W] + B, C, H, W = cube.shape + + patches, _ = self.to_patch_embed( + cube, waves + ) # [B L D] - patchify & create embeddings per patch + + # Add time & latlon as encoding to patches + patches = self.add_encodings( + patches, + time, + latlon, + gsd, + ) # [B L D] - add position encoding to the embeddings + + # Add class tokens + cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D] + patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D] + + # pass the patches through the transformer + patches = self.transformer(patches) # [B (1 + L) D] + + # # remove the cls token + # embeddings = patches[:, 1: , :] # [B L D] + + return patches # [B (1 + L) D] + + +class FCBlock(nn.Module): + def __init__(self, size): + super().__init__() + self.l1 = nn.Linear(size, size) + self.l2 = nn.Linear(size, size) + + def forward(self, x): + y = F.gelu(self.l1(x)) + y = F.gelu(self.l2(y)) + return x + y + + +class WavesTransformer(nn.Module): + def __init__( # noqa: PLR0913 + self, + wave_dim, + output_dim, + num_latent_tokens, + embed_dim, + is_decoder, + num_heads=4, + num_layers=1, + ): + super().__init__() + self.num_latent_tokens = num_latent_tokens + self.is_decoder = is_decoder + layer = nn.TransformerEncoderLayer( + d_model=wave_dim, + nhead=num_heads, + activation="gelu", + dropout=0, + norm_first=False, + batch_first=False, + ) + self.encoder = nn.TransformerEncoder(layer, num_layers) + + self.fc_weight = nn.Linear(wave_dim, output_dim) + self.fc_bias = None if self.is_decoder else nn.Linear( + wave_dim, embed_dim) + + self.weight_tokens = nn.Parameter( + torch.randn(self.num_latent_tokens, wave_dim) * 0.02 + ) + self.bias_token = nn.Parameter(torch.randn(1, wave_dim) * 0.02) + + def forward(self, x): + x = torch.cat([self.weight_tokens, x, self.bias_token], dim=0) + out = self.encoder(x) + weights = self.fc_weight( + out[self.num_latent_tokens: -1] + x[self.num_latent_tokens: -1] + ) + bias = None if self.is_decoder else self.fc_bias(out[-1]) + return weights, bias + + +class DynamicEmbedding(nn.Module): + def __init__( + self, + wave_dim, + num_latent_tokens, + patch_size, + embed_dim, + is_decoder=False, + ): + super().__init__() + self.wave_dim = wave_dim + self.num_latent_tokens = num_latent_tokens + self.patch_size = patch_size + self.embed_dim = embed_dim + self.is_decoder = is_decoder + self.output_dim = (patch_size**2) * embed_dim + + self.weight_generator = WavesTransformer( + wave_dim, + self.output_dim, + self.num_latent_tokens, + self.embed_dim, + is_decoder, + ) + self.fclayer = FCBlock(self.wave_dim) + + self.initialize_weights() + + def forward(self, batch, waves): + waves = posemb_sincos_1d(waves, self.wave_dim) + waves = waves.to(batch.device) + waves = self.fclayer(waves) + weight, bias = self.weight_generator(waves) + + if self.is_decoder: + dynamic_weight = rearrange( + weight, + "cin (k1 k2 cout) -> (cin k1 k2) cout", + k1=self.patch_size, + k2=self.patch_size, + cout=self.embed_dim, + ) + if bias is not None: + bias = rearrange(bias, "b -> (b)") + dynamic_out = F.linear(batch, dynamic_weight * 0.02, bias=bias) + x = dynamic_out + else: + dynamic_weight = rearrange( + weight, + "cin (cout k1 k2) -> cout cin k1 k2", + k1=self.patch_size, + k2=self.patch_size, + ) + if bias is not None: + bias = rearrange(bias, "b -> (b)") + dynamic_out = F.conv2d( + batch, dynamic_weight * 0.02, bias=bias, stride=self.patch_size + ) + x = rearrange(dynamic_out, "b c h w -> b (h w) c") + + return x, waves + + def initialize_weights(self): + # Initialize weights using Xavier initialization + for m in self.modules(): + if isinstance(m, (nn.Linear, nn.Conv2d)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class Datacuber(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, **kwargs): + # TODO: dynamic dict handling/parsing + if not isinstance(x, dict): + datacube = {} + datacube['pixels'] = x + datacube['time'] = torch.zeros((x.shape[0], 4)) + datacube['latlon'] = torch.zeros((x.shape[0], 4)) + datacube['gsd'] = 1.0 + datacube['waves'] = torch.zeros(x.shape[1]) + return datacube + else: + return x diff --git a/terratorch/models/backbones/clay_v1/utils.py b/terratorch/models/backbones/clay_v1/utils.py new file mode 100644 index 00000000..b8796bfd --- /dev/null +++ b/terratorch/models/backbones/clay_v1/utils.py @@ -0,0 +1,44 @@ +import torch + + +def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature**omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) + + +def posemb_sincos_2d_with_gsd( + h, w, dim, gsd=1.0, temperature: int = 10000, dtype=torch.float32 +): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature ** (2 * omega / dim)) * \ + (gsd / 1.0) # Adjusted for g + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) + + +def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32): + assert ( + dim % 2 == 0 + ), "Feature dimension must be a multiple of 2 for sincos embedding" + pos = torch.arange(pos) if isinstance(pos, int) else pos + + omega = torch.arange(dim // 2) / (dim // 2 - 1) + omega = 1.0 / (temperature**omega) + + scaled_pos = pos[:, None] * omega[None, :] + pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1) + + return pe.type(dtype) diff --git a/terratorch/models/clay_model_factory.py b/terratorch/models/clay_model_factory.py index b196cc64..9b0d2a53 100644 --- a/terratorch/models/clay_model_factory.py +++ b/terratorch/models/clay_model_factory.py @@ -1,13 +1,11 @@ -import importlib from collections.abc import Callable -import sys -import numpy as np import timm import torch from torch import nn import terratorch.models.decoders as decoder_registry +from terratorch.models.backbones.clay_v1.embedder import Embedder from terratorch.datasets import HLSBands from terratorch.models.model import ( AuxiliaryHead, @@ -28,30 +26,6 @@ class DecoderNotFoundError(Exception): pass -class ModelWrapper(nn.Module): - - def __init__(self, model: nn.Module = None) -> None: - - super(ModelWrapper, self).__init__() - - self.model = model - - self.embedding_shape = self.model.model.state_dict()['decoder.embed_to_pixels.dem.bias'].shape[0] - - def channels(self): - return (1, self.embedding_shape) - - @property - def parameters(self): - return self.model.parameters - - def forward(self, args, **kwargs): - datacube = {} - datacube['pixels'] = args - datacube['timestep'] = None - datacube['latlon'] = None - return self.model.forward(datacube) - @register_factory class ClayModelFactory(ModelFactory): def build_model( @@ -116,7 +90,7 @@ def build_model( bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands] # TODO: support auxiliary heads if not isinstance(backbone, nn.Module): - if not "Clay" in backbone: + if not "clay" in backbone: msg = "This class only handles models for `Clay` encoders" raise NotImplementedError(msg) @@ -138,51 +112,23 @@ def build_model( features_only=True, **backbone_kwargs, ) - except Exception: - - # When the model is not on HG, it needs be restored locally. - print("This model is not available on HuggingFace. Trying to instantiate locally ...") + except Exception as e: + print(e, "Error loading from HF. Trying to instantiate locally ...") assert checkpoint_path, "A checkpoint must be provided to restore the model." - # The CLAY source code must be installed or available via PYTHONPATH. - try: # TODO Inlcude the Clay source code into the tolkit in order to - # avoid issues with the modules paths or made it - # seamlessly accesible via configuration. - if self.syspath_kwarg in kwargs: - syspath_value = kwargs.get(self.syspath_kwarg) - - else: - - Exception(f"It is necessary to define the variable {self.syspath_kwarg} on yaml" - "config for restoring local model.") - - sys.path.insert(0, syspath_value) - - from src.model_clay import CLAYModule - - except ModuleNotFoundError: - - print(f"It is better to review the field {self.syspath_kwarg} in the yaml file.") - - backbone: nn.Module = ModelWrapper(model=CLAYModule(**backbone_kwargs)) - - if self.CPU_ONLY: - model_dict = torch.load(checkpoint_path, map_location="cpu") - else: - model_dict = torch.load(checkpoint_path) - - backbone.model.load_state_dict(model_dict['state_dict']) - + device = "cpu" if self.CPU_ONLY else "cuda" + backbone: nn.Module = Embedder( + ckpt_path=checkpoint_path, device=device, **backbone_kwargs) print("Model Clay was successfully restored.") # allow decoder to be a module passed directly decoder_cls = _get_decoder(decoder) - decoder_kwargs = _extract_prefix_keys(kwargs, "decoder_") # TODO: remove this - decoder: nn.Module = decoder_cls(backbone.channels(), **decoder_kwargs) + decoder: nn.Module = decoder_cls( + backbone.feature_info.channels(), **decoder_kwargs) # decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs) head_kwargs = _extract_prefix_keys(kwargs, "head_") @@ -193,12 +139,14 @@ def build_model( task, backbone, decoder, head_kwargs, prepare_features_for_image_model, rescale=rescale ) - to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = [] + to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = [ + ] for aux_decoder in aux_decoders: args = aux_decoder.decoder_args if aux_decoder.decoder_args else {} aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder) aux_decoder_kwargs = _extract_prefix_keys(args, "decoder_") - aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs) + aux_decoder_instance = aux_decoder_cls( + backbone.feature_info.channels(), **aux_decoder_kwargs) # aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs) aux_head_kwargs = _extract_prefix_keys(args, "head_") @@ -207,7 +155,8 @@ def build_model( # aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs) # aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head) to_be_aux_decoders.append( - AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs) + AuxiliaryHeadWithDecoderWithoutInstantiatedHead( + aux_decoder.name, aux_decoder_instance, aux_head_kwargs) ) return _build_appropriate_model( @@ -220,6 +169,7 @@ def build_model( auxiliary_heads=to_be_aux_decoders, ) + def _build_appropriate_model( task: str, backbone: nn.Module,