Skip to content

Commit

Permalink
final attn head supports stack embeddings
Browse files Browse the repository at this point in the history
- 1x1 conv for stack embedding
- stack into 3 channels for x,y,t
  • Loading branch information
shaikh58 committed Aug 9, 2024
1 parent 6af9e17 commit 6928078
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
34 changes: 25 additions & 9 deletions dreem/models/attention_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ def __init__(
super().__init__()
self.embedding_agg_method = embedding_agg_method

# if using stacked embeddings, use 1x1 conv with x,y,t embeddings as channels
# if using stacked embeddings, use 1x1 conv with x,y,t embeddings as channels
# ensures output represents ref instances by query instances
if self.embedding_agg_method == "stack":
self.conv_1x1 = torch.nn.Conv2d(in_channels=3,out_channels=1,
kernel_size=1,stride=1,padding=0)
self.q_proj = self.conv_1x1
self.k_proj = self.conv_1x1
self.q_proj = torch.nn.Conv1d(in_channels=3, out_channels=1,
kernel_size=1, stride=1, padding=0
)
self.k_proj = torch.nn.Conv1d(in_channels=3, out_channels=1,
kernel_size=1, stride=1, padding=0
)
else:
self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
Expand All @@ -51,13 +54,26 @@ def forward(
Returns:
Output tensor of shape (batch_size, num_frame_instances, num_window_instances).
"""
batch_size, num_query_instances, feature_dim = query.size()
num_window_instances = key.shape[1]

# if stacked embeddings, create channels for each x,y,t embedding dimension
# maps shape (1,192,1024) -> (1,64,3,1024)
if self.embedding_agg_method == "stack":
key =
query =
key = key.view(
batch_size, 3, num_window_instances//3, feature_dim
).permute(0, 2, 1, 3).squeeze(0)
query = query.view(
batch_size, 3, num_query_instances//3, feature_dim
).permute(0, 2, 1, 3).squeeze(0)
# key, query of shape (batch_size, num_instances, 3, feature_dim)
k = self.k_proj(key).transpose(1, 0)
q = self.q_proj(query).transpose(1, 0)
# k,q of shape (batch_size, num_instances, feature_dim)
else:
k = self.k_proj(key)
q = self.q_proj(query)

k = self.k_proj(key)
q = self.q_proj(query)
attn_weights = torch.bmm(q, k.transpose(1, 2))

return attn_weights # (B, N_t, N)
4 changes: 2 additions & 2 deletions dreem/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(
feature_dim=feature_dim_attn_head,
num_layers=num_layers_attn_head,
dropout=dropout_attn_head,
embedding_agg_method=self.embedding_meta["embedding_agg_method"]
)

self._reset_parameters()
Expand Down Expand Up @@ -242,8 +243,7 @@ def forward(

asso_output = []
for frame_features in decoder_features:
# TODO: attn_head handles the 3x queries that can come out of the encoder/decoder if using stacked embeddings;
# does this by altering the MLP dimensions prior to attention outer product
# attn_head handles the 3x queries that can come out of the encoder/decoder if using stacked embeddings
# n_query should be the number of instances in the last frame if running inference,
# or number of ref instances for training. total_instances is always the number of reference instances
asso_matrix = self.attn_head(frame_features, encoder_features).view(
Expand Down
4 changes: 3 additions & 1 deletion run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from omegaconf import OmegaConf
import os

os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
# /Users/mustafashaikh/dreem/dreem/training
# /Users/main/Documents/GitHub/dreem/dreem/training
os.chdir("/Users/mustafashaikh/dreem/dreem/training")

base_config = "./configs/base.yaml"
# params_config = "./configs/override.yaml"
Expand Down

0 comments on commit 6928078

Please sign in to comment.