Skip to content

Commit

Permalink
[FEAT][Cope]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed Jun 14, 2024
1 parent a59b7f6 commit 41e1f0a
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.5.5"
version = "2.5.6"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down
2 changes: 2 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@
)
from zeta.nn.modules.simple_lstm import SimpleLSTM
from zeta.nn.modules.simple_rnn import SimpleRNN
from zeta.nn.modules.cope import CoPE

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -440,4 +441,5 @@
"SparseChannelIntegration",
"SimpleLSTM",
"SimpleRNN",
"CoPE",
]
31 changes: 31 additions & 0 deletions zeta/nn/modules/cope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from torch import nn, Tensor


class CoPE(nn.Module):
def __init__(self, npos_max: int, dim: int = None):
super().__init__()
self.npos_max = npos_max
self.pos_emb = nn.parameter.Parameter(torch.zeros(1, dim, npos_max))

def forward(self, query: Tensor, attn_logits: Tensor) -> Tensor:
# compute positions
gates = torch.sigmoid(attn_logits)
pos = gates.flip(-1).cumsum(dim=-1).flip(-1)
pos = pos.clamp(max=self.npos_max - 1)
# interpolate from integer positions
pos_ceil = pos.ceil().long()
pos_floor = pos.floor().long()
logits_int = torch.matmul(query, self.pos_emb)
logits_ceil = logits_int.gather(-1, pos_ceil)
logits_floor = logits_int.gather(-1, pos_floor)
w = pos - pos_floor
return logits_ceil * w + logits_floor * (1 - w)


# x = torch.randn(1, 5, 10)
# attn_logits = torch.randn(1, 5, 10)

# cope = CoPE(5, 10)
# out = cope(x, attn_logits)
# print(out)
153 changes: 153 additions & 0 deletions zeta/nn/modules/sparc_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class SparseFineGrainedContrastiveAlignment(nn.Module):
def __init__(
self,
vision_adapter: nn.Module,
text_adapter: nn.Module,
hidden_dim: int,
tau: float = 0.07,
):
super(SparseFineGrainedContrastiveAlignment, self).__init__()
self.vision_adapter = vision_adapter
self.text_adapter = text_adapter
self.hidden_dim = hidden_dim
self.tau = tau

def forward(
self, image_patches: torch.Tensor, text_tokens: torch.Tensor
) -> torch.Tensor:
# Assume image_patches: [b, c, h, w] and text_tokens: [b, s, d] are already encoded

# Flatten image patches for easier processing
b, c, h, w = image_patches.shape
image_patches = rearrange(
image_patches, "b c h w -> b (h w) c"
) # shape: [b, hw, c]

# Apply adapters
image_patches = self.vision_adapter(image_patches) # shape: [b, hw, d]
text_tokens = self.text_adapter(text_tokens) # shape: [b, s, d]

# Compute global embeddings
global_image_embedding = self.vision_adapter(
F.adaptive_avg_pool2d(
rearrange(image_patches, "b p d -> b d p"), (1, 1)
).squeeze(-1)
) # shape: [b, d]
global_text_embedding = self.text_adapter(
F.adaptive_avg_pool1d(
rearrange(text_tokens, "b s d -> b d s"), 1
).squeeze(-1)
) # shape: [b, d]

# Global contrastive loss
global_loss = self.global_contrastive_loss(
global_image_embedding, global_text_embedding
)

# Fine-grained alignment
fine_grained_loss = self.fine_grained_alignment(
image_patches, text_tokens
)

# Overall loss
overall_loss = global_loss + fine_grained_loss

return overall_loss

def global_contrastive_loss(
self,
global_image_embedding: torch.Tensor,
global_text_embedding: torch.Tensor,
) -> torch.Tensor:
b, d = global_image_embedding.shape
sim_matrix = (
F.cosine_similarity(
global_image_embedding.unsqueeze(1),
global_text_embedding.unsqueeze(0),
dim=-1,
)
/ self.tau
)
labels = torch.arange(b).long().to(global_image_embedding.device)
loss_i = F.cross_entropy(sim_matrix, labels)
loss_t = F.cross_entropy(sim_matrix.T, labels)
loss = (loss_i + loss_t) / 2
return loss

def fine_grained_alignment(
self, image_patches: torch.Tensor, text_tokens: torch.Tensor
) -> torch.Tensor:
b, hw, d = image_patches.shape
_, s, _ = text_tokens.shape

# Compute similarity matrix
sim_matrix = torch.einsum(
"bpd,bsd->bps", image_patches, text_tokens
) # shape: [b, hw, s]

# Min-max normalization
sim_matrix = (sim_matrix - sim_matrix.min(dim=1, keepdim=True)[0]) / (
sim_matrix.max(dim=1, keepdim=True)[0]
- sim_matrix.min(dim=1, keepdim=True)[0]
+ 1e-8
)

# Sparsification
sigma = 1 / hw
sim_matrix[sim_matrix < sigma] = 0

# Compute alignment weights
alignment_weights = F.normalize(
sim_matrix, p=1, dim=1
) # shape: [b, hw, s]

# Compute language-grouped vision embeddings
language_grouped_vision_embeddings = torch.einsum(
"bps,bpd->bsd", alignment_weights, image_patches
) # shape: [b, s, d]

# Fine-grained contrastive loss
fine_grained_loss = self.fine_grained_contrastive_loss(
language_grouped_vision_embeddings, text_tokens
)

return fine_grained_loss

def fine_grained_contrastive_loss(
self,
language_grouped_vision_embeddings: torch.Tensor,
text_tokens: torch.Tensor,
) -> torch.Tensor:
b, s, d = language_grouped_vision_embeddings.shape
sim_matrix = (
F.cosine_similarity(
language_grouped_vision_embeddings.unsqueeze(2),
text_tokens.unsqueeze(1),
dim=-1,
)
/ self.tau
)
labels = (
torch.arange(s).long().to(language_grouped_vision_embeddings.device)
)
loss_c = F.cross_entropy(sim_matrix.permute(0, 2, 1), labels)
loss_t = F.cross_entropy(sim_matrix, labels)
loss = (loss_c + loss_t) / 2
return loss


# # Example usage:
# # Assuming vision_adapter and text_adapter are defined elsewhere
# model = SparseFineGrainedContrastiveAlignment(
# vision_adapter, text_adapter, hidden_dim=768
# )
# image_patches = torch.randn(32, 3, 224, 224) # Example image batch
# text_tokens = torch.randn(32, 128, 768) # Example text batch
# loss = model(image_patches, text_tokens)
# print(loss)
121 changes: 121 additions & 0 deletions zeta/nn/modules/tensor_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
from torch import Tensor


# Define the TensorShape class
class TensorShape(Tensor):
"""
Represents the shape of a tensor.
Args:
data (array-like): The data of the tensor.
shape_string (str): The string representation of the shape.
Attributes:
shape_string (str): The string representation of the shape.
shape_dict (dict): A dictionary mapping dimensions to sizes.
Raises:
ValueError: If the shape string does not match the actual shape.
Example:
>>> data = [1, 2, 3, 4]
>>> shape_string = "2 2"
>>> tensor_shape = TensorShape(data, shape_string)
>>> print(tensor_shape)
TensorShape(shape_string='2 2', actual_shape=(2, 2))
"""

def __new__(cls, data, shape_string):
instance = torch.as_tensor(data).as_subclass(cls)
instance.shape_string = shape_string
instance.shape_dict = cls.parse_shape_string(
shape_string, instance.shape
)
return instance

@staticmethod
def parse_shape_string(shape_string, actual_shape):
"""
Parses the shape string and returns a dictionary mapping dimensions to sizes.
Args:
shape_string (str): The string representation of the shape.
actual_shape (tuple): The actual shape of the tensor.
Returns:
dict: A dictionary mapping dimensions to sizes.
Raises:
ValueError: If the number of dimensions in the shape string does not match the actual shape.
"""
dimensions = shape_string.split()
if len(dimensions) != len(actual_shape):
raise ValueError(
f"Shape string {shape_string} does not match actual shape {actual_shape}"
)
return {dim: size for dim, size in zip(dimensions, actual_shape)}

def __repr__(self):
return f"TensorShape(shape_string={self.shape_string}, actual_shape={super().shape})"

@staticmethod
def check_shape(tensor, shape_string):
"""
Checks if the shape of the given tensor matches the specified shape string.
Args:
tensor (Tensor): The tensor to check the shape of.
shape_string (str): The string representation of the expected shape.
Raises:
ValueError: If the shape of the tensor does not match the expected shape.
"""
shape_dict = TensorShape.parse_shape_string(shape_string, tensor.shape)
if tensor.shape != tuple(shape_dict.values()):
raise ValueError(
f"Expected shape {shape_dict}, but got {tensor.shape}"
)


# Define a decorator for shape checking
def check_tensor_shape(shape_string: str = None):
"""
Decorator function that checks if the shape of a tensor matches the specified shape string.
Args:
shape_string (str): A string representing the desired shape of the tensor.
Returns:
function: A decorator function that wraps the original function and performs the shape check.
Example:
@check_tensor_shape("B S D")
def my_function(tensor):
# Function implementation
pass
The above example will ensure that the tensor passed to `my_function` has a shape of (2, 3).
"""

def decorator(func):
def wrapper(*args, **kwargs):
# Assuming the tensor is the first argument
tensor = args[1]
TensorShape.check_shape(tensor, shape_string)
return func(*args, **kwargs)

return wrapper

return decorator


# Define a helper function to create TensorShape objects
def create_tensor(
data: Tensor = None, shape_string: str = None, random_on: bool = False
):
if random_on:
data = torch.randn(data)
return TensorShape(data, shape_string)
else:
return TensorShape(data, shape_string)

0 comments on commit 41e1f0a

Please sign in to comment.