Skip to content

Commit

Permalink
minor updates from PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
shaikh58 committed Sep 27, 2024
1 parent 5a5f75f commit 3ff1ab0
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 615 deletions.
22 changes: 13 additions & 9 deletions dreem/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
# create the lookup array based on how many instances there are
# max(101, seq_len) is for positional vs temporal; pos can only have idx up to
# 100 since it's a fraction of [0,1]*100. temp is from [0, clip_len]; since clip_len
# not available, we use # of instances from input x; this is always >= clip_len
self.build_rope_cache(max(101, seq_len)) # registers cache
# not available, we use the last value in the indexing array since this will be the
# last possible frame that we would need to index since no instances in a frame after that
self.build_rope_cache(max(101, input_pos[:, -1].max() + 1)) # registers cache
self.cache = self.cache.to(input_pos.device)
# extract the values based on whether input_pos is set or not
rope_cache = (
Expand Down Expand Up @@ -269,16 +270,21 @@ def _check_init_args(self, emb_type: str, mode: str):


def _transform(self, x, emb):

"""Routes to the relevant embedding function to transform the input queries
Args:
x: Input queries of shape (batch_size, N, embed_dim)
emb: Embedding array to apply to data; can be (N, embed_dim) or
(batch_size, n_query, num_heads, embed_dim // 2, 2) if using RoPE
"""
if self._emb_func == self._rope_embedding:
return self._apply_rope(x, emb)
else:
return self._apply_additive_embeddings(x, emb)


def _apply_rope(self, x, emb):
"""
Applies Rotary Positional Embedding to input queries
"""Applies Rotary Positional Embedding to input queries
Args:
x: Input queries of shape (batch_size, n_query, embed_dim)
Expand Down Expand Up @@ -308,8 +314,7 @@ def _apply_rope(self, x, emb):


def _apply_additive_embeddings(self, x, emb):
"""
Applies additive embeddings to input queries
"""Applies additive embeddings to input queries
Args:
x: Input tensor of shape (batch_size, N, embed_dim)
Expand Down Expand Up @@ -361,8 +366,7 @@ def _torch_int_div(


def _rope_embedding(self, seq_positions: torch.Tensor, input_shape: torch.Size) -> torch.Tensor:
"""
Computes the rotation matrix to apply RoPE to input queries
"""Computes the rotation matrix to apply RoPE to input queries
Args:
seq_positions: Pos array of shape (embed_dim,) used to compute rotational embedding
input_shape: Shape of the input queries; needed for rope
Expand Down
23 changes: 16 additions & 7 deletions dreem/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def forward(
query_boxes = ref_boxes
query_times = ref_times


decoder_features, pos_emb_traceback, temp_emb_traceback = self.decoder(
query_features, encoder_features,
embedding_map={"pos": self.pos_emb, "temp": self.temp_emb},
Expand Down Expand Up @@ -553,16 +552,24 @@ def forward(
if self.return_intermediate:
intermediate.pop()
intermediate.append(decoder_features)
return torch.stack(intermediate)
return torch.stack(intermediate), pos_emb_traceback, temp_emb_traceback

return decoder_features.unsqueeze(0), pos_emb_traceback, temp_emb_traceback


def apply_embeddings(queries: torch.Tensor, embedding_map: Dict[str, Embedding],
boxes: torch.Tensor, times: torch.Tensor,
embedding_agg_method: str):
"""
Enter docstring here
""" Applies embeddings to input queries for various aggregation methods. This function
is called from the transformer encoder and decoder
Args:
queries: The input tensor of shape (n_query, batch_size, embed_dim).
embedding_map: Dict of Embedding objects defining the pos/temp embeddings to be applied
to the input data
boxes: Bounding box based embedding ids of shape (n_query, n_anchors, 4)
times: Times based embedding ids of shape (n_query,)
embedding_agg_method: method of aggregation of embeddings e.g. stack/concatenate/average
"""

pos_emb, temp_emb = embedding_map["pos"], embedding_map["temp"]
Expand Down Expand Up @@ -635,14 +642,15 @@ def _get_activation_fn(activation: str) -> callable:

def collate_queries(queries: Tuple[torch.Tensor], embedding_agg_method: str
) -> torch.Tensor:
"""
Aggregates queries transformed by embeddings
"""Aggregates queries transformed by embeddings
Args:
_queries: 5-tuple of queries (already transformed by embeddings) for _, x, y, t, original input
each of shape (batch_size, n_query, embed_dim)
embedding_agg_method: String representing the aggregation method for embeddings
Returns: Tensor of aggregated queries of shape; can be concatenated (increased length of tokens),
Returns:
Tensor of aggregated queries of shape; can be concatenated (increased length of tokens),
stacked (increased number of tokens), or averaged (original token number and length)
"""

Expand Down Expand Up @@ -670,6 +678,7 @@ def collate_queries(queries: Tuple[torch.Tensor], embedding_agg_method: str
def spatial_emb_from_bb(bb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes embedding arrays for x,y spatial dimensions using centroids from bounding boxes
Args:
bb: Bounding boxes of shape (n_query, n_anchors, 4) from which to compute x,y centroids;
each bounding box is [ymin, xmin, ymax, xmax]
Expand Down
Loading

0 comments on commit 3ff1ab0

Please sign in to comment.