diff --git a/src/benchmark_encoder.py b/src/benchmark_encoder.py deleted file mode 100644 index 08c2b0ac..00000000 --- a/src/benchmark_encoder.py +++ /dev/null @@ -1,80 +0,0 @@ -import argparse -import time -import warnings - -import torch - -warnings.filterwarnings("ignore") - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -def get_data(): - """ - Generate random data tensors for model input. - """ - cube = torch.randn(128, 3, 256, 256).to(DEVICE) - timestep = torch.randn(128, 4).to(DEVICE) - latlon = torch.randn(128, 4).to(DEVICE) - waves = torch.randn(3).to(DEVICE) - gsd = torch.randn(1).to(DEVICE) - return cube, timestep, latlon, waves, gsd - - -def load_exported_model(eager=True): - """ - Load the exported model from a file. - - Args: - eager (bool): Flag to decide whether to use eager mode or compiled mode. - """ - print("Loading exported model") - ep = torch.export.load("checkpoints/compiled/encoder.pt") - if eager: - model = ep.module() - else: - model = torch.compile(ep.module(), backend="inductor") - return model - - -def benchmark_model(model): - """ - Benchmark the model by running inference on randomly generated data. - - Args: - model: The model to benchmark. - """ - print("Benchmarking model") - start = time.time() - for i in range(20): - cube, timestep, latlon, waves, gsd = get_data() - with torch.inference_mode(): - out = model(cube, timestep, latlon, waves, gsd) - print( - f"Iteration {i}: Output shapes - {out[0].shape}, {out[1].shape}, {out[2].shape}, {out[3].shape}" # noqa E501 - ) - print("Time taken for inference: ", time.time() - start) - - -def run(eager=True): - """ - Run the exported model and benchmark it. - - Args: - eager (bool): Flag to decide whether to use eager mode or compiled mode. - """ - print("Running model") - model = load_exported_model(eager=eager) - benchmark_model(model) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Run benchmark for the exported model." - ) - parser.add_argument( - "--eager", action="store_true", help="Use eager mode for running the model." - ) - args = parser.parse_args() - - run(args.eager) diff --git a/src/export.py b/src/export.py deleted file mode 100644 index 70a65f1e..00000000 --- a/src/export.py +++ /dev/null @@ -1,73 +0,0 @@ -import warnings -from pathlib import Path - -import torch -from torch.export import Dim - -from src.model import ClayMAEModule - -warnings.filterwarnings("ignore") - -CHECKPOINT_PATH = "checkpoints/clay-v1-base.ckpt" -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -CHIP_SIZE = 256 - - -def get_data(): - """ - Generate random data tensors for model input. - """ - cube = torch.randn(128, 3, CHIP_SIZE, CHIP_SIZE).to(DEVICE) - timestep = torch.randn(128, 4).to(DEVICE) - latlon = torch.randn(128, 4).to(DEVICE) - waves = torch.randn(3).to(DEVICE) - gsd = torch.randn(1).to(DEVICE) - return cube, timestep, latlon, waves, gsd - - -def load_model(): - """ - Load the model from a checkpoint and prepare it for evaluation. - """ - module = ClayMAEModule.load_from_checkpoint( - CHECKPOINT_PATH, shuffle=False, mask_ratio=0.0 - ) - encoder = module.model.encoder.eval() # Get the encoder in eval mode - encoder = encoder.to(DEVICE) # Move to the appropriate device - return encoder - - -def export_model(): - """ - Export the model with dynamic shapes for deployment. - """ - cube, timestep, latlon, waves, gsd = get_data() - encoder = load_model() - - # Define dynamic shapes for model export - batch_size = Dim("batch_size", min=32, max=1200) - channel_bands = Dim("channel_bands", min=1, max=10) - - dynamic_shapes = { - "cube": {0: batch_size, 1: channel_bands}, - "time": {0: batch_size}, - "latlon": {0: batch_size}, - "waves": {0: channel_bands}, - "gsd": {0: None}, - } - - # Export model - ep = torch.export.export( - mod=encoder, - args=(cube, timestep, latlon, waves, gsd), - dynamic_shapes=dynamic_shapes, - strict=True, - ) - - # Save the exported model - Path("checkpoints/compiled").mkdir(parents=True, exist_ok=True) - torch.export.save(ep, "checkpoints/compiled/encoder.pt") - - -if __name__ == "__main__": - export_model() diff --git a/src/test_encoder.py b/src/test_encoder.py deleted file mode 100644 index 14839197..00000000 --- a/src/test_encoder.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch - -from src.datamodule import ClayDataModule - -# Load the pre-trained Clay encoder model -clay_encoder = torch.export.load("checkpoints/compiled/encoder.pt").module() - - -def load_batch(): - # Initialize the data module with appropriate parameters - dm = ClayDataModule( - data_dir="/home/ubuntu/data", - size=256, - metadata_path="configs/metadata.yaml", - batch_size=1, - num_workers=1, - ) - - # Setup the data module for the 'fit' stage - dm.setup(stage="fit") - metadata = dm.metadata - - # Get the training data loader and create an iterator - trn_dl = dm.train_dataloader() - iter_dl = iter(trn_dl) - - return iter_dl, metadata - - -def prepare_data(sensor, metadata, device): - """ - Load data from the sensor and transfer it to the specified device. - - Args: - - sensor (dict): Sensor data containing 'pixels', 'time', 'latlon', and 'platform'. - - metadata (dict): Metadata information for different platforms. - - device (torch.device): The device to which the data should be transferred. - - Returns: - - tuple: Transferred cube, timestep, latlon, waves, and gsd tensors. - """ - cube = sensor["pixels"] - timestep = sensor["time"] - latlon = sensor["latlon"] - platform = sensor["platform"][0] - - # Get wavelengths and ground sampling distance (gsd) from metadata - waves = torch.tensor(list(metadata[platform].bands.wavelength.values())) - gsd = torch.tensor([metadata[platform].gsd]) - - # Transfer data to the specified device - cube, timestep, latlon, waves, gsd = map( - lambda x: x.to(device), (cube, timestep, latlon, waves, gsd) - ) - return cube, timestep, latlon, waves, gsd - - -def main(): - dl, metadata = load_batch() - - # Fetch samples from the data loader - l8_c2l1 = next(dl) - l8_c2l2 = next(dl) - linz = next(dl) - naip = next(dl) - s1 = next(dl) - s2 = next(dl) - - # Perform inference with the Clay encoder model - with torch.no_grad(): - for sensor in (l8_c2l1, l8_c2l2, linz, naip, s1, s2): - # Load data and transfer to GPU - batch = prepare_data(sensor, metadata, torch.device("cuda")) - - # Get patch embeddings from the encoder model - patch_embeddings, *_ = clay_encoder(*batch) - - # Extract the class (CLS) embedding - cls_embedding = patch_embeddings[:, 0, :] - - # Print the platform and the shape of the CLS embedding - print(sensor["platform"][0], cls_embedding.shape) - - -if __name__ == "__main__": - main() diff --git a/src/utils.py b/src/utils.py index 1e35731f..b0f2bcce 100644 --- a/src/utils.py +++ b/src/utils.py @@ -11,7 +11,6 @@ def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32): assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature**omega) - omega = omega.to(y.device) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] @@ -25,9 +24,9 @@ def posemb_sincos_2d_with_gsd( y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" - omega = torch.arange(dim // 4, device=gsd.device) / (dim // 4 - 1) + gsd = gsd.to(x.device) + omega = torch.arange(dim // 4) / (dim // 4 - 1) omega = 1.0 / (temperature ** (2 * omega / dim)) * (gsd / 1.0) # Adjusted for g - omega = omega.to(y.device) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] @@ -35,17 +34,16 @@ def posemb_sincos_2d_with_gsd( return pe.type(dtype) -def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32): +def posemb_sincos_1d(waves, dim, temperature: int = 10000, dtype=torch.float32): assert ( dim % 2 == 0 ), "Feature dimension must be a multiple of 2 for sincos embedding" - pos = torch.arange(pos) if isinstance(pos, int) else pos + waves = torch.arange(waves) if isinstance(waves, int) else waves - omega = torch.arange(dim // 2) / (dim // 2 - 1) + omega = torch.arange(dim // 2, device=waves.device) / (dim // 2 - 1) omega = 1.0 / (temperature**omega) - omega = omega.to(pos.device) - scaled_pos = pos[:, None] * omega[None, :] - pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1) + scaled_waves = waves[:, None] * omega[None, :] + pe = torch.cat((scaled_waves.sin(), scaled_waves.cos()), dim=1) return pe.type(dtype)