diff --git a/.gitignore b/.gitignore index 4e1fa69..3af1399 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,7 @@ dreem/training/models/* # docs site/ *.xml +dreem/training/configs/base.yaml +dreem/training/configs/override.yaml +dreem/training/configs/override.yaml +dreem/training/configs/base.yaml diff --git a/dreem/models/embedding.py b/dreem/models/embedding.py index 16dd8da..0c68a24 100644 --- a/dreem/models/embedding.py +++ b/dreem/models/embedding.py @@ -146,6 +146,7 @@ def __init__( normalize: bool = False, scale: float | None = None, mlp_cfg: dict | None = None, + embedding_agg_method: str = "average" ): """Initialize embeddings. @@ -164,12 +165,14 @@ def __init__( mlp_cfg: A dictionary of mlp hyperparameters for projecting embedding to correct space. Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3} """ + self._check_init_args(emb_type, mode) super().__init__() self.emb_type = emb_type self.mode = mode + self.embedding_agg_method = embedding_agg_method self.features = features self.emb_num = emb_num self.over_boxes = over_boxes @@ -216,12 +219,15 @@ def __init__( elif self.mode == "fixed": if self.emb_type == "pos": - self._emb_func = self._sine_box_embedding + if self.embedding_agg_method == "average": + self._emb_func = self._sine_box_embedding + else: + self._emb_func = self._sine_pos_embedding elif self.emb_type == "temp": self._emb_func = self._sine_temp_embedding elif self.mode == "rope": - # TODO: pos/temp uses the same processing but takes the input differently + # pos/temp embeddings processed the same way with different embedding array inputs self._emb_func = self._rope_embedding @@ -363,7 +369,37 @@ def _rope_embedding(self, x: torch.Tensor) -> torch.Tensor: return rot_mat - + + def _sine_pos_embedding(self, centroids: torch.Tensor) -> torch.Tensor: + """Compute fixed sine temporal embeddings per dimension (x,y) + + Args: + centroids: the input centroids for either the x,y dimension represented + by fraction of distance of original image that the instance centroid lies at; + of shape (N,) or (N,1) where N = # of query tokens (i.e. instances) + values between [0,1] + + Returns: + an n_instances x D embedding representing the temporal embedding. + """ + d = self.features + n = self.temperature + + positions = centroids.unsqueeze(1) + temp_lookup = torch.zeros(len(centroids), d, device=centroids.device) + + denominators = torch.pow( + n, 2 * torch.arange(0, d // 2, device=centroids.device) / d + ) # 10000^(2i/d_model), i is the index of embedding + temp_lookup[:, 0::2] = torch.sin( + positions / denominators + ) # sin(pos/10000^(2i/d_model)) + temp_lookup[:, 1::2] = torch.cos( + positions / denominators + ) # cos(pos/10000^(2i/d_model)) + + return temp_lookup # .view(len(times), self.features) + def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor: """Compute sine positional embeddings for boxes using given parameters. diff --git a/dreem/models/transformer.py b/dreem/models/transformer.py index c40cc33..6ff0eee 100644 --- a/dreem/models/transformer.py +++ b/dreem/models/transformer.py @@ -85,13 +85,17 @@ def __init__( pos_emb_cfg = self.embedding_meta["pos"] if pos_emb_cfg: self.pos_emb = Embedding( - emb_type="pos", features=self.d_model, **pos_emb_cfg - ) + emb_type="pos", features=self.d_model, + embedding_agg_method=self.embedding_meta["embedding_agg_method"], + **pos_emb_cfg + ) # agg method must be the same for pos and temp embeddings if "temp" in self.embedding_meta: temp_emb_cfg = self.embedding_meta["temp"] if temp_emb_cfg: self.temp_emb = Embedding( - emb_type="temp", features=self.d_model, **temp_emb_cfg + emb_type="temp", features=self.d_model, + embedding_agg_method=self.embedding_meta["embedding_agg_method"], + **temp_emb_cfg ) # Transformer Encoder @@ -178,7 +182,8 @@ def forward( encoder_queries = ref_features - encoder_features, ref_pos_emb, ref_temp_emb = self.encoder( + # (encoder_features, ref_pos_emb, ref_temp_emb) \ + encoder_features = self.encoder( encoder_queries, embedding_map={"pos": self.pos_emb, "temp": self.temp_emb}, ref_boxes=ref_boxes, @@ -187,10 +192,11 @@ def forward( ) # (total_instances, batch_size, embed_dim) # TODO: check if instance.add_embedding() supports rotation matrices - if self.return_embedding: - for i, instance in enumerate(ref_instances): - instance.add_embedding("pos", ref_pos_emb[i]) - instance.add_embedding("temp", ref_temp_emb[i]) + # TODO: include support for adding x,y,t embeddings to the instance + # if self.return_embedding: + # for i, instance in enumerate(ref_instances): + # instance.add_embedding("pos", ref_pos_emb[i]) + # instance.add_embedding("temp", ref_temp_emb[i]) # -------------- Begin decoder pre-processing --------------- # @@ -225,10 +231,11 @@ def forward( else: query_instances = ref_instances - if self.return_embedding: - for i, instance in enumerate(query_instances): - instance.add_embedding("pos", query_pos_emb[i]) - instance.add_embedding("temp", query_temp_emb[i]) + # TODO: include support for x,y,t embeddings and uncomment this + # if self.return_embedding: + # for i, instance in enumerate(query_instances): + # instance.add_embedding("pos", query_pos_emb[i]) + # instance.add_embedding("temp", query_temp_emb[i]) decoder_features = self.decoder( query_features, @@ -481,7 +488,7 @@ def forward( queries = queries.permute(1,0,2) # queries is shape (batch_size, n_query, embed_dim) # calculate temporal embeddings and transform queries queries_t, ref_temp_emb = temp_emb(queries, ref_times) - # if avg. of temp and pos, need bounding boxes + # if avg. of temp and pos, need bounding boxes; bb only used for method "average" if embedding_agg_method == "average": _, ref_pos_emb = pos_emb(queries, ref_boxes) ref_emb = (ref_pos_emb + ref_temp_emb) / 2 @@ -495,7 +502,7 @@ def forward( # concatenate or stack the queries (avg. method done above since it applies differently) queries = self.collate_queries( - (queries, queries_t, queries_x, queries_y), + (queries_t, queries_x, queries_y), embedding_agg_method) # transpose for input to EncoderLayer to (n_queries, batch_size, embed_dim) queries = queries.permute(1, 0, 2) @@ -504,7 +511,7 @@ def forward( encoder_features = self.norm(queries) - return encoder_features, ref_pos_emb, ref_temp_emb + return encoder_features# , ref_pos_emb, ref_temp_emb def collate_queries(self, queries: Tuple[torch.Tensor], embedding_agg_method: str