diff --git a/pyproject.toml b/pyproject.toml index f8d376cd..24fcad2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "2.5.5" +version = "2.5.6" description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" authors = ["Zeta Team "] license = "MIT" diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index 1b67c747..01d9a867 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -219,6 +219,7 @@ ) from zeta.nn.modules.simple_lstm import SimpleLSTM from zeta.nn.modules.simple_rnn import SimpleRNN +from zeta.nn.modules.cope import CoPE # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -440,4 +441,5 @@ "SparseChannelIntegration", "SimpleLSTM", "SimpleRNN", + "CoPE", ] diff --git a/zeta/nn/modules/cope.py b/zeta/nn/modules/cope.py new file mode 100644 index 00000000..e888c937 --- /dev/null +++ b/zeta/nn/modules/cope.py @@ -0,0 +1,31 @@ +import torch +from torch import nn, Tensor + + +class CoPE(nn.Module): + def __init__(self, npos_max: int, dim: int = None): + super().__init__() + self.npos_max = npos_max + self.pos_emb = nn.parameter.Parameter(torch.zeros(1, dim, npos_max)) + + def forward(self, query: Tensor, attn_logits: Tensor) -> Tensor: + # compute positions + gates = torch.sigmoid(attn_logits) + pos = gates.flip(-1).cumsum(dim=-1).flip(-1) + pos = pos.clamp(max=self.npos_max - 1) + # interpolate from integer positions + pos_ceil = pos.ceil().long() + pos_floor = pos.floor().long() + logits_int = torch.matmul(query, self.pos_emb) + logits_ceil = logits_int.gather(-1, pos_ceil) + logits_floor = logits_int.gather(-1, pos_floor) + w = pos - pos_floor + return logits_ceil * w + logits_floor * (1 - w) + + +# x = torch.randn(1, 5, 10) +# attn_logits = torch.randn(1, 5, 10) + +# cope = CoPE(5, 10) +# out = cope(x, attn_logits) +# print(out) diff --git a/zeta/nn/modules/sparc_alignment.py b/zeta/nn/modules/sparc_alignment.py new file mode 100644 index 00000000..eb1bc28c --- /dev/null +++ b/zeta/nn/modules/sparc_alignment.py @@ -0,0 +1,153 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class SparseFineGrainedContrastiveAlignment(nn.Module): + def __init__( + self, + vision_adapter: nn.Module, + text_adapter: nn.Module, + hidden_dim: int, + tau: float = 0.07, + ): + super(SparseFineGrainedContrastiveAlignment, self).__init__() + self.vision_adapter = vision_adapter + self.text_adapter = text_adapter + self.hidden_dim = hidden_dim + self.tau = tau + + def forward( + self, image_patches: torch.Tensor, text_tokens: torch.Tensor + ) -> torch.Tensor: + # Assume image_patches: [b, c, h, w] and text_tokens: [b, s, d] are already encoded + + # Flatten image patches for easier processing + b, c, h, w = image_patches.shape + image_patches = rearrange( + image_patches, "b c h w -> b (h w) c" + ) # shape: [b, hw, c] + + # Apply adapters + image_patches = self.vision_adapter(image_patches) # shape: [b, hw, d] + text_tokens = self.text_adapter(text_tokens) # shape: [b, s, d] + + # Compute global embeddings + global_image_embedding = self.vision_adapter( + F.adaptive_avg_pool2d( + rearrange(image_patches, "b p d -> b d p"), (1, 1) + ).squeeze(-1) + ) # shape: [b, d] + global_text_embedding = self.text_adapter( + F.adaptive_avg_pool1d( + rearrange(text_tokens, "b s d -> b d s"), 1 + ).squeeze(-1) + ) # shape: [b, d] + + # Global contrastive loss + global_loss = self.global_contrastive_loss( + global_image_embedding, global_text_embedding + ) + + # Fine-grained alignment + fine_grained_loss = self.fine_grained_alignment( + image_patches, text_tokens + ) + + # Overall loss + overall_loss = global_loss + fine_grained_loss + + return overall_loss + + def global_contrastive_loss( + self, + global_image_embedding: torch.Tensor, + global_text_embedding: torch.Tensor, + ) -> torch.Tensor: + b, d = global_image_embedding.shape + sim_matrix = ( + F.cosine_similarity( + global_image_embedding.unsqueeze(1), + global_text_embedding.unsqueeze(0), + dim=-1, + ) + / self.tau + ) + labels = torch.arange(b).long().to(global_image_embedding.device) + loss_i = F.cross_entropy(sim_matrix, labels) + loss_t = F.cross_entropy(sim_matrix.T, labels) + loss = (loss_i + loss_t) / 2 + return loss + + def fine_grained_alignment( + self, image_patches: torch.Tensor, text_tokens: torch.Tensor + ) -> torch.Tensor: + b, hw, d = image_patches.shape + _, s, _ = text_tokens.shape + + # Compute similarity matrix + sim_matrix = torch.einsum( + "bpd,bsd->bps", image_patches, text_tokens + ) # shape: [b, hw, s] + + # Min-max normalization + sim_matrix = (sim_matrix - sim_matrix.min(dim=1, keepdim=True)[0]) / ( + sim_matrix.max(dim=1, keepdim=True)[0] + - sim_matrix.min(dim=1, keepdim=True)[0] + + 1e-8 + ) + + # Sparsification + sigma = 1 / hw + sim_matrix[sim_matrix < sigma] = 0 + + # Compute alignment weights + alignment_weights = F.normalize( + sim_matrix, p=1, dim=1 + ) # shape: [b, hw, s] + + # Compute language-grouped vision embeddings + language_grouped_vision_embeddings = torch.einsum( + "bps,bpd->bsd", alignment_weights, image_patches + ) # shape: [b, s, d] + + # Fine-grained contrastive loss + fine_grained_loss = self.fine_grained_contrastive_loss( + language_grouped_vision_embeddings, text_tokens + ) + + return fine_grained_loss + + def fine_grained_contrastive_loss( + self, + language_grouped_vision_embeddings: torch.Tensor, + text_tokens: torch.Tensor, + ) -> torch.Tensor: + b, s, d = language_grouped_vision_embeddings.shape + sim_matrix = ( + F.cosine_similarity( + language_grouped_vision_embeddings.unsqueeze(2), + text_tokens.unsqueeze(1), + dim=-1, + ) + / self.tau + ) + labels = ( + torch.arange(s).long().to(language_grouped_vision_embeddings.device) + ) + loss_c = F.cross_entropy(sim_matrix.permute(0, 2, 1), labels) + loss_t = F.cross_entropy(sim_matrix, labels) + loss = (loss_c + loss_t) / 2 + return loss + + +# # Example usage: +# # Assuming vision_adapter and text_adapter are defined elsewhere +# model = SparseFineGrainedContrastiveAlignment( +# vision_adapter, text_adapter, hidden_dim=768 +# ) +# image_patches = torch.randn(32, 3, 224, 224) # Example image batch +# text_tokens = torch.randn(32, 128, 768) # Example text batch +# loss = model(image_patches, text_tokens) +# print(loss) diff --git a/zeta/nn/modules/tensor_shape.py b/zeta/nn/modules/tensor_shape.py new file mode 100644 index 00000000..296a9d52 --- /dev/null +++ b/zeta/nn/modules/tensor_shape.py @@ -0,0 +1,121 @@ +import torch +from torch import Tensor + + +# Define the TensorShape class +class TensorShape(Tensor): + """ + Represents the shape of a tensor. + + Args: + data (array-like): The data of the tensor. + shape_string (str): The string representation of the shape. + + Attributes: + shape_string (str): The string representation of the shape. + shape_dict (dict): A dictionary mapping dimensions to sizes. + + Raises: + ValueError: If the shape string does not match the actual shape. + + Example: + >>> data = [1, 2, 3, 4] + >>> shape_string = "2 2" + >>> tensor_shape = TensorShape(data, shape_string) + >>> print(tensor_shape) + TensorShape(shape_string='2 2', actual_shape=(2, 2)) + """ + + def __new__(cls, data, shape_string): + instance = torch.as_tensor(data).as_subclass(cls) + instance.shape_string = shape_string + instance.shape_dict = cls.parse_shape_string( + shape_string, instance.shape + ) + return instance + + @staticmethod + def parse_shape_string(shape_string, actual_shape): + """ + Parses the shape string and returns a dictionary mapping dimensions to sizes. + + Args: + shape_string (str): The string representation of the shape. + actual_shape (tuple): The actual shape of the tensor. + + Returns: + dict: A dictionary mapping dimensions to sizes. + + Raises: + ValueError: If the number of dimensions in the shape string does not match the actual shape. + """ + dimensions = shape_string.split() + if len(dimensions) != len(actual_shape): + raise ValueError( + f"Shape string {shape_string} does not match actual shape {actual_shape}" + ) + return {dim: size for dim, size in zip(dimensions, actual_shape)} + + def __repr__(self): + return f"TensorShape(shape_string={self.shape_string}, actual_shape={super().shape})" + + @staticmethod + def check_shape(tensor, shape_string): + """ + Checks if the shape of the given tensor matches the specified shape string. + + Args: + tensor (Tensor): The tensor to check the shape of. + shape_string (str): The string representation of the expected shape. + + Raises: + ValueError: If the shape of the tensor does not match the expected shape. + """ + shape_dict = TensorShape.parse_shape_string(shape_string, tensor.shape) + if tensor.shape != tuple(shape_dict.values()): + raise ValueError( + f"Expected shape {shape_dict}, but got {tensor.shape}" + ) + + +# Define a decorator for shape checking +def check_tensor_shape(shape_string: str = None): + """ + Decorator function that checks if the shape of a tensor matches the specified shape string. + + Args: + shape_string (str): A string representing the desired shape of the tensor. + + Returns: + function: A decorator function that wraps the original function and performs the shape check. + + Example: + @check_tensor_shape("B S D") + def my_function(tensor): + # Function implementation + pass + + The above example will ensure that the tensor passed to `my_function` has a shape of (2, 3). + """ + + def decorator(func): + def wrapper(*args, **kwargs): + # Assuming the tensor is the first argument + tensor = args[1] + TensorShape.check_shape(tensor, shape_string) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +# Define a helper function to create TensorShape objects +def create_tensor( + data: Tensor = None, shape_string: str = None, random_on: bool = False +): + if random_on: + data = torch.randn(data) + return TensorShape(data, shape_string) + else: + return TensorShape(data, shape_string)