-
-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye Gomez
authored and
Kye Gomez
committed
Jun 14, 2024
1 parent
a59b7f6
commit 41e1f0a
Showing
5 changed files
with
308 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "zetascale" | ||
version = "2.5.5" | ||
version = "2.5.6" | ||
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" | ||
authors = ["Zeta Team <[email protected]>"] | ||
license = "MIT" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
from torch import nn, Tensor | ||
|
||
|
||
class CoPE(nn.Module): | ||
def __init__(self, npos_max: int, dim: int = None): | ||
super().__init__() | ||
self.npos_max = npos_max | ||
self.pos_emb = nn.parameter.Parameter(torch.zeros(1, dim, npos_max)) | ||
|
||
def forward(self, query: Tensor, attn_logits: Tensor) -> Tensor: | ||
# compute positions | ||
gates = torch.sigmoid(attn_logits) | ||
pos = gates.flip(-1).cumsum(dim=-1).flip(-1) | ||
pos = pos.clamp(max=self.npos_max - 1) | ||
# interpolate from integer positions | ||
pos_ceil = pos.ceil().long() | ||
pos_floor = pos.floor().long() | ||
logits_int = torch.matmul(query, self.pos_emb) | ||
logits_ceil = logits_int.gather(-1, pos_ceil) | ||
logits_floor = logits_int.gather(-1, pos_floor) | ||
w = pos - pos_floor | ||
return logits_ceil * w + logits_floor * (1 - w) | ||
|
||
|
||
# x = torch.randn(1, 5, 10) | ||
# attn_logits = torch.randn(1, 5, 10) | ||
|
||
# cope = CoPE(5, 10) | ||
# out = cope(x, attn_logits) | ||
# print(out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from einops import rearrange | ||
|
||
|
||
class SparseFineGrainedContrastiveAlignment(nn.Module): | ||
def __init__( | ||
self, | ||
vision_adapter: nn.Module, | ||
text_adapter: nn.Module, | ||
hidden_dim: int, | ||
tau: float = 0.07, | ||
): | ||
super(SparseFineGrainedContrastiveAlignment, self).__init__() | ||
self.vision_adapter = vision_adapter | ||
self.text_adapter = text_adapter | ||
self.hidden_dim = hidden_dim | ||
self.tau = tau | ||
|
||
def forward( | ||
self, image_patches: torch.Tensor, text_tokens: torch.Tensor | ||
) -> torch.Tensor: | ||
# Assume image_patches: [b, c, h, w] and text_tokens: [b, s, d] are already encoded | ||
|
||
# Flatten image patches for easier processing | ||
b, c, h, w = image_patches.shape | ||
image_patches = rearrange( | ||
image_patches, "b c h w -> b (h w) c" | ||
) # shape: [b, hw, c] | ||
|
||
# Apply adapters | ||
image_patches = self.vision_adapter(image_patches) # shape: [b, hw, d] | ||
text_tokens = self.text_adapter(text_tokens) # shape: [b, s, d] | ||
|
||
# Compute global embeddings | ||
global_image_embedding = self.vision_adapter( | ||
F.adaptive_avg_pool2d( | ||
rearrange(image_patches, "b p d -> b d p"), (1, 1) | ||
).squeeze(-1) | ||
) # shape: [b, d] | ||
global_text_embedding = self.text_adapter( | ||
F.adaptive_avg_pool1d( | ||
rearrange(text_tokens, "b s d -> b d s"), 1 | ||
).squeeze(-1) | ||
) # shape: [b, d] | ||
|
||
# Global contrastive loss | ||
global_loss = self.global_contrastive_loss( | ||
global_image_embedding, global_text_embedding | ||
) | ||
|
||
# Fine-grained alignment | ||
fine_grained_loss = self.fine_grained_alignment( | ||
image_patches, text_tokens | ||
) | ||
|
||
# Overall loss | ||
overall_loss = global_loss + fine_grained_loss | ||
|
||
return overall_loss | ||
|
||
def global_contrastive_loss( | ||
self, | ||
global_image_embedding: torch.Tensor, | ||
global_text_embedding: torch.Tensor, | ||
) -> torch.Tensor: | ||
b, d = global_image_embedding.shape | ||
sim_matrix = ( | ||
F.cosine_similarity( | ||
global_image_embedding.unsqueeze(1), | ||
global_text_embedding.unsqueeze(0), | ||
dim=-1, | ||
) | ||
/ self.tau | ||
) | ||
labels = torch.arange(b).long().to(global_image_embedding.device) | ||
loss_i = F.cross_entropy(sim_matrix, labels) | ||
loss_t = F.cross_entropy(sim_matrix.T, labels) | ||
loss = (loss_i + loss_t) / 2 | ||
return loss | ||
|
||
def fine_grained_alignment( | ||
self, image_patches: torch.Tensor, text_tokens: torch.Tensor | ||
) -> torch.Tensor: | ||
b, hw, d = image_patches.shape | ||
_, s, _ = text_tokens.shape | ||
|
||
# Compute similarity matrix | ||
sim_matrix = torch.einsum( | ||
"bpd,bsd->bps", image_patches, text_tokens | ||
) # shape: [b, hw, s] | ||
|
||
# Min-max normalization | ||
sim_matrix = (sim_matrix - sim_matrix.min(dim=1, keepdim=True)[0]) / ( | ||
sim_matrix.max(dim=1, keepdim=True)[0] | ||
- sim_matrix.min(dim=1, keepdim=True)[0] | ||
+ 1e-8 | ||
) | ||
|
||
# Sparsification | ||
sigma = 1 / hw | ||
sim_matrix[sim_matrix < sigma] = 0 | ||
|
||
# Compute alignment weights | ||
alignment_weights = F.normalize( | ||
sim_matrix, p=1, dim=1 | ||
) # shape: [b, hw, s] | ||
|
||
# Compute language-grouped vision embeddings | ||
language_grouped_vision_embeddings = torch.einsum( | ||
"bps,bpd->bsd", alignment_weights, image_patches | ||
) # shape: [b, s, d] | ||
|
||
# Fine-grained contrastive loss | ||
fine_grained_loss = self.fine_grained_contrastive_loss( | ||
language_grouped_vision_embeddings, text_tokens | ||
) | ||
|
||
return fine_grained_loss | ||
|
||
def fine_grained_contrastive_loss( | ||
self, | ||
language_grouped_vision_embeddings: torch.Tensor, | ||
text_tokens: torch.Tensor, | ||
) -> torch.Tensor: | ||
b, s, d = language_grouped_vision_embeddings.shape | ||
sim_matrix = ( | ||
F.cosine_similarity( | ||
language_grouped_vision_embeddings.unsqueeze(2), | ||
text_tokens.unsqueeze(1), | ||
dim=-1, | ||
) | ||
/ self.tau | ||
) | ||
labels = ( | ||
torch.arange(s).long().to(language_grouped_vision_embeddings.device) | ||
) | ||
loss_c = F.cross_entropy(sim_matrix.permute(0, 2, 1), labels) | ||
loss_t = F.cross_entropy(sim_matrix, labels) | ||
loss = (loss_c + loss_t) / 2 | ||
return loss | ||
|
||
|
||
# # Example usage: | ||
# # Assuming vision_adapter and text_adapter are defined elsewhere | ||
# model = SparseFineGrainedContrastiveAlignment( | ||
# vision_adapter, text_adapter, hidden_dim=768 | ||
# ) | ||
# image_patches = torch.randn(32, 3, 224, 224) # Example image batch | ||
# text_tokens = torch.randn(32, 128, 768) # Example text batch | ||
# loss = model(image_patches, text_tokens) | ||
# print(loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
# Define the TensorShape class | ||
class TensorShape(Tensor): | ||
""" | ||
Represents the shape of a tensor. | ||
Args: | ||
data (array-like): The data of the tensor. | ||
shape_string (str): The string representation of the shape. | ||
Attributes: | ||
shape_string (str): The string representation of the shape. | ||
shape_dict (dict): A dictionary mapping dimensions to sizes. | ||
Raises: | ||
ValueError: If the shape string does not match the actual shape. | ||
Example: | ||
>>> data = [1, 2, 3, 4] | ||
>>> shape_string = "2 2" | ||
>>> tensor_shape = TensorShape(data, shape_string) | ||
>>> print(tensor_shape) | ||
TensorShape(shape_string='2 2', actual_shape=(2, 2)) | ||
""" | ||
|
||
def __new__(cls, data, shape_string): | ||
instance = torch.as_tensor(data).as_subclass(cls) | ||
instance.shape_string = shape_string | ||
instance.shape_dict = cls.parse_shape_string( | ||
shape_string, instance.shape | ||
) | ||
return instance | ||
|
||
@staticmethod | ||
def parse_shape_string(shape_string, actual_shape): | ||
""" | ||
Parses the shape string and returns a dictionary mapping dimensions to sizes. | ||
Args: | ||
shape_string (str): The string representation of the shape. | ||
actual_shape (tuple): The actual shape of the tensor. | ||
Returns: | ||
dict: A dictionary mapping dimensions to sizes. | ||
Raises: | ||
ValueError: If the number of dimensions in the shape string does not match the actual shape. | ||
""" | ||
dimensions = shape_string.split() | ||
if len(dimensions) != len(actual_shape): | ||
raise ValueError( | ||
f"Shape string {shape_string} does not match actual shape {actual_shape}" | ||
) | ||
return {dim: size for dim, size in zip(dimensions, actual_shape)} | ||
|
||
def __repr__(self): | ||
return f"TensorShape(shape_string={self.shape_string}, actual_shape={super().shape})" | ||
|
||
@staticmethod | ||
def check_shape(tensor, shape_string): | ||
""" | ||
Checks if the shape of the given tensor matches the specified shape string. | ||
Args: | ||
tensor (Tensor): The tensor to check the shape of. | ||
shape_string (str): The string representation of the expected shape. | ||
Raises: | ||
ValueError: If the shape of the tensor does not match the expected shape. | ||
""" | ||
shape_dict = TensorShape.parse_shape_string(shape_string, tensor.shape) | ||
if tensor.shape != tuple(shape_dict.values()): | ||
raise ValueError( | ||
f"Expected shape {shape_dict}, but got {tensor.shape}" | ||
) | ||
|
||
|
||
# Define a decorator for shape checking | ||
def check_tensor_shape(shape_string: str = None): | ||
""" | ||
Decorator function that checks if the shape of a tensor matches the specified shape string. | ||
Args: | ||
shape_string (str): A string representing the desired shape of the tensor. | ||
Returns: | ||
function: A decorator function that wraps the original function and performs the shape check. | ||
Example: | ||
@check_tensor_shape("B S D") | ||
def my_function(tensor): | ||
# Function implementation | ||
pass | ||
The above example will ensure that the tensor passed to `my_function` has a shape of (2, 3). | ||
""" | ||
|
||
def decorator(func): | ||
def wrapper(*args, **kwargs): | ||
# Assuming the tensor is the first argument | ||
tensor = args[1] | ||
TensorShape.check_shape(tensor, shape_string) | ||
return func(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
# Define a helper function to create TensorShape objects | ||
def create_tensor( | ||
data: Tensor = None, shape_string: str = None, random_on: bool = False | ||
): | ||
if random_on: | ||
data = torch.randn(data) | ||
return TensorShape(data, shape_string) | ||
else: | ||
return TensorShape(data, shape_string) |