diff --git a/.gitignore b/.gitignore index fb6ee365..3af13992 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,8 @@ dreem/training/models/* # docs site/ +*.xml +dreem/training/configs/base.yaml +dreem/training/configs/override.yaml +dreem/training/configs/override.yaml +dreem/training/configs/base.yaml diff --git a/dreem/datasets/sleap_dataset.py b/dreem/datasets/sleap_dataset.py index 7a98fa4f..a0ded6ed 100644 --- a/dreem/datasets/sleap_dataset.py +++ b/dreem/datasets/sleap_dataset.py @@ -415,4 +415,4 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram def __del__(self): """Handle file closing before garbage collection.""" for reader in self.videos: - reader.close() + reader.close() \ No newline at end of file diff --git a/dreem/inference/eval.py b/dreem/inference/eval.py index 1262099d..000a9b88 100644 --- a/dreem/inference/eval.py +++ b/dreem/inference/eval.py @@ -77,4 +77,4 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: # override with params config, and specific params: # python eval.py --config-dir=./configs --config-name=inference +params_config=configs/params.yaml dataset.train_dataset.padding=10 - run() + run() \ No newline at end of file diff --git a/dreem/inference/post_processing.py b/dreem/inference/post_processing.py index 09fd8ff3..a64739ce 100644 --- a/dreem/inference/post_processing.py +++ b/dreem/inference/post_processing.py @@ -126,6 +126,8 @@ def filter_max_center_dist( k_boxes: torch.Tensor | None = None, nonk_boxes: torch.Tensor | None = None, id_inds: torch.Tensor | None = None, + h: int = None, + w: int = None ) -> torch.Tensor: """Filter trajectory score by distances between objects across frames. @@ -135,6 +137,8 @@ def filter_max_center_dist( k_boxes: The bounding boxes in the current frame nonk_boxes: the boxes not in the current frame id_inds: track ids + h: height of image + w: width of image Returns: An N_t x N association matrix @@ -147,13 +151,15 @@ def filter_max_center_dist( k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2 - + # TODO: nonk_boxes should be only from previous frame rather than entire window dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum( dim=-1 ) # n_k x Np - - norm_dist = dist / (k_s[:, None, :] + 1e-8) + # TODO: note that dist is in units of fraction of the height and width of the image; + # TODO: need to scale it by the original image size so that its in units of pixels + # norm_dist = dist / (k_s[:, None, :] + 1e-8) norm_dist = dist.mean(axis=-1) # n_k x Np + # norm_dist = valid = norm_dist < max_center_dist # n_k x Np valid_assn = ( diff --git a/dreem/inference/track.py b/dreem/inference/track.py index 6818ecad..91c289d4 100644 --- a/dreem/inference/track.py +++ b/dreem/inference/track.py @@ -163,4 +163,4 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]: # override with params config, and specific params: # python train.py --config-dir=./configs --config-name=inference +params_config=configs/params.yaml dataset.train_dataset.padding=10 - run() + run() \ No newline at end of file diff --git a/dreem/inference/tracker.py b/dreem/inference/tracker.py index 9f2e7137..279a24de 100644 --- a/dreem/inference/tracker.py +++ b/dreem/inference/tracker.py @@ -138,8 +138,10 @@ def track( # asso_preds, pred_boxes, pred_time, embeddings = self.model( # instances, reid_features # ) + # get reference and query instances from TrackQueue and calls _run_global_tracker() instances_pred = self.sliding_inference(model, frames) + # e.g. during train/val, don't track across batches so persistent_tracking is switched off if not self.persistent_tracking: logger.debug(f"Clearing Queue after tracking") self.track_queue.end_tracks() @@ -164,7 +166,9 @@ def sliding_inference( # H: height. # W: width. + # frames is untracked clip for inference for batch_idx, frame_to_track in enumerate(frames): + # tracked_frames is a list of reference frames that have been tracked (associated) tracked_frames = self.track_queue.collate_tracks( device=frame_to_track.frame_id.device ) @@ -188,10 +192,11 @@ def sliding_inference( ) curr_track_id = 0 + # if track ids exist from another tracking program i.e. sleap, init with those for i, instance in enumerate(frames[batch_idx].instances): instance.pred_track_id = instance.gt_track_id curr_track_id = max(curr_track_id, instance.pred_track_id) - + # if no track ids, then assign new ones for i, instance in enumerate(frames[batch_idx].instances): if instance.pred_track_id == -1: curr_track_id += 1 @@ -201,6 +206,7 @@ def sliding_inference( if ( frame_to_track.has_instances() ): # Check if there are detections. If there are skip and increment gap count + # combine the tracked frames with the latest frame; inference pipeline uses latest frame as pred frames_to_track = tracked_frames + [ frame_to_track ] # better var name? @@ -217,7 +223,7 @@ def sliding_inference( self.track_queue.add_frame(frame_to_track) else: self.track_queue.increment_gaps([]) - + # update the frame object from the input inference untracked clip frames[batch_idx] = frame_to_track return frames @@ -252,7 +258,7 @@ def _run_global_tracker( # E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window. _ = model.eval() - + # get the last frame in the clip to perform inference on query_frame = frames[query_ind] query_instances = query_frame.instances @@ -279,8 +285,10 @@ def _run_global_tracker( # (L=1, n_query, total_instances) with torch.no_grad(): + # GTR knows this is for inference since query_instances is not None asso_matrix = model(all_instances, query_instances) + # GTR output is n_query x n_instances - split this into per-frame to softmax each frame separately asso_output = asso_matrix[-1].matrix.split( instances_per_frame, dim=1 ) # (window_size, n_query, N_i) @@ -296,7 +304,7 @@ def _run_global_tracker( asso_output_df.index.name = "Instances" asso_output_df.columns.name = "Instances" - + # save the association matrix to the Frame object query_frame.add_traj_score("asso_output", asso_output_df) query_frame.asso_output = asso_matrix[-1] @@ -343,6 +351,8 @@ def _run_global_tracker( query_frame.add_traj_score("asso_nonquery", asso_nonquery_df) + # need frame height and width to scale boxes during post-processing + _, h, w = query_frame.img_shape.flatten() pred_boxes = model_utils.get_boxes(all_instances) query_boxes = pred_boxes[query_inds] # n_k x 4 nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4 @@ -374,7 +384,7 @@ def _run_global_tracker( query_frame.add_traj_score("decay_time", decay_time_traj_score) ################################################################################ - + # reduce association matrix - aggregating reference instance association scores by tracks # (n_query x n_nonquery) x (n_nonquery x n_traj) --> n_query x n_traj traj_score = torch.mm(traj_score, id_inds.cpu()) # (n_query, n_traj) @@ -387,6 +397,7 @@ def _run_global_tracker( query_frame.add_traj_score("traj_score", traj_score_df) ################################################################################ + # IOU-based post-processing; add a weighted IOU across successive frames to association scores # with iou -> combining with location in tracker, they set to True # todo -> should also work without pos_embed @@ -421,11 +432,12 @@ def _run_global_tracker( query_frame.add_traj_score("weight_iou", iou_traj_score) ################################################################################ + # filters association matrix such that instances too far from each other get scores=0 # threshold for continuing a tracking or starting a new track -> they use 1.0 # todo -> should also work without pos_embed traj_score = post_processing.filter_max_center_dist( - traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds + traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds, h, w ) if self.max_center_dist is not None and self.max_center_dist > 0: @@ -439,6 +451,7 @@ def _run_global_tracker( query_frame.add_traj_score("max_center_dist", max_center_dist_traj_score) ################################################################################ + # softmax along tracks for each instance, for interpretability scaled_traj_score = torch.softmax(traj_score, dim=1) scaled_traj_score_df = pd.DataFrame( scaled_traj_score.numpy(), columns=unique_ids.cpu().numpy() @@ -449,6 +462,7 @@ def _run_global_tracker( query_frame.add_traj_score("scaled", scaled_traj_score_df) ################################################################################ + # hungarian matching match_i, match_j = linear_sum_assignment((-traj_score)) track_ids = instance_ids.new_full((n_query,), -1) @@ -462,6 +476,7 @@ def _run_global_tracker( thresh = ( overlap_thresh * id_inds[:, j].sum() if mult_thresh else overlap_thresh ) + # if the association score for a query instance is lower than the threshold, create a new track for it if n_traj >= self.max_tracks or traj_score[i, j] > thresh: logger.debug( f"Assigning instance {i} to track {j} with id {unique_ids[j]}" diff --git a/dreem/io/config.py b/dreem/io/config.py index 4ac8105a..7f7b5477 100644 --- a/dreem/io/config.py +++ b/dreem/io/config.py @@ -226,7 +226,7 @@ def get_dataset( mode: str, label_files: list[str] | None = None, vid_files: list[str | list[str]] = None, - ) -> "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset": + ) -> "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None: """Getter for datasets. Args: @@ -301,6 +301,10 @@ def get_dataset( "Could not resolve dataset type from Config! Please include \ either `slp_files` or `tracks`/`source`" ) + if len(dataset) == 0: + logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None") + return None + return dataset @property def data_paths(self): @@ -319,9 +323,9 @@ def data_paths(self, paths: tuple[str, list[str]]): def get_dataloader( self, - dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset", + dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None, mode: str, - ) -> torch.utils.data.DataLoader: + ) -> torch.utils.data.DataLoader | None: """Getter for dataloader. Args: @@ -350,7 +354,7 @@ def get_dataloader( else: pin_memory = False - return torch.utils.data.DataLoader( + dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=1, pin_memory=pin_memory, @@ -358,6 +362,13 @@ def get_dataloader( **dataloader_params, ) + if len(dataloader) == 0: + logger.warn( + f"Length of {mode} dataloader is {len(dataloader)}! Returning `None`" + ) + return None + return dataloader + def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer: """Getter for optimizer. @@ -492,7 +503,7 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint: filename=f"{{epoch}}-{{{metric}}}", **checkpoint_params, ) - checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}" + checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-final-{{{metric}}}" checkpointers.append(checkpointer) return checkpointers diff --git a/dreem/io/instance.py b/dreem/io/instance.py index 65be3c02..ba97182b 100644 --- a/dreem/io/instance.py +++ b/dreem/io/instance.py @@ -565,7 +565,11 @@ def add_embedding(self, emb_type: str, embedding: torch.Tensor) -> None: emb_type: Key/embedding type to be saved to dictionary embedding: The actual torch tensor embedding. """ - embedding = _expand_to_rank(embedding, 2) + if ( + type(embedding) != dict + ): # for embedding agg method "average", input is array + # for method stack and concatenate, input is dict + embedding = _expand_to_rank(embedding, 2) self._embeddings[emb_type] = embedding @property diff --git a/dreem/models/attention_head.py b/dreem/models/attention_head.py index 2b160552..701dc6f6 100644 --- a/dreem/models/attention_head.py +++ b/dreem/models/attention_head.py @@ -9,23 +9,40 @@ class ATTWeightHead(torch.nn.Module): """Single attention head.""" - def __init__( - self, - feature_dim: int, - num_layers: int, - dropout: float, - ): + def __init__(self, feature_dim: int, num_layers: int, dropout: float, **kwargs): """Initialize an instance of ATTWeightHead. Args: feature_dim: The dimensionality of input features. num_layers: The number of hidden layers in the MLP. dropout: Dropout probability. + embedding_agg_method: how the embeddings are aggregated; average/stack/concatenate """ super().__init__() + if "embedding_agg_method" in kwargs: + self.embedding_agg_method = kwargs["embedding_agg_method"] + else: + self.embedding_agg_method = None - self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout) - self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout) + # if using stacked embeddings, use 1x1 conv with x,y,t embeddings as channels + # ensures output represents ref instances by query instances + if self.embedding_agg_method == "stack": + self.q_proj = torch.nn.Conv1d( + in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + ) + self.k_proj = torch.nn.Conv1d( + in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 + ) + self.attn_x = torch.nn.MultiheadAttention(feature_dim, 1) + self.attn_y = torch.nn.MultiheadAttention(feature_dim, 1) + self.attn_t = torch.nn.MultiheadAttention(feature_dim, 1) + else: + self.q_proj = MLP( + feature_dim, feature_dim, feature_dim, num_layers, dropout + ) + self.k_proj = MLP( + feature_dim, feature_dim, feature_dim, num_layers, dropout + ) def forward( self, @@ -41,8 +58,45 @@ def forward( Returns: Output tensor of shape (batch_size, num_frame_instances, num_window_instances). """ - k = self.k_proj(key) - q = self.q_proj(query) - attn_weights = torch.bmm(q, k.transpose(1, 2)) + batch_size, num_query_instances, feature_dim = query.size() + num_window_instances = key.shape[1] + + # if stacked embeddings, create channels for each x,y,t embedding dimension + # maps shape (1,num_instances*3,feature_dim) -> (num_instances,3,feature_dim) + if self.embedding_agg_method == "stack": + key_stacked = ( + key + .view(batch_size, 3, num_window_instances // 3, feature_dim) + .permute(0, 2, 1, 3) + .squeeze(0) # keep as (num_instances*3, feature_dim) + ) + key_orig = key.squeeze(0) # keep as (num_instances*3, feature_dim) + + query = ( + query.view(batch_size, 3, num_query_instances // 3, feature_dim) + .permute(0, 2, 1, 3) + .squeeze(0) + ) + # pass t,x,y frame features through cross attention with entire encoder 3*num_window_instances tokens before MLP; + # note order is t,x,y + out_t, _ = self.attn_t(query=query[:,0,:], key=key_orig, value=key_orig) + out_x, _ = self.attn_x(query=query[:,1,:], key=key_orig, value=key_orig) + out_y, _ = self.attn_y(query=query[:,2,:], key=key_orig, value=key_orig) + # combine each attention output to (num_instances, 3, feature_dim) + collated = torch.stack((out_t, out_x, out_y), dim=0).permute(1,0,2) + # mlp_out has shape (1, num_window_instances, feature_dim) + mlp_out = self.q_proj(collated).transpose(1,0) + + # key, query of shape (num_instances, 3, feature_dim) + # TODO: uncomment this if not using modified attention heads for t,x,y + k = self.k_proj(key_stacked).transpose(1, 0) + # q = self.q_proj(query).transpose(1, 0) + # k,q of shape (num_instances, feature_dim) + attn_weights = torch.bmm(mlp_out, k.transpose(1, 2)) + else: + k = self.k_proj(key) + q = self.q_proj(query) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + return attn_weights # (B, N_t, N) diff --git a/dreem/models/embedding.py b/dreem/models/embedding.py index 8dd44577..bbae71d4 100644 --- a/dreem/models/embedding.py +++ b/dreem/models/embedding.py @@ -3,15 +3,128 @@ import math import torch import logging +from torch import nn, Tensor +from typing import Optional from dreem.models.mlp import MLP logger = logging.getLogger("dreem.models") # todo: add named tensors, clean variable names +class RotaryPositionalEmbeddings(nn.Module): + """ + This class implements Rotary Positional Embeddings (RoPE) + proposed in https://arxiv.org/abs/2104.09864. + + Reference implementation (used for correctness verfication) + can be found here: + https://github.com/meta-llama/llama/blob/main/llama/model.py#L80 + + In this implementation we cache the embeddings for each position upto + ``max_seq_len`` by computing this during init. + + Args: + dim (int): Embedding dimension. This is usually set to the dim of each + head in the attention module computed as ````embed_dim`` // ``num_heads```` + max_seq_len (int): Maximum expected sequence length for the + model, if exceeded the cached freqs will be recomputed + base (int): The base for the geometric progression used to compute + the rotation angles + """ + + def __init__( + self, + dim: int, + # max_seq_len: int, + base: int = 10000, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + # self.max_seq_len = max_seq_len + self._rope_init() + + # We need to explicitly define reset_parameters for FSDP initialization, see + # https://github.com/pytorch/pytorch/blob/797d4fbdf423dd9320ebe383fb57ffb1135c4a99/torch/distributed/fsdp/_init_utils.py#L885 + def reset_parameters(self): + self._rope_init() + + def _rope_init(self): + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + self.register_buffer("theta", theta, persistent=False) + + def build_rope_cache(self, max_seq_len: int) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + # cache includes both the cos and sin components and so the output shape is + # [max_seq_len, dim // 2, 2] + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + """ + Args: + x (Tensor): input tensor with shape + [b, s, n_h, h_d] + input_pos (Optional[Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b, s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Returns: + Tensor: output tensor with RoPE applied + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + + """ + # input tensor has shape [b, s, n_h, h_d] + seq_len = x.size(1) + + # create the lookup array based on how many instances there are + # max(101, seq_len) is for positional vs temporal; pos can only have idx up to + # 100 since it's a fraction of [0,1]*100. temp is from [0, clip_len]; since clip_len + # not available, we use the last value in the indexing array since this will be the + # last possible frame that we would need to index since no instances in a frame after that + if input_pos.dim() <= 1: input_pos = input_pos.unsqueeze(0) + self.build_rope_cache(max(101, input_pos[:, -1].max() + 1)) # registers cache + self.cache = self.cache.to(input_pos.device) + # extract the values based on whether input_pos is set or not + rope_cache = ( + self.cache[:seq_len] if input_pos is None else self.cache[input_pos] + ) + + # reshape input; the last dimension is used for computing the output. + # Cast to float to match the reference implementation + # tensor has shape [b, s, n_h, h_d // 2, 2] + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, + # otherwise has shape [1, s, 1, h_d // 2, 2] + rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + + return rope_cache + + class Embedding(torch.nn.Module): """Class that wraps around different embedding types. - + Creates embedding array and transforms the input data Used for both learned and fixed embeddings. """ @@ -24,6 +137,7 @@ class Embedding(torch.nn.Module): EMB_MODES = { "fixed": {"temperature", "scale", "normalize"}, "learned": {"emb_num"}, + "rope": {"embedding_agg_method"}, "off": {}, } # dict of valid args:keyword params @@ -39,6 +153,7 @@ def __init__( normalize: bool = False, scale: float | None = None, mlp_cfg: dict | None = None, + embedding_agg_method: str = "average", ): """Initialize embeddings. @@ -57,12 +172,12 @@ def __init__( mlp_cfg: A dictionary of mlp hyperparameters for projecting embedding to correct space. Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3} """ - self._check_init_args(emb_type, mode) super().__init__() self.emb_type = emb_type self.mode = mode + self.embedding_agg_method = embedding_agg_method self.features = features self.emb_num = emb_num self.over_boxes = over_boxes @@ -71,6 +186,8 @@ def __init__( self.scale = scale self.n_points = n_points + self._check_init_args(emb_type, mode) + if self.normalize and self.scale is None: self.scale = 2 * math.pi @@ -91,8 +208,8 @@ def __init__( else: self.mlp = torch.nn.Identity() - self._emb_func = lambda tensor: torch.zeros( - (tensor.shape[0], self.features), dtype=tensor.dtype, device=tensor.device + self._emb_func = lambda seq, x: torch.zeros( + (seq.shape[0], self.features), dtype=seq.dtype, device=seq.device ) # turn off embedding by returning zeros self.lookup = None @@ -109,10 +226,19 @@ def __init__( elif self.mode == "fixed": if self.emb_type == "pos": - self._emb_func = self._sine_box_embedding + if self.embedding_agg_method == "average": + self._emb_func = self._sine_box_embedding + else: # if using stacked/concatenated agg method + self._emb_func = self._sine_pos_embedding elif self.emb_type == "temp": self._emb_func = self._sine_temp_embedding + elif self.mode == "rope": + # pos/temp embeddings processed the same way with different embedding array inputs + self._emb_func = self._rope_embedding + # create instance so embedding lookup array is created only once + self.rope_instance = RotaryPositionalEmbeddings(self.features) + def _check_init_args(self, emb_type: str, mode: str): """Check whether the correct arguments were passed to initialization. @@ -136,27 +262,86 @@ def _check_init_args(self, emb_type: str, mode: str): f"Embedding `mode` must be one of {self.EMB_MODES} not {mode}" ) - def forward(self, seq_positions: torch.Tensor) -> torch.Tensor: + if mode.lower() == "rope" and self.embedding_agg_method == "average": + raise ValueError( + f"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'" + ) + + def _transform(self, x, emb): + """Routes to the relevant embedding function to transform the input queries + + Args: + x: Input queries of shape (batch_size, N, embed_dim) + emb: Embedding array to apply to data; can be (N, embed_dim) or + (batch_size, n_query, num_heads, embed_dim // 2, 2) if using RoPE + """ + if self._emb_func == self._rope_embedding: + return self._apply_rope(x, emb) + else: + return self._apply_additive_embeddings(x, emb) + + def _apply_rope(self, x, emb): + """Applies Rotary Positional Embedding to input queries + + Args: + x: Input queries of shape (batch_size, n_query, embed_dim) + emb: Rotation matrix of shape (batch_size, n_query, num_heads, embed_dim // 2, 2) + + Returns: + Tensor of input queries transformed by RoPE + """ + + xout = torch.unsqueeze(x, 2) + # input needs shape [batch_size, n_query, num_heads, embed_dim // 2, 2] + xout = xout.float().reshape(*xout.shape[:-1], -1, 2) + # apply RoPE to each query token + xout = torch.stack( + [ + xout[..., 0] * emb[..., 0] - xout[..., 1] * emb[..., 1], + xout[..., 1] * emb[..., 0] + xout[..., 0] * emb[..., 1], + ], + -1, + ) + # output has shape [batch_size, n_query, num_heads, embed_dim] + xout = xout.flatten(3).squeeze(2) + + return xout + + def _apply_additive_embeddings(self, x, emb): + """Applies additive embeddings to input queries + + Args: + x: Input tensor of shape (batch_size, N, embed_dim) + emb: Embedding array of shape (N, embed_dim) + + Returns: + Tensor: Input queries with embeddings added - shape (batch_size, N, embed_dim) + """ + _emb = emb.unsqueeze(0) + return x + _emb + + def forward(self, x, seq_positions: torch.Tensor) -> torch.Tensor: """Get the sequence positional embeddings. Args: seq_positions: - * An (`N`, 1) tensor where seq_positions[i] represents the temporal position of instance_i in the sequence. - * An (`N`, n_anchors x 4) tensor where seq_positions[i, j, :] represents the [y1, x1, y2, x2] spatial locations of jth point of instance_i in the sequence. + * An (N,) tensor where seq_positions[i] represents the temporal position of instance_i in the sequence. + * An (N, n_anchors x 4) tensor where seq_positions[i, j, :] represents the [y1, x1, y2, x2] spatial locations of jth point of instance_i in the sequence. + x: Input data of shape ((batch_size, N, embed_dim)) Returns: - An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding. + - Tensor: input queries transformed by embedding + - An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding. """ - emb = self._emb_func(seq_positions) - if emb.shape[-1] != self.features: - raise RuntimeError( - ( - f"Output embedding dimension is {emb.shape[-1]} but requested {self.features} dimensions! \n" - f"hint: Try turning the MLP on by passing `mlp_cfg` to the constructor to project to the correct embedding dimensions." - ) - ) - return emb + # create embedding array; either rotation matrix of shape + # (batch_size, n_query, num_heads, embed_dim // 2, 2), + # or (N, embed_dim) array + emb = self._emb_func(seq_positions, x.size()) + # transform the input data with the embedding + xout = self._transform(x, emb) + + return xout, emb def _torch_int_div( self, tensor1: torch.Tensor, tensor2: torch.Tensor @@ -172,7 +357,61 @@ def _torch_int_div( """ return torch.div(tensor1, tensor2, rounding_mode="floor") - def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor: + def _rope_embedding( + self, seq_positions: torch.Tensor, input_shape: torch.Size + ) -> torch.Tensor: + """Computes the rotation matrix to apply RoPE to input queries + Args: + seq_positions: Pos array of shape (embed_dim,) used to compute rotational embedding + input_shape: Shape of the input queries; needed for rope + Returns: + Tensor: (N, embed_dim) rotation matrix + """ + # create dummy input of shape (num_batches, num_instances, num_attn_heads, embed_dim) + # use num_heads=1 for compatibility with torch ROPE + x_rope = torch.rand(input_shape).unsqueeze(2) + # infer whether it is a positional or temporal embedding + is_pos_emb = 1 if seq_positions.max() < 1 else 0 + # if it is positional, scale seq_positions since these are fractions + # in [0,1] and we need int indexes for embedding lookup + seq_positions = seq_positions * 100 if is_pos_emb else seq_positions + seq_positions = seq_positions.unsqueeze(0).int() + # RoPE module takes in dimension, num_queries as input to calculate rotation matrix + rot_mat = self.rope_instance(x_rope, seq_positions) + + return rot_mat + + def _sine_pos_embedding(self, centroids: torch.Tensor, *args) -> torch.Tensor: + """Compute fixed sine temporal embeddings per dimension (x,y) + + Args: + centroids: the input centroids for either the x,y dimension represented + by fraction of distance of original image that the instance centroid lies at; + of shape (N,) or (N,1) where N = # of query tokens (i.e. instances) + values between [0,1] + + Returns: + an n_instances x D embedding representing the temporal embedding. + """ + d = self.features + n = self.temperature + + positions = centroids.unsqueeze(1) + temp_lookup = torch.zeros(len(centroids), d, device=centroids.device) + + denominators = torch.pow( + n, 2 * torch.arange(0, d // 2, device=centroids.device) / d + ) # 10000^(2i/d_model), i is the index of embedding + temp_lookup[:, 0::2] = torch.sin( + positions / denominators + ) # sin(pos/10000^(2i/d_model)) + temp_lookup[:, 1::2] = torch.cos( + positions / denominators + ) # cos(pos/10000^(2i/d_model)) + + return temp_lookup # .view(len(times), self.features) + + def _sine_box_embedding(self, boxes: torch.Tensor, *args) -> torch.Tensor: """Compute sine positional embeddings for boxes using given parameters. Args: @@ -217,7 +456,7 @@ def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor: return pos_emb - def _sine_temp_embedding(self, times: torch.Tensor) -> torch.Tensor: + def _sine_temp_embedding(self, times: torch.Tensor, *args) -> torch.Tensor: """Compute fixed sine temporal embeddings. Args: @@ -249,7 +488,7 @@ def _sine_temp_embedding(self, times: torch.Tensor) -> torch.Tensor: temp_emb = temp_lookup[times.int()] return temp_emb # .view(len(times), self.features) - def _learned_pos_embedding(self, boxes: torch.Tensor) -> torch.Tensor: + def _learned_pos_embedding(self, boxes: torch.Tensor, *args) -> torch.Tensor: """Compute learned positional embeddings for boxes using given parameters. Args: @@ -309,7 +548,7 @@ def _learned_pos_embedding(self, boxes: torch.Tensor) -> torch.Tensor: return pos_emb.view(N, self.features) - def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor: + def _learned_temp_embedding(self, times: torch.Tensor, *args) -> torch.Tensor: """Compute learned temporal embeddings for times using given parameters. Args: @@ -323,6 +562,7 @@ def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor: """ temp_lookup = self.lookup N = times.shape[0] + times = times / times.max() left_ind, right_ind, left_weight, right_weight = self._compute_weights(times) @@ -337,7 +577,7 @@ def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor: return temp_emb.view(N, self.features) - def _compute_weights(self, data: torch.Tensor) -> tuple[torch.Tensor, ...]: + def _compute_weights(self, data: torch.Tensor, *args) -> tuple[torch.Tensor, ...]: """Compute left and right learned embedding weights. Args: diff --git a/dreem/models/gtr_runner.py b/dreem/models/gtr_runner.py index e9fd9f84..162099dd 100644 --- a/dreem/models/gtr_runner.py +++ b/dreem/models/gtr_runner.py @@ -300,7 +300,7 @@ def on_test_epoch_end(self): avg_result = results_df[key].mean() results_file.attrs.create(key, avg_result) for i, (metrics, frames) in enumerate(zip(metrics_dict, preds)): - vid_name = frames[0].vid_name.split("/")[-1].split(".")[0] + vid_name = frames[0].vid_name.split("/")[-1] vid_group = results_file.require_group(vid_name) clip_group = vid_group.require_group(f"clip_{i}") for key, val in metrics.items(): @@ -309,11 +309,18 @@ def on_test_epoch_end(self): if metrics.get("num_switches", 0) > 0: _ = frame.to_h5( clip_group, - frame.get_gt_track_ids().cpu().numpy(), + [ + instance.gt_track_id.item() + for instance in frame.instances + ], save={"crop": True, "features": True, "embeddings": True}, ) else: _ = frame.to_h5( - clip_group, frame.get_gt_track_ids().cpu().numpy() + clip_group, + [ + instance.gt_track_id.item() + for instance in frame.instances + ], ) self.test_results = {"metrics": [], "preds": [], "save_path": fname} diff --git a/dreem/models/mlp.py b/dreem/models/mlp.py index 872d7150..c497ab87 100644 --- a/dreem/models/mlp.py +++ b/dreem/models/mlp.py @@ -34,6 +34,7 @@ def __init__( self.layers = torch.nn.ModuleList( [ torch.nn.Linear(n, k) + # list concatenations to ensure layer shape compability for n, k in zip([input_dim] + h, h + [output_dim]) ] ) @@ -54,6 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Output tensor of shape (batch_size, num_instances, output_dim). """ for i, layer in enumerate(self.layers): + layer.to(x.device) x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) if i < self.num_layers - 1 and self.dropout > 0.0: x = self.dropouts[i](x) diff --git a/dreem/models/transformer.py b/dreem/models/transformer.py index 8db00e0e..ce91b356 100644 --- a/dreem/models/transformer.py +++ b/dreem/models/transformer.py @@ -14,11 +14,13 @@ from dreem.io import AssociationMatrix from dreem.models.attention_head import ATTWeightHead from dreem.models import Embedding, FourierPositionalEmbeddings +from dreem.models.mlp import MLP from dreem.models.model_utils import get_boxes, get_times from torch import nn import copy import torch import torch.nn.functional as F +from typing import Dict, Tuple # todo: add named tensors # todo: add flash attention @@ -82,22 +84,32 @@ def __init__( self.temp_emb = Embedding(emb_type="off", mode="off", features=self.d_model) if self.embedding_meta: + self.embedding_agg_method = ( + embedding_meta["embedding_agg_method"] + if "embedding_agg_method" in embedding_meta + else "average" + ) if "pos" in self.embedding_meta: pos_emb_cfg = self.embedding_meta["pos"] if pos_emb_cfg: self.pos_emb = Embedding( - emb_type="pos", features=self.d_model, **pos_emb_cfg - ) + emb_type="pos", + features=self.d_model, + embedding_agg_method=self.embedding_agg_method, + **pos_emb_cfg, + ) # agg method must be the same for pos and temp embeddings if "temp" in self.embedding_meta: temp_emb_cfg = self.embedding_meta["temp"] if temp_emb_cfg: self.temp_emb = Embedding( - emb_type="temp", features=self.d_model, **temp_emb_cfg + emb_type="temp", + features=self.d_model, + embedding_agg_method=self.embedding_agg_method, + **temp_emb_cfg, ) - - self.fourier_embeddings = FourierPositionalEmbeddings( - n_components=8, d_model=d_model - ) + else: + self.embedding_meta = {} + self.embedding_agg_method = None # Transformer Encoder encoder_layer = TransformerEncoderLayer( @@ -140,6 +152,7 @@ def __init__( feature_dim=feature_dim_attn_head, num_layers=num_layers_attn_head, dropout=dropout_attn_head, + **self.embedding_meta, ) self._reset_parameters() @@ -175,7 +188,6 @@ def forward( [instance.features for instance in ref_instances], dim=0 ).unsqueeze(0) - # window_length = len(frames) # instances_per_frame = [frame.num_detected for frame in frames] total_instances = len(ref_instances) embed_dim = self.d_model @@ -184,29 +196,10 @@ def forward( ref_boxes = torch.nan_to_num(ref_boxes, -1.0) ref_times, query_times = get_times(ref_instances, query_instances) - window_length = len(ref_times.unique()) - - ref_temp_emb = self.temp_emb(ref_times) - - ref_pos_emb = self.pos_emb(ref_boxes) - - if self.return_embedding: - for i, instance in enumerate(ref_instances): - instance.add_embedding("pos", ref_pos_emb[i]) - instance.add_embedding("temp", ref_temp_emb[i]) - - ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0 - - ref_emb = ref_emb.view(1, total_instances, embed_dim) - - ref_emb = ref_emb.permute(1, 0, 2) # (total_instances, batch_size, embed_dim) - - batch_size, total_instances = ref_features.shape[:-1] - + batch_size, total_instances, embed_dim = ref_features.shape ref_features = ref_features.permute( 1, 0, 2 ) # (total_instances, batch_size, embed_dim) - encoder_queries = ref_features # apply fourier embeddings if using fourier rope, OR if using descriptor (compact) visual encoder @@ -228,20 +221,34 @@ def forward( self.fourier_norm, ) - encoder_features = self.encoder( - encoder_queries, pos_emb=ref_emb - ) # (total_instances, batch_size, embed_dim) + # (encoder_features, ref_pos_emb, ref_temp_emb) \ + encoder_features, pos_emb_traceback, temp_emb_traceback = self.encoder( + encoder_queries, + embedding_map={"pos": self.pos_emb, "temp": self.temp_emb}, + boxes=ref_boxes, + times=ref_times, + embedding_agg_method=self.embedding_agg_method, + ) # (total_instances, batch_size, embed_dim) or + # (3*total_instances,batch_size,embed_dim) if using stacked embeddings + + if self.return_embedding: + for i, instance in enumerate(ref_instances): + if self.embedding_agg_method == "average": + ref_pos_emb = pos_emb_traceback[0][i] # array + else: + ref_pos_emb = { + "x": pos_emb_traceback[0][0][i], + "y": pos_emb_traceback[1][0][i], + } # dict - n_query = total_instances + instance.add_embedding("pos", ref_pos_emb) # can be an array or a dict + instance.add_embedding("temp", temp_emb_traceback) - query_features = ref_features - query_pos_emb = ref_pos_emb - query_temp_emb = ref_temp_emb - query_emb = ref_emb + # -------------- Begin decoder --------------- # + # for inference, query_instances is not None if query_instances is not None: n_query = len(query_instances) - query_features = torch.cat( [instance.features for instance in query_instances], dim=0 ).unsqueeze(0) @@ -250,25 +257,16 @@ def forward( 1, 0, 2 ) # (n_query, batch_size, embed_dim) + # just get boxes, we already have query_times from above query_boxes = get_boxes(query_instances) query_boxes = torch.nan_to_num(query_boxes, -1.0) - query_temp_emb = self.temp_emb(query_times) - - query_pos_emb = self.pos_emb(query_boxes) - - query_emb = (query_pos_emb + query_temp_emb) / 2.0 - query_emb = query_emb.view(1, n_query, embed_dim) - query_emb = query_emb.permute(1, 0, 2) # (n_query, batch_size, embed_dim) - - else: + else: # for training, query_instances is None so just pass in the ref data + n_query = total_instances query_instances = ref_instances + query_features = ref_features + query_boxes = ref_boxes query_times = ref_times - if self.return_embedding: - for i, instance in enumerate(query_instances): - instance.add_embedding("pos", query_pos_emb[i]) - instance.add_embedding("temp", query_temp_emb[i]) - # apply fourier embeddings if using fourier rope, OR if using descriptor (compact) visual encoder if ( self.embedding_meta @@ -288,25 +286,44 @@ def forward( self.fourier_norm, ) - decoder_features = self.decoder( + decoder_features, pos_emb_traceback, temp_emb_traceback = self.decoder( query_features, encoder_features, - ref_pos_emb=ref_emb, - query_pos_emb=query_emb, + embedding_map={"pos": self.pos_emb, "temp": self.temp_emb}, + enc_boxes=ref_boxes, + enc_times=ref_times, + boxes=query_boxes, + times=query_times, + embedding_agg_method=self.embedding_agg_method, ) # (L, n_query, batch_size, embed_dim) + if self.return_embedding: + for i, instance in enumerate(ref_instances): + if self.embedding_agg_method == "average": + ref_pos_emb = pos_emb_traceback[0][i] # array + else: + ref_pos_emb = { + "x": pos_emb_traceback[0][0][i], + "y": pos_emb_traceback[1][0][i], + } # dict + + instance.add_embedding("pos", ref_pos_emb) # can be an array or a dict + instance.add_embedding("temp", temp_emb_traceback) + decoder_features = decoder_features.transpose( 1, 2 - ) # # (L, batch_size, n_query, embed_dim) - encoder_features = encoder_features.permute(1, 0, 2).view( - batch_size, total_instances, embed_dim - ) # (batch_size, total_instances, embed_dim) + ) # # (L, batch_size, n_query, embed_dim) or ((L, batch_size, 3*n_query, embed_dim)) if using stacked embeddings + encoder_features = encoder_features.permute(1, 0, 2) + # (batch_size, total_instances, embed_dim) or (batch_size, 3*total_instances, embed_dim) asso_output = [] for frame_features in decoder_features: + # attn_head handles the 3x queries that can come out of the encoder/decoder if using stacked embeddings + # n_query should be the number of instances in the last frame if running inference, + # or number of ref instances for training. total_instances is always the number of reference instances asso_matrix = self.attn_head(frame_features, encoder_features).view( n_query, total_instances - ) + ) # call to view() just removes the batch dimension; output of attn_head is (1,n_query,total_instances) asso_matrix = AssociationMatrix(asso_matrix, ref_instances, query_instances) asso_output.append(asso_matrix) @@ -381,24 +398,16 @@ def __init__( self.activation = _get_activation_fn(activation) - def forward( - self, queries: torch.Tensor, pos_emb: torch.Tensor = None - ) -> torch.Tensor: + def forward(self, queries: torch.Tensor) -> torch.Tensor: """Execute a forward pass of the encoder layer. Args: - queries: Input sequence for encoder (n_query, batch_size, embed_dim). - pos_emb: Position embedding, if provided is added to src + queries: Input sequence for encoder (n_query, batch_size, embed_dim); + data is already transformed with embedding Returns: The output tensor of shape (n_query, batch_size, embed_dim). """ - if pos_emb is None: - pos_emb = torch.zeros_like(queries) - - queries = queries + pos_emb - - # q = k = src attn_features = self.self_attn( query=queries, @@ -465,8 +474,6 @@ def forward( self, decoder_queries: torch.Tensor, encoder_features: torch.Tensor, - ref_pos_emb: torch.Tensor | None = None, - query_pos_emb: torch.Tensor | None = None, ) -> torch.Tensor: """Execute forward pass of decoder layer. @@ -474,19 +481,10 @@ def forward( decoder_queries: Target sequence for decoder to generate (n_query, batch_size, embed_dim). encoder_features: Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim) - ref_pos_emb: The input positional embedding tensor of shape (n_query, embed_dim). - query_pos_emb: The target positional embedding of shape (n_query, embed_dim) Returns: The output tensor of shape (n_query, batch_size, embed_dim). """ - if query_pos_emb is None: - query_pos_emb = torch.zeros_like(decoder_queries) - if ref_pos_emb is None: - ref_pos_emb = torch.zeros_like(encoder_features) - - decoder_queries = decoder_queries + query_pos_emb - encoder_features = encoder_features + ref_pos_emb if self.decoder_self_attn: self_attn_features = self.self_attn( @@ -495,6 +493,7 @@ def forward( decoder_queries = decoder_queries + self.dropout1(self_attn_features) decoder_queries = self.norm1(decoder_queries) + # cross attention x_attn_features = self.multihead_attn( query=decoder_queries, # (n_query, batch_size, embed_dim) key=encoder_features, # (total_instances, batch_size, embed_dim) @@ -543,22 +542,38 @@ def __init__( self.norm = norm if norm is not None else nn.Identity() def forward( - self, queries: torch.Tensor, pos_emb: torch.Tensor = None - ) -> torch.Tensor: - """Execute a forward pass of encoder layer. + self, + queries: torch.Tensor, + embedding_map: Dict[str, Embedding], + boxes: torch.Tensor, + times: torch.Tensor, + embedding_agg_method: str = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Execute a forward pass of encoder layer. Computes and applies embeddings before input to EncoderLayer Args: queries: The input tensor of shape (n_query, batch_size, embed_dim). - pos_emb: The positional embedding tensor of shape (n_query, embed_dim). + embedding_map: Dict of Embedding objects defining the pos/temp embeddings to be applied to + the input data before it passes to the EncoderLayer + boxes: Bounding box based embedding ids of shape (n_query, batch_size, 4) + times: + embedding_agg_method: Returns: The output tensor of shape (n_query, batch_size, embed_dim). """ + for layer in self.layers: - queries = layer(queries, pos_emb=pos_emb) + # compute embeddings and apply to the input queries + queries, pos_emb_traceback, temp_emb_traceback = apply_embeddings( + queries, embedding_map, boxes, times, embedding_agg_method + ) + # pass through EncoderLayer + queries = layer(queries) encoder_features = self.norm(queries) - return encoder_features + + return encoder_features, pos_emb_traceback, temp_emb_traceback class TransformerDecoder(nn.Module): @@ -589,8 +604,12 @@ def forward( self, decoder_queries: torch.Tensor, encoder_features: torch.Tensor, - ref_pos_emb: torch.Tensor | None = None, - query_pos_emb: torch.Tensor | None = None, + embedding_map: Dict[str, Embedding], + enc_boxes: torch.Tensor, + enc_times: torch.Tensor, + boxes: torch.Tensor, + times: torch.Tensor, + embedding_agg_method: str = None, ) -> torch.Tensor: """Execute a forward pass of the decoder block. @@ -598,23 +617,33 @@ def forward( decoder_queries: Query sequence for decoder to generate (n_query, batch_size, embed_dim). encoder_features: Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim) - ref_pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim). - query_pos_emb: The query positional embedding of shape (n_query, batch_size, embed_dim) + Returns: The output tensor of shape (L, n_query, batch_size, embed_dim). """ decoder_features = decoder_queries - intermediate = [] - for layer in self.layers: - decoder_features = layer( - decoder_features, + # since the encoder output doesn't change for any number of decoder layer inputs, + # we can process its embedding outside the loop + if embedding_agg_method == "average": + encoder_features, *_ = apply_embeddings( encoder_features, - ref_pos_emb=ref_pos_emb, - query_pos_emb=query_pos_emb, + embedding_map, + enc_boxes, + enc_times, + embedding_agg_method, ) + # TODO: ^ should embeddings really be applied to encoder output again before cross attention? + # switched off for stack and concatenate methods as those further split the tokens. Kept for "average" + # for backward compatibility + + for layer in self.layers: + decoder_features, pos_emb_traceback, temp_emb_traceback = apply_embeddings( + decoder_features, embedding_map, boxes, times, embedding_agg_method + ) + decoder_features = layer(decoder_features, encoder_features) if self.return_intermediate: intermediate.append(self.norm(decoder_features)) @@ -622,10 +651,66 @@ def forward( if self.return_intermediate: intermediate.pop() intermediate.append(decoder_features) + return torch.stack(intermediate), pos_emb_traceback, temp_emb_traceback + + return decoder_features.unsqueeze(0), pos_emb_traceback, temp_emb_traceback + + +def apply_embeddings( + queries: torch.Tensor, + embedding_map: Dict[str, Embedding], + boxes: torch.Tensor, + times: torch.Tensor, + embedding_agg_method: str, +): + """Applies embeddings to input queries for various aggregation methods. This function + is called from the transformer encoder and decoder + + Args: + queries: The input tensor of shape (n_query, batch_size, embed_dim). + embedding_map: Dict of Embedding objects defining the pos/temp embeddings to be applied + to the input data + boxes: Bounding box based embedding ids of shape (n_query, n_anchors, 4) + times: Times based embedding ids of shape (n_query,) + embedding_agg_method: method of aggregation of embeddings e.g. stack/concatenate/average + """ + + pos_emb, temp_emb = embedding_map["pos"], embedding_map["temp"] + # queries is of shape (n_query, batch_size, embed_dim); transpose for embeddings + queries = queries.permute( + 1, 0, 2 + ) # queries is shape (batch_size, n_query, embed_dim) + # calculate temporal embeddings and transform queries + queries_t, ref_temp_emb = temp_emb(queries, times) + + if embedding_agg_method is None: + pos_emb_traceback = (torch.zeros_like(queries),) + queries_avg = queries_t = queries_x = queries_y = None + else: + # if avg. of temp and pos, need bounding boxes; bb only used for method "average" + if embedding_agg_method == "average": + _, ref_pos_emb = pos_emb(queries, boxes) + ref_emb = (ref_pos_emb + ref_temp_emb) / 2 + queries_avg = queries + ref_emb + queries_t = queries_x = queries_y = None + pos_emb_traceback = (ref_pos_emb,) + else: + # calculate embedding array for x,y from bb centroids; ref_x, ref_y of shape (n_query,) + ref_x, ref_y = spatial_emb_from_bb(boxes) + # forward pass of Embedding object transforms input queries with embeddings + queries_x, ref_pos_emb_x = pos_emb(queries, ref_x) + queries_y, ref_pos_emb_y = pos_emb(queries, ref_y) + queries_avg = None # pass dummy var in to collate_queries + pos_emb_traceback = (ref_pos_emb_x, ref_pos_emb_y) - return torch.stack(intermediate) + # concatenate or stack the queries (avg. method done above since it applies differently) + queries = collate_queries( + (queries_avg, queries_t, queries_x, queries_y, queries), embedding_agg_method + ) + # transpose for input to EncoderLayer to (n_queries, batch_size, embed_dim) + queries = queries.permute(1, 0, 2) - return decoder_features.unsqueeze(0) + return queries, pos_emb_traceback, ref_temp_emb def _get_clones(module: nn.Module, N: int) -> nn.ModuleList: @@ -657,3 +742,62 @@ def _get_activation_fn(activation: str) -> callable: if activation == "glu": return F.glu raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") + + +def collate_queries( + queries: Tuple[torch.Tensor], embedding_agg_method: str +) -> torch.Tensor: + """Aggregates queries transformed by embeddings + + Args: + _queries: 5-tuple of queries (already transformed by embeddings) for _, x, y, t, original input + each of shape (batch_size, n_query, embed_dim) + embedding_agg_method: String representing the aggregation method for embeddings + + Returns: + Tensor of aggregated queries of shape; can be concatenated (increased length of tokens), + stacked (increased number of tokens), or averaged (original token number and length) + """ + + queries_avg, queries_t, queries_x, queries_y, orig_queries = queries + + if embedding_agg_method == "average": + collated_queries = queries_avg + elif embedding_agg_method == "stack": + # (t1,t2,t3...),(x1,x2,x3...),(y1,y2,y3...) + # stacked is of shape (batch_size, 3*n_query, embed_dim) + collated_queries = torch.cat((queries_t, queries_x, queries_y), dim=1) + elif embedding_agg_method == "concatenate": + mlp = MLP( + input_dim=queries_t.shape[-1] * 3, + hidden_dim=queries_t.shape[-1] * 2, + output_dim=queries_t.shape[-1], + num_layers=1, + dropout=0.0, + ) + # concatenated is of shape (batch_size, n_query, 3*embed_dim) + collated_queries = torch.cat((queries_t, queries_x, queries_y), dim=2) + # pass through MLP to project into space of (batch_size, n_query, embed_dim) + collated_queries = mlp(collated_queries) + else: + collated_queries = orig_queries + + return collated_queries + + +def spatial_emb_from_bb(bb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes embedding arrays for x,y spatial dimensions using centroids from bounding boxes + + Args: + bb: Bounding boxes of shape (n_query, n_anchors, 4) from which to compute x,y centroids; + each bounding box is [ymin, xmin, ymax, xmax] + + Returns: + A tuple of tensors containing the emebdding array for x,y dimensions, each of shape (n_query,) + """ + # compute avg of xmin,xmax and ymin,ymax + return ( + bb[:, :, [1, 3]].mean(axis=2).squeeze(), + bb[:, :, [0, 2]].mean(axis=2).squeeze(), + ) diff --git a/dreem/training/configs/base.yaml b/dreem/training/configs/base.yaml index af5eb9b4..41bdbe6a 100644 --- a/dreem/training/configs/base.yaml +++ b/dreem/training/configs/base.yaml @@ -29,10 +29,11 @@ model: dropout_attn_head: 0.1 embedding_meta: pos: - mode: "fixed" + mode: "fixed" # supports fixed, learned, rope normalize: true temp: - mode: "fixed" + mode: "fixed" # supports fixed, learned, rope + embedding_agg_method: "stack" # supports stack, average, concatenate return_embedding: False decoder_self_attn: False @@ -97,6 +98,7 @@ dataset: crop_size: 128 chunk: true clip_length: 32 + mode: "train" val_dataset: slp_files: ["../../tests/data/sleap/two_flies.slp"] @@ -105,6 +107,7 @@ dataset: crop_size: 128 chunk: True clip_length: 32 + mode: "val" test_dataset: slp_files: ["../../tests/data/sleap/two_flies.slp"] @@ -113,6 +116,7 @@ dataset: crop_size: 128 chunk: True clip_length: 32 + mode: "test" dataloader: train_dataloader: @@ -135,7 +139,7 @@ logging: group: "example" save_dir: './logs' project: "GTR" - log_model: "all" + log_model: null early_stopping: monitor: "val_loss" @@ -152,9 +156,12 @@ checkpointing: save_last: true dirpath: null auto_insert_metric_name: true - every_n_epochs: 10 + every_n_epochs: 1 trainer: + # only use this for local apple silicon runs; change for cluster runs + # accelerator: "mps" + # devices: 1 check_val_every_n_epoch: 1 enable_checkpointing: true gradient_clip_val: null @@ -162,8 +169,8 @@ trainer: limit_test_batches: 1.0 limit_val_batches: 1.0 log_every_n_steps: 1 - max_epochs: 100 - min_epochs: 10 + max_epochs: 1 + min_epochs: 1 view_batch: enable: False diff --git a/dreem/training/configs/override.yaml b/dreem/training/configs/override.yaml new file mode 100644 index 00000000..6d1ccb48 --- /dev/null +++ b/dreem/training/configs/override.yaml @@ -0,0 +1,142 @@ +model: + ckpt_path: null + encoder_cfg: + model_name: "resnet18" + in_chans: 3 + backend: "torchvision" + pretrained: false + d_model: 128 + nhead: 1 + num_encoder_layers: 1 + num_decoder_layers: 1 + dropout: 0.1 + activation: "relu" + return_intermediate_dec: True + norm: False + num_layers_attn_head: 1 + dropout_attn_head: 0.1 + embedding_meta: + pos: + mode: "fixed" + normalize: true + n_points: 1 + temp: + mode: "fixed" + return_embedding: False + decoder_self_attn: True + +loss: + epsilon: 0.0001 + asso_weight: 10.0 + +optimizer: + lr: 0.0001 + weight_decay: 0 + +scheduler: + factor: 0.5 + patience: 5 + threshold: 0.001 + +dataset: + train_dataset: + dir: + # note: if using batch runner, use format: /home/runner/talmodata-smb/... + # if using interactive, use format: "/home/jovyan/talmolab-smb/datasets/..." + path: "/home/runner/talmolab-smb/datasets/mot/animal/sleap/btc/large_run/als/train" + labels_suffix: ".slp" + vid_suffix: ".mp4" + clip_length: 32 + crop_size: 64 + padding: 0 + anchors: "centroid" + augmentations: + Rotate: + limit: 45 + p: 0.3 + GaussianBlur: + blur_limit: [3,7] + sigma_limit: 0 + p: 0.3 + RandomBrightnessContrast: + brightness_limit: 0.1 + contrast_limit: 0.3 + p: 0.3 + MotionBlur: + blur_limit: [3,7] + p: 0.3 + NodeDropout: + p: 0.3 + n: 5 + InstanceDropout: + p: 0.3 + n: 1 + n_chunks: 1000 + handle_missing: "centroid" + + val_dataset: + dir: + # note: if using batch runner, use format: /home/runner/talmodata-smb/... + path: "/home/runner/talmolab-smb/datasets/mot/animal/sleap/btc/large_run/als/val" + labels_suffix: ".slp" + vid_suffix: ".mp4" + crop_size: 64 + padding: 0 + anchors: "centroid" + n_chunks: 300 + handle_missing: "centroid" + + # to not run test, just use empty lists to override the paths in the base.yaml + test_dataset: + slp_files: [] + video_files: [] + +dataloader: + train_dataloader: + num_workers: 0 + val_dataloader: + num_workers: 0 + test_dataloader: + num_workers: 0 + +checkpointing: + save_top_k: -1 + +trainer: + max_epochs: 50 + min_epochs: -1 + # limit_train_batches: 0.001 + # limit_test_batches: 1.0 + # limit_val_batches: 0.004 + # profiler: "advanced" + + +logging: + project: "dreem" + group: "test-batch-job" # experiment/test + entity: "mushaikh" + name: "sample-efficiency" # name of the run (within a group) + notes: "test `dreem-train" + logger_type: "WandbLogger" + +tracker: + window_size: 8 + use_vis_feats: true + overlap_thresh: 0.1 + mult_thresh: true + decay_time: null + iou: null + max_center_dist: null + +runner: + persistent_tracking: + train: false + val: false + test: false + metrics: + train: [] + +# view_batch: +# enable: True +# num_frames: 5 +# no_train: True \ No newline at end of file diff --git a/dreem/training/configs/test_batch_train.csv b/dreem/training/configs/test_batch_train.csv deleted file mode 100644 index a0303c7b..00000000 --- a/dreem/training/configs/test_batch_train.csv +++ /dev/null @@ -1,4 +0,0 @@ -model.d_model,model.dim_feedforward,model.feature_dim_attn_head,model.num_encoder_layers,model.num_decoder_layers -256,256,256,1,1 -512,512,512,2,2 -1024,1024,1024,4,4 diff --git a/dreem/training/train.py b/dreem/training/train.py index b6fbd3d4..dc05a294 100644 --- a/dreem/training/train.py +++ b/dreem/training/train.py @@ -54,6 +54,7 @@ def run(cfg: DictConfig): logger.info(f"Final train config: {train_cfg}") model = train_cfg.get_model() + train_dataset = train_cfg.get_dataset(mode="train") train_dataloader = train_cfg.get_dataloader(train_dataset, mode="train") @@ -96,6 +97,7 @@ def run(cfg: DictConfig): callbacks.append(early_stopping) accelerator = "gpu" if torch.cuda.is_available() else "cpu" + devices = torch.cuda.device_count() if torch.cuda.is_available() else cpu_count() trainer = train_cfg.get_trainer( diff --git a/scripts/run_eval.py b/scripts/run_eval.py new file mode 100644 index 00000000..a433852b --- /dev/null +++ b/scripts/run_eval.py @@ -0,0 +1,12 @@ +from dreem.training import train +from omegaconf import OmegaConf + +# /Users/mustafashaikh/dreem/dreem/training +# /Users/main/Documents/GitHub/dreem/dreem/training + + +inference_config = "tests/configs/inference.yaml" + +cfg = OmegaConf.load(inference_config) + +eval.run(cfg) \ No newline at end of file diff --git a/scripts/run_tracker.py b/scripts/run_tracker.py new file mode 100644 index 00000000..f4aefe94 --- /dev/null +++ b/scripts/run_tracker.py @@ -0,0 +1,12 @@ +from dreem.inference import track +from omegaconf import OmegaConf +import os + +# /Users/mustafashaikh/dreem/dreem/training +# /Users/main/Documents/GitHub/dreem/dreem/training +# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training") +config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml" + +cfg = OmegaConf.load(config) + +track.run(cfg) \ No newline at end of file diff --git a/scripts/run_trainer.py b/scripts/run_trainer.py new file mode 100644 index 00000000..f6829e32 --- /dev/null +++ b/scripts/run_trainer.py @@ -0,0 +1,16 @@ +from dreem.training import train +from omegaconf import OmegaConf +import os + +# /Users/mustafashaikh/dreem/dreem/training +# /Users/main/Documents/GitHub/dreem/dreem/training +# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training") +base_config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/base-updated.yaml" +params_config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/override-updated.yaml" + +cfg = OmegaConf.load(base_config) +# Load and merge override config +override_cfg = OmegaConf.load(params_config) +cfg = OmegaConf.merge(cfg, override_cfg) + +train.run(cfg) \ No newline at end of file diff --git a/tests/test_inference.py b/tests/test_inference.py index 2b55484e..7f580ddb 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -215,6 +215,8 @@ def test_post_processing(): # set_default_device k_boxes=k_boxes, nonk_boxes=nonk_boxes, id_inds=id_inds, + h=im_size, + w=im_size ) ).all() @@ -226,6 +228,8 @@ def test_post_processing(): # set_default_device k_boxes=k_boxes, nonk_boxes=nonk_boxes, id_inds=id_inds, + h=im_size, + w=im_size ) ).all() diff --git a/tests/test_models.py b/tests/test_models.py index 3eaf9c22..bdf17f0d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,6 +14,8 @@ from dreem.models.transformer import ( TransformerEncoderLayer, TransformerDecoderLayer, + spatial_emb_from_bb, + apply_embeddings, ) @@ -33,7 +35,9 @@ def test_att_weight_head(): """Test self-attention head logic.""" b, n, f = 1, 10, 1024 # batch size, num instances, features - att_weight_head = ATTWeightHead(feature_dim=f, num_layers=2, dropout=0.1) + att_weight_head = ATTWeightHead( + feature_dim=f, num_layers=2, dropout=0.1, embedding_agg_method="average" + ) q = k = torch.rand(size=(b, n, f)) @@ -161,10 +165,87 @@ def test_embedding_validity(): with pytest.raises(Exception): _ = Embedding(emb_type="temporal", mode="learn", features=128) + with pytest.raises(Exception): + # embedding_agg_method cannot be average for rope + _ = Embedding( + emb_type="pos", mode="rope", features=128, embedding_agg_method="average" + ) + _ = Embedding( + emb_type="pos", mode="rope", features=128, embedding_agg_method="stacked" + ) + + _ = Embedding( + emb_type="pos", mode="rope", features=128, embedding_agg_method="stack" + ) + _ = Embedding( + emb_type="pos", mode="rope", features=128, embedding_agg_method="concatenate" + ) + + _ = Embedding( + emb_type="pos", mode="fixed", features=128, embedding_agg_method="average" + ) + _ = Embedding( + emb_type="pos", mode="fixed", features=128, embedding_agg_method="stack" + ) + _ = Embedding( + emb_type="pos", mode="fixed", features=128, embedding_agg_method="concatenate" + ) + + _ = Embedding( + emb_type="pos", mode="learned", features=128, embedding_agg_method="average" + ) + _ = Embedding( + emb_type="pos", mode="learned", features=128, embedding_agg_method="stack" + ) + _ = Embedding( + emb_type="pos", mode="learned", features=128, embedding_agg_method="concatenate" + ) + _ = Embedding(emb_type="temp", mode="learned", features=128) _ = Embedding(emb_type="pos", mode="learned", features=128) - _ = Embedding(emb_type="pos", mode="learned", features=128) + +def test_rope_embedding(): + "Test RoPE embedding" + frames = 32 + objects = 10 + d_model = 256 + n_anchors = 1 + + N = frames * objects + + boxes = torch.rand(size=(N, n_anchors, 4)) + times = torch.rand(size=(N,)) + # input data of shape (batch_size, N, num_heads, embed_dim) + x = torch.rand(size=(1, N, d_model)) + + pos_emb = Embedding( + emb_type="pos", mode="rope", features=d_model, embedding_agg_method="stack" + ) + temp_emb = Embedding( + emb_type="temp", mode="rope", features=d_model, embedding_agg_method="stack" + ) + + ref_x, ref_y = spatial_emb_from_bb(boxes) + x_rope, rot_mat_x = pos_emb(x, ref_x) + y_rope, rot_mat_y = pos_emb(x, ref_y) + t_rope, ref_temp_emb = temp_emb(x, times) + + assert x_rope.size() == (1, N, d_model) + assert y_rope.size() == (1, N, d_model) + assert t_rope.size() == (1, N, d_model) + + assert not torch.equal(x, x_rope) + assert not torch.equal(x, y_rope) + assert not torch.equal(x, t_rope) + + assert not torch.equal(x_rope, y_rope) + assert not torch.equal(x_rope, t_rope) + assert not torch.equal(y_rope, t_rope) + + assert ref_x.size() == ref_y.size() + assert x_rope.size() == x.size() + assert y_rope.size() == x.size() def test_embedding_basic(): @@ -179,6 +260,8 @@ def test_embedding_basic(): boxes = torch.rand(size=(N, n_anchors, 4)) times = torch.rand(size=(N,)) + # input data of shape (batch_size, N, embed_dim) + x = torch.rand(size=(1, N, d_model)) pos_emb = Embedding( emb_type="pos", @@ -189,31 +272,31 @@ def test_embedding_basic(): scale=10, ) - sine_pos_emb = pos_emb(boxes) + _, sine_pos_emb = pos_emb(x, boxes) pos_emb = Embedding(emb_type="pos", mode="learned", features=d_model, emb_num=100) - learned_pos_emb = pos_emb(boxes) + _, learned_pos_emb = pos_emb(x, boxes) temp_emb = Embedding(emb_type="temp", mode="learned", features=d_model, emb_num=16) - learned_temp_emb = temp_emb(times) + _, learned_temp_emb = temp_emb(x, times) pos_emb_off = Embedding(emb_type="pos", mode="off", features=d_model) - off_pos_emb = pos_emb_off(boxes) + _, off_pos_emb = pos_emb_off(x, boxes) temp_emb_off = Embedding(emb_type="temp", mode="off", features=d_model) - off_temp_emb = temp_emb_off(times) + _, off_temp_emb = temp_emb_off(x, times) learned_emb_off = Embedding(emb_type="off", mode="learned", features=d_model) - off_learned_emb_boxes = learned_emb_off(boxes) - off_learned_emb_times = learned_emb_off(times) + _, off_learned_emb_boxes = learned_emb_off(x, boxes) + _, off_learned_emb_times = learned_emb_off(x, times) fixed_emb_off = Embedding(emb_type="off", mode="fixed", features=d_model) - off_fixed_emb_boxes = fixed_emb_off(boxes) - off_fixed_emb_times = fixed_emb_off(times) + _, off_fixed_emb_boxes = fixed_emb_off(x, boxes) + _, off_fixed_emb_times = fixed_emb_off(x, times) off_emb = Embedding(emb_type="off", mode="off", features=d_model) - off_emb_boxes = off_emb(boxes) - off_emb_times = off_emb(times) + _, off_emb_boxes = off_emb(x, boxes) + _, off_emb_times = off_emb(x, times) assert sine_pos_emb.size() == (N, d_model) assert learned_pos_emb.size() == (N, d_model) @@ -247,12 +330,14 @@ def test_embedding_kwargs(): frames = 32 objects = 10 + d_model = 128 N = frames * objects n_anchors = 1 boxes = torch.rand(N, n_anchors, 4) - + # input data of shape (batch_size, N, embed_dim) + x = torch.rand(size=(1, N, d_model)) # sine embedding sine_args = { @@ -260,32 +345,32 @@ def test_embedding_kwargs(): "scale": frames, "normalize": True, } - sine_no_args = Embedding("pos", "fixed", 128) - sine_with_args = Embedding("pos", "fixed", 128, **sine_args) + sine_no_args = Embedding("pos", "fixed", d_model) + sine_with_args = Embedding("pos", "fixed", d_model, **sine_args) assert sine_no_args.temperature != sine_with_args.temperature - sine_no_args = sine_no_args(boxes) - sine_with_args = sine_with_args(boxes) + _, sine_no_args = sine_no_args(x, boxes) + _, sine_with_args = sine_with_args(x, boxes) assert not torch.equal(sine_no_args, sine_with_args) # learned pos embedding - lp_no_args = Embedding("pos", "learned", 128) + lp_no_args = Embedding("pos", "learned", d_model) lp_args = {"emb_num": 100, "over_boxes": False} - lp_with_args = Embedding("pos", "learned", 128, **lp_args) + lp_with_args = Embedding("pos", "learned", d_model, **lp_args) assert lp_no_args.lookup.weight.shape != lp_with_args.lookup.weight.shape # learned temp embedding - lt_no_args = Embedding("temp", "learned", 128) + lt_no_args = Embedding("temp", "learned", d_model) lt_args = {"emb_num": 100} - lt_with_args = Embedding("temp", "learned", 128, **lt_args) + lt_with_args = Embedding("temp", "learned", d_model, **lt_args) assert lt_no_args.lookup.weight.shape != lt_with_args.lookup.weight.shape @@ -299,6 +384,8 @@ def test_multianchor_embedding(): N = frames * objects boxes = torch.rand(size=(N, n_anchors, 4)) + # input data of shape (batch_size, N, embed_dim) + x = torch.rand(size=(1, N, d_model)) fixed_emb = Embedding( "pos", @@ -317,18 +404,18 @@ def test_multianchor_embedding(): assert not isinstance(fixed_emb.mlp, torch.nn.Identity) assert not isinstance(learned_emb.mlp, torch.nn.Identity) - emb = fixed_emb(boxes) + _, emb = fixed_emb(x, boxes) assert emb.size() == (N, features) - emb = learned_emb(boxes) + _, emb = learned_emb(x, boxes) assert emb.size() == (N, features) fixed_emb = Embedding("pos", "fixed", features=features) learned_emb = Embedding("pos", "learned", features=features) with pytest.raises(RuntimeError): - _ = fixed_emb(boxes) + _, _ = fixed_emb(x, boxes) with pytest.raises(RuntimeError): - _ = learned_emb(boxes) + _, _ = learned_emb(x, boxes) def test_transformer_encoder(): @@ -351,7 +438,7 @@ def test_transformer_encoder(): # with position pos_emb = torch.ones_like(queries) - encoder_features = transformer_encoder(queries, pos_emb=pos_emb) + encoder_features = transformer_encoder(queries) assert encoder_features.size() == encoder_features.size() @@ -381,12 +468,7 @@ def test_transformer_decoder(): # with position pos_emb = query_pos_emb = torch.ones_like(encoder_features) - decoder_features = transformer_decoder( - decoder_queries, - encoder_features, - ref_pos_emb=pos_emb, - query_pos_emb=query_pos_emb, - ) + decoder_features = transformer_decoder(decoder_queries, encoder_features) assert decoder_features.size() == decoder_queries.size() @@ -397,8 +479,13 @@ def test_transformer_basic(): num_frames = 32 num_detected = 10 img_shape = (1, 100, 100) - - transformer = Transformer(d_model=feats, num_encoder_layers=1, num_decoder_layers=1) + embedding_meta = {"embedding_agg_method": "stack"} + transformer = Transformer( + d_model=feats, + num_encoder_layers=1, + num_decoder_layers=1, + embedding_meta=embedding_meta, + ) frames = [] @@ -444,6 +531,7 @@ def test_transformer_embedding(): embedding_meta = { "pos": {"mode": "learned", "emb_num": 16, "normalize": True}, "temp": {"mode": "learned", "emb_num": 16, "normalize": True}, + "embedding_agg_method": "average", } transformer = Transformer( diff --git a/tests/test_training.py b/tests/test_training.py index bd8bbe75..57295100 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -142,5 +142,5 @@ def test_config_gtr_runner(tmp_path, base_config, params_config, two_flies): } cfg.set_hparams(hparams) - with torch.autograd.set_detect_anomaly(True): - run(cfg.cfg) + # with torch.autograd.set_detect_anomaly(True): + run(cfg.cfg)