Skip to content

Commit

Permalink
embedding bug fixes for encoder
Browse files Browse the repository at this point in the history
- bounding box embedding only for method "average" - modify emb_funcs routing
- temporarily remove support for adding embeddings into instance objects - need to make compatible with x,y,t embeddings
- remove config yamls from updates - current versions serve as templates
- runs through to end of encoder forward pass
  • Loading branch information
shaikh58 committed Sep 27, 2024
1 parent c43ee75 commit fe1eeca
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 18 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,7 @@ dreem/training/models/*
# docs
site/
*.xml
dreem/training/configs/base.yaml
dreem/training/configs/override.yaml
dreem/training/configs/override.yaml
dreem/training/configs/base.yaml
42 changes: 39 additions & 3 deletions dreem/models/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
normalize: bool = False,
scale: float | None = None,
mlp_cfg: dict | None = None,
embedding_agg_method: str = "average"
):
"""Initialize embeddings.
Expand All @@ -164,12 +165,14 @@ def __init__(
mlp_cfg: A dictionary of mlp hyperparameters for projecting embedding to correct space.
Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3}
"""

self._check_init_args(emb_type, mode)

super().__init__()

self.emb_type = emb_type
self.mode = mode
self.embedding_agg_method = embedding_agg_method
self.features = features
self.emb_num = emb_num
self.over_boxes = over_boxes
Expand Down Expand Up @@ -216,12 +219,15 @@ def __init__(

elif self.mode == "fixed":
if self.emb_type == "pos":
self._emb_func = self._sine_box_embedding
if self.embedding_agg_method == "average":
self._emb_func = self._sine_box_embedding
else:
self._emb_func = self._sine_pos_embedding
elif self.emb_type == "temp":
self._emb_func = self._sine_temp_embedding

elif self.mode == "rope":
# TODO: pos/temp uses the same processing but takes the input differently
# pos/temp embeddings processed the same way with different embedding array inputs
self._emb_func = self._rope_embedding


Expand Down Expand Up @@ -363,7 +369,37 @@ def _rope_embedding(self, x: torch.Tensor) -> torch.Tensor:

return rot_mat



def _sine_pos_embedding(self, centroids: torch.Tensor) -> torch.Tensor:
"""Compute fixed sine temporal embeddings per dimension (x,y)
Args:
centroids: the input centroids for either the x,y dimension represented
by fraction of distance of original image that the instance centroid lies at;
of shape (N,) or (N,1) where N = # of query tokens (i.e. instances)
values between [0,1]
Returns:
an n_instances x D embedding representing the temporal embedding.
"""
d = self.features
n = self.temperature

positions = centroids.unsqueeze(1)
temp_lookup = torch.zeros(len(centroids), d, device=centroids.device)

denominators = torch.pow(
n, 2 * torch.arange(0, d // 2, device=centroids.device) / d
) # 10000^(2i/d_model), i is the index of embedding
temp_lookup[:, 0::2] = torch.sin(
positions / denominators
) # sin(pos/10000^(2i/d_model))
temp_lookup[:, 1::2] = torch.cos(
positions / denominators
) # cos(pos/10000^(2i/d_model))

return temp_lookup # .view(len(times), self.features)

def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
"""Compute sine positional embeddings for boxes using given parameters.
Expand Down
37 changes: 22 additions & 15 deletions dreem/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,17 @@ def __init__(
pos_emb_cfg = self.embedding_meta["pos"]
if pos_emb_cfg:
self.pos_emb = Embedding(
emb_type="pos", features=self.d_model, **pos_emb_cfg
)
emb_type="pos", features=self.d_model,
embedding_agg_method=self.embedding_meta["embedding_agg_method"],
**pos_emb_cfg
) # agg method must be the same for pos and temp embeddings
if "temp" in self.embedding_meta:
temp_emb_cfg = self.embedding_meta["temp"]
if temp_emb_cfg:
self.temp_emb = Embedding(
emb_type="temp", features=self.d_model, **temp_emb_cfg
emb_type="temp", features=self.d_model,
embedding_agg_method=self.embedding_meta["embedding_agg_method"],
**temp_emb_cfg
)

# Transformer Encoder
Expand Down Expand Up @@ -178,7 +182,8 @@ def forward(

encoder_queries = ref_features

encoder_features, ref_pos_emb, ref_temp_emb = self.encoder(
# (encoder_features, ref_pos_emb, ref_temp_emb) \
encoder_features = self.encoder(
encoder_queries,
embedding_map={"pos": self.pos_emb, "temp": self.temp_emb},
ref_boxes=ref_boxes,
Expand All @@ -187,10 +192,11 @@ def forward(
) # (total_instances, batch_size, embed_dim)

# TODO: check if instance.add_embedding() supports rotation matrices
if self.return_embedding:
for i, instance in enumerate(ref_instances):
instance.add_embedding("pos", ref_pos_emb[i])
instance.add_embedding("temp", ref_temp_emb[i])
# TODO: include support for adding x,y,t embeddings to the instance
# if self.return_embedding:
# for i, instance in enumerate(ref_instances):
# instance.add_embedding("pos", ref_pos_emb[i])
# instance.add_embedding("temp", ref_temp_emb[i])

# -------------- Begin decoder pre-processing --------------- #

Expand Down Expand Up @@ -225,10 +231,11 @@ def forward(
else:
query_instances = ref_instances

if self.return_embedding:
for i, instance in enumerate(query_instances):
instance.add_embedding("pos", query_pos_emb[i])
instance.add_embedding("temp", query_temp_emb[i])
# TODO: include support for x,y,t embeddings and uncomment this
# if self.return_embedding:
# for i, instance in enumerate(query_instances):
# instance.add_embedding("pos", query_pos_emb[i])
# instance.add_embedding("temp", query_temp_emb[i])

decoder_features = self.decoder(
query_features,
Expand Down Expand Up @@ -481,7 +488,7 @@ def forward(
queries = queries.permute(1,0,2) # queries is shape (batch_size, n_query, embed_dim)
# calculate temporal embeddings and transform queries
queries_t, ref_temp_emb = temp_emb(queries, ref_times)
# if avg. of temp and pos, need bounding boxes
# if avg. of temp and pos, need bounding boxes; bb only used for method "average"
if embedding_agg_method == "average":
_, ref_pos_emb = pos_emb(queries, ref_boxes)
ref_emb = (ref_pos_emb + ref_temp_emb) / 2
Expand All @@ -495,7 +502,7 @@ def forward(

# concatenate or stack the queries (avg. method done above since it applies differently)
queries = self.collate_queries(
(queries, queries_t, queries_x, queries_y),
(queries_t, queries_x, queries_y),
embedding_agg_method)
# transpose for input to EncoderLayer to (n_queries, batch_size, embed_dim)
queries = queries.permute(1, 0, 2)
Expand All @@ -504,7 +511,7 @@ def forward(

encoder_features = self.norm(queries)

return encoder_features, ref_pos_emb, ref_temp_emb
return encoder_features# , ref_pos_emb, ref_temp_emb


def collate_queries(self, queries: Tuple[torch.Tensor], embedding_agg_method: str
Expand Down

0 comments on commit fe1eeca

Please sign in to comment.