-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RoPE embeddings #78
base: main
Are you sure you want to change the base?
RoPE embeddings #78
Changes from all commits
16d7bc4
6f2c7ad
b82c4d4
a07ea57
6d135fb
4714aea
9c64789
67bf6e4
fa61af0
55f5f25
9cec3a2
f02a173
287c475
a7e3a56
5d4bf5e
f23ef5c
785df8f
3d3f2ca
6af9e17
6928078
c4b1124
62f2c03
9292bbc
3751de0
3d1a35e
c4abac2
5a7e86b
9eddead
d5993a9
bcb661a
fd77ded
41454f7
64c970b
b63f24f
c320eea
21035fb
4d27914
be5e630
dba9f08
0dd6a60
e492909
4140524
a1ca23e
b5fa58d
c721e90
6711697
20fd4a7
9ac41a8
c43ee75
fe1eeca
2da8c09
65a4ae0
7c38ad4
8b552ef
8fdfba1
03df33f
1d2f5a5
5a5f75f
3ff1ab0
fe2c88e
de2ace9
9b29171
3bc9fef
1998f6f
2792d13
e4ce29c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,6 +126,8 @@ def filter_max_center_dist( | |
k_boxes: torch.Tensor | None = None, | ||
nonk_boxes: torch.Tensor | None = None, | ||
id_inds: torch.Tensor | None = None, | ||
h: int = None, | ||
w: int = None | ||
) -> torch.Tensor: | ||
"""Filter trajectory score by distances between objects across frames. | ||
|
||
|
@@ -135,6 +137,8 @@ def filter_max_center_dist( | |
k_boxes: The bounding boxes in the current frame | ||
nonk_boxes: the boxes not in the current frame | ||
id_inds: track ids | ||
h: height of image | ||
w: width of image | ||
|
||
Returns: | ||
An N_t x N association matrix | ||
|
@@ -147,13 +151,15 @@ def filter_max_center_dist( | |
k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k | ||
|
||
nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2 | ||
|
||
# TODO: nonk_boxes should be only from previous frame rather than entire window | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Address TODO comments regarding distance calculation The TODO comments indicate that:
To address these issues:
Please update the implementation to reflect these changes and remove the TODO comments once addressed. Also applies to: 158-160 |
||
dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum( | ||
dim=-1 | ||
) # n_k x Np | ||
|
||
norm_dist = dist / (k_s[:, None, :] + 1e-8) | ||
# TODO: note that dist is in units of fraction of the height and width of the image; | ||
# TODO: need to scale it by the original image size so that its in units of pixels | ||
# norm_dist = dist / (k_s[:, None, :] + 1e-8) | ||
norm_dist = dist.mean(axis=-1) # n_k x Np | ||
# norm_dist = | ||
Comment on lines
+160
to
+162
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Codebase verification Missing Tests for No existing tests found for 🔗 Analysis chainReview the change in distance normalization The distance normalization has been changed from dividing by the box size to taking the mean across the last dimension. This change might affect the behavior of the function. Please verify if this new approach aligns with the intended functionality of If the change is intentional and correct, please add a comment explaining the rationale behind this modification. 🏁 Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Description: Compare the output of the old and new implementations
# Test: Search for test cases of filter_max_center_dist
rg --type python -A 10 "def test_filter_max_center_dist"
# If no tests are found, consider adding a new test case
if [ $? -ne 0 ]; then
echo "No existing tests found for filter_max_center_dist. Consider adding a new test case."
fi
Length of output: 286 |
||
|
||
valid = norm_dist < max_center_dist # n_k x Np | ||
valid_assn = ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,8 +138,10 @@ def track( | |
# asso_preds, pred_boxes, pred_time, embeddings = self.model( | ||
# instances, reid_features | ||
# ) | ||
# get reference and query instances from TrackQueue and calls _run_global_tracker() | ||
instances_pred = self.sliding_inference(model, frames) | ||
|
||
# e.g. during train/val, don't track across batches so persistent_tracking is switched off | ||
if not self.persistent_tracking: | ||
logger.debug(f"Clearing Queue after tracking") | ||
self.track_queue.end_tracks() | ||
|
@@ -164,7 +166,9 @@ def sliding_inference( | |
# H: height. | ||
# W: width. | ||
|
||
# frames is untracked clip for inference | ||
for batch_idx, frame_to_track in enumerate(frames): | ||
# tracked_frames is a list of reference frames that have been tracked (associated) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Codebase verification Potential
🔗 Analysis chainPotential In the line You can run the following script to check the type of 🏁 Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Description: Verify that 'frame_id' is consistently a tensor.
# Test: Search for 'frame_id' assignments. Expect: 'frame_id' should be a tensor.
rg --type python --context 2 'frame_id\s*='
Length of output: 76 Script: #!/bin/bash
# Description: Verify that 'frame_id' is consistently a tensor.
# Test: Search for 'frame_id' assignments in Python files. Expect: 'frame_id' should be a tensor.
rg --glob "*.py" --context 2 'frame_id\s*='
Length of output: 7061 |
||
tracked_frames = self.track_queue.collate_tracks( | ||
device=frame_to_track.frame_id.device | ||
) | ||
|
@@ -188,10 +192,11 @@ def sliding_inference( | |
) | ||
|
||
curr_track_id = 0 | ||
# if track ids exist from another tracking program i.e. sleap, init with those | ||
for i, instance in enumerate(frames[batch_idx].instances): | ||
instance.pred_track_id = instance.gt_track_id | ||
curr_track_id = max(curr_track_id, instance.pred_track_id) | ||
|
||
# if no track ids, then assign new ones | ||
shaikh58 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for i, instance in enumerate(frames[batch_idx].instances): | ||
if instance.pred_track_id == -1: | ||
curr_track_id += 1 | ||
|
@@ -201,6 +206,7 @@ def sliding_inference( | |
if ( | ||
frame_to_track.has_instances() | ||
): # Check if there are detections. If there are skip and increment gap count | ||
# combine the tracked frames with the latest frame; inference pipeline uses latest frame as pred | ||
frames_to_track = tracked_frames + [ | ||
frame_to_track | ||
] # better var name? | ||
|
@@ -217,7 +223,7 @@ def sliding_inference( | |
self.track_queue.add_frame(frame_to_track) | ||
else: | ||
self.track_queue.increment_gaps([]) | ||
|
||
# update the frame object from the input inference untracked clip | ||
frames[batch_idx] = frame_to_track | ||
return frames | ||
|
||
|
@@ -252,7 +258,7 @@ def _run_global_tracker( | |
# E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window. | ||
|
||
_ = model.eval() | ||
|
||
# get the last frame in the clip to perform inference on | ||
query_frame = frames[query_ind] | ||
|
||
query_instances = query_frame.instances | ||
|
@@ -279,8 +285,10 @@ def _run_global_tracker( | |
|
||
# (L=1, n_query, total_instances) | ||
with torch.no_grad(): | ||
# GTR knows this is for inference since query_instances is not None | ||
asso_matrix = model(all_instances, query_instances) | ||
|
||
# GTR output is n_query x n_instances - split this into per-frame to softmax each frame separately | ||
asso_output = asso_matrix[-1].matrix.split( | ||
instances_per_frame, dim=1 | ||
) # (window_size, n_query, N_i) | ||
|
@@ -296,7 +304,7 @@ def _run_global_tracker( | |
|
||
asso_output_df.index.name = "Instances" | ||
asso_output_df.columns.name = "Instances" | ||
|
||
# save the association matrix to the Frame object | ||
query_frame.add_traj_score("asso_output", asso_output_df) | ||
query_frame.asso_output = asso_matrix[-1] | ||
|
||
|
@@ -343,6 +351,8 @@ def _run_global_tracker( | |
|
||
query_frame.add_traj_score("asso_nonquery", asso_nonquery_df) | ||
|
||
# need frame height and width to scale boxes during post-processing | ||
_, h, w = query_frame.img_shape.flatten() | ||
pred_boxes = model_utils.get_boxes(all_instances) | ||
query_boxes = pred_boxes[query_inds] # n_k x 4 | ||
nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4 | ||
|
@@ -374,7 +384,7 @@ def _run_global_tracker( | |
|
||
query_frame.add_traj_score("decay_time", decay_time_traj_score) | ||
################################################################################ | ||
|
||
# reduce association matrix - aggregating reference instance association scores by tracks | ||
# (n_query x n_nonquery) x (n_nonquery x n_traj) --> n_query x n_traj | ||
traj_score = torch.mm(traj_score, id_inds.cpu()) # (n_query, n_traj) | ||
|
||
|
@@ -387,6 +397,7 @@ def _run_global_tracker( | |
|
||
query_frame.add_traj_score("traj_score", traj_score_df) | ||
################################################################################ | ||
# IOU-based post-processing; add a weighted IOU across successive frames to association scores | ||
|
||
# with iou -> combining with location in tracker, they set to True | ||
# todo -> should also work without pos_embed | ||
|
@@ -421,11 +432,12 @@ def _run_global_tracker( | |
|
||
query_frame.add_traj_score("weight_iou", iou_traj_score) | ||
################################################################################ | ||
# filters association matrix such that instances too far from each other get scores=0 | ||
|
||
# threshold for continuing a tracking or starting a new track -> they use 1.0 | ||
# todo -> should also work without pos_embed | ||
traj_score = post_processing.filter_max_center_dist( | ||
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds | ||
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds, h, w | ||
) | ||
|
||
if self.max_center_dist is not None and self.max_center_dist > 0: | ||
|
@@ -439,6 +451,7 @@ def _run_global_tracker( | |
query_frame.add_traj_score("max_center_dist", max_center_dist_traj_score) | ||
|
||
################################################################################ | ||
# softmax along tracks for each instance, for interpretability | ||
scaled_traj_score = torch.softmax(traj_score, dim=1) | ||
scaled_traj_score_df = pd.DataFrame( | ||
scaled_traj_score.numpy(), columns=unique_ids.cpu().numpy() | ||
|
@@ -449,6 +462,7 @@ def _run_global_tracker( | |
query_frame.add_traj_score("scaled", scaled_traj_score_df) | ||
################################################################################ | ||
|
||
# hungarian matching | ||
match_i, match_j = linear_sum_assignment((-traj_score)) | ||
|
||
track_ids = instance_ids.new_full((n_query,), -1) | ||
|
@@ -462,6 +476,7 @@ def _run_global_tracker( | |
thresh = ( | ||
overlap_thresh * id_inds[:, j].sum() if mult_thresh else overlap_thresh | ||
) | ||
# if the association score for a query instance is lower than the threshold, create a new track for it | ||
if n_traj >= self.max_tracks or traj_score[i, j] > thresh: | ||
logger.debug( | ||
f"Assigning instance {i} to track {j} with id {unique_ids[j]}" | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -226,7 +226,7 @@ def get_dataset( | |||||||||||||||||
mode: str, | ||||||||||||||||||
label_files: list[str] | None = None, | ||||||||||||||||||
vid_files: list[str | list[str]] = None, | ||||||||||||||||||
) -> "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset": | ||||||||||||||||||
) -> "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None: | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Undefined type names in return type annotation The return type annotation of Consider importing the dataset classes at the top of the module: +from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset Alternatively, if you wish to avoid importing at the top level due to potential circular imports, you can use a from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset 🧰 Tools🪛 Ruff
|
||||||||||||||||||
"""Getter for datasets. | ||||||||||||||||||
|
||||||||||||||||||
Args: | ||||||||||||||||||
|
@@ -301,6 +301,10 @@ def get_dataset( | |||||||||||||||||
"Could not resolve dataset type from Config! Please include \ | ||||||||||||||||||
either `slp_files` or `tracks`/`source`" | ||||||||||||||||||
) | ||||||||||||||||||
if len(dataset) == 0: | ||||||||||||||||||
logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None") | ||||||||||||||||||
return None | ||||||||||||||||||
return dataset | ||||||||||||||||||
Comment on lines
+304
to
+307
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use The Apply this diff: - logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None")
+ logger.warning(f"Length of {mode} dataset is {len(dataset)}! Returning None") 📝 Committable suggestion
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
@property | ||||||||||||||||||
def data_paths(self): | ||||||||||||||||||
|
@@ -319,9 +323,9 @@ def data_paths(self, paths: tuple[str, list[str]]): | |||||||||||||||||
|
||||||||||||||||||
def get_dataloader( | ||||||||||||||||||
self, | ||||||||||||||||||
dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset", | ||||||||||||||||||
dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None, | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Undefined type names in parameter annotation In the parameter annotation of To fix this, import the dataset classes at the top of the module: +from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset Or, if you prefer to avoid top-level imports: from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset 🧰 Tools🪛 Ruff
|
||||||||||||||||||
mode: str, | ||||||||||||||||||
) -> torch.utils.data.DataLoader: | ||||||||||||||||||
) -> torch.utils.data.DataLoader | None: | ||||||||||||||||||
"""Getter for dataloader. | ||||||||||||||||||
|
||||||||||||||||||
Args: | ||||||||||||||||||
|
@@ -350,14 +354,21 @@ def get_dataloader( | |||||||||||||||||
else: | ||||||||||||||||||
pin_memory = False | ||||||||||||||||||
|
||||||||||||||||||
return torch.utils.data.DataLoader( | ||||||||||||||||||
dataloader = torch.utils.data.DataLoader( | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider making Currently, Apply this diff to use the - dataloader = torch.utils.data.DataLoader(
- dataset=dataset,
- batch_size=1,
- pin_memory=pin_memory,
- collate_fn=dataset.no_batching_fn,
- **dataloader_params,
- )
+ dataloader = torch.utils.data.DataLoader(
+ dataset=dataset,
+ pin_memory=pin_memory,
+ collate_fn=dataset.no_batching_fn,
+ **dataloader_params,
+ ) 📝 Committable suggestion
Suggested change
|
||||||||||||||||||
dataset=dataset, | ||||||||||||||||||
batch_size=1, | ||||||||||||||||||
pin_memory=pin_memory, | ||||||||||||||||||
collate_fn=dataset.no_batching_fn, | ||||||||||||||||||
**dataloader_params, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
if len(dataloader) == 0: | ||||||||||||||||||
logger.warn( | ||||||||||||||||||
f"Length of {mode} dataloader is {len(dataloader)}! Returning `None`" | ||||||||||||||||||
) | ||||||||||||||||||
Comment on lines
+366
to
+368
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace deprecated The Apply this diff: - logger.warn(
+ logger.warning( 📝 Committable suggestion
Suggested change
|
||||||||||||||||||
return None | ||||||||||||||||||
return dataloader | ||||||||||||||||||
|
||||||||||||||||||
def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer: | ||||||||||||||||||
"""Getter for optimizer. | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -492,7 +503,7 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint: | |||||||||||||||||
filename=f"{{epoch}}-{{{metric}}}", | ||||||||||||||||||
**checkpoint_params, | ||||||||||||||||||
) | ||||||||||||||||||
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}" | ||||||||||||||||||
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-final-{{{metric}}}" | ||||||||||||||||||
checkpointers.append(checkpointer) | ||||||||||||||||||
return checkpointers | ||||||||||||||||||
|
||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -9,23 +9,40 @@ | |||||||||||
class ATTWeightHead(torch.nn.Module): | ||||||||||||
"""Single attention head.""" | ||||||||||||
|
||||||||||||
def __init__( | ||||||||||||
self, | ||||||||||||
feature_dim: int, | ||||||||||||
num_layers: int, | ||||||||||||
dropout: float, | ||||||||||||
): | ||||||||||||
def __init__(self, feature_dim: int, num_layers: int, dropout: float, **kwargs): | ||||||||||||
"""Initialize an instance of ATTWeightHead. | ||||||||||||
|
||||||||||||
Args: | ||||||||||||
feature_dim: The dimensionality of input features. | ||||||||||||
num_layers: The number of hidden layers in the MLP. | ||||||||||||
dropout: Dropout probability. | ||||||||||||
embedding_agg_method: how the embeddings are aggregated; average/stack/concatenate | ||||||||||||
""" | ||||||||||||
super().__init__() | ||||||||||||
if "embedding_agg_method" in kwargs: | ||||||||||||
self.embedding_agg_method = kwargs["embedding_agg_method"] | ||||||||||||
else: | ||||||||||||
self.embedding_agg_method = None | ||||||||||||
Comment on lines
+22
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify the initialization of Instead of using an - if "embedding_agg_method" in kwargs:
- self.embedding_agg_method = kwargs["embedding_agg_method"]
- else:
- self.embedding_agg_method = None
+ self.embedding_agg_method = kwargs.get("embedding_agg_method", None) Committable suggestion
Suggested change
ToolsRuff
|
||||||||||||
|
||||||||||||
self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout) | ||||||||||||
self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout) | ||||||||||||
# if using stacked embeddings, use 1x1 conv with x,y,t embeddings as channels | ||||||||||||
# ensures output represents ref instances by query instances | ||||||||||||
if self.embedding_agg_method == "stack": | ||||||||||||
self.q_proj = torch.nn.Conv1d( | ||||||||||||
in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 | ||||||||||||
) | ||||||||||||
self.k_proj = torch.nn.Conv1d( | ||||||||||||
in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0 | ||||||||||||
) | ||||||||||||
self.attn_x = torch.nn.MultiheadAttention(feature_dim, 1) | ||||||||||||
self.attn_y = torch.nn.MultiheadAttention(feature_dim, 1) | ||||||||||||
self.attn_t = torch.nn.MultiheadAttention(feature_dim, 1) | ||||||||||||
else: | ||||||||||||
self.q_proj = MLP( | ||||||||||||
feature_dim, feature_dim, feature_dim, num_layers, dropout | ||||||||||||
) | ||||||||||||
self.k_proj = MLP( | ||||||||||||
feature_dim, feature_dim, feature_dim, num_layers, dropout | ||||||||||||
) | ||||||||||||
|
||||||||||||
def forward( | ||||||||||||
self, | ||||||||||||
|
@@ -41,8 +58,45 @@ def forward( | |||||||||||
Returns: | ||||||||||||
Output tensor of shape (batch_size, num_frame_instances, num_window_instances). | ||||||||||||
""" | ||||||||||||
k = self.k_proj(key) | ||||||||||||
q = self.q_proj(query) | ||||||||||||
attn_weights = torch.bmm(q, k.transpose(1, 2)) | ||||||||||||
batch_size, num_query_instances, feature_dim = query.size() | ||||||||||||
num_window_instances = key.shape[1] | ||||||||||||
|
||||||||||||
# if stacked embeddings, create channels for each x,y,t embedding dimension | ||||||||||||
# maps shape (1,num_instances*3,feature_dim) -> (num_instances,3,feature_dim) | ||||||||||||
if self.embedding_agg_method == "stack": | ||||||||||||
key_stacked = ( | ||||||||||||
key | ||||||||||||
.view(batch_size, 3, num_window_instances // 3, feature_dim) | ||||||||||||
.permute(0, 2, 1, 3) | ||||||||||||
.squeeze(0) # keep as (num_instances*3, feature_dim) | ||||||||||||
) | ||||||||||||
key_orig = key.squeeze(0) # keep as (num_instances*3, feature_dim) | ||||||||||||
|
||||||||||||
query = ( | ||||||||||||
query.view(batch_size, 3, num_query_instances // 3, feature_dim) | ||||||||||||
.permute(0, 2, 1, 3) | ||||||||||||
.squeeze(0) | ||||||||||||
) | ||||||||||||
# pass t,x,y frame features through cross attention with entire encoder 3*num_window_instances tokens before MLP; | ||||||||||||
# note order is t,x,y | ||||||||||||
out_t, _ = self.attn_t(query=query[:,0,:], key=key_orig, value=key_orig) | ||||||||||||
out_x, _ = self.attn_x(query=query[:,1,:], key=key_orig, value=key_orig) | ||||||||||||
out_y, _ = self.attn_y(query=query[:,2,:], key=key_orig, value=key_orig) | ||||||||||||
# combine each attention output to (num_instances, 3, feature_dim) | ||||||||||||
collated = torch.stack((out_t, out_x, out_y), dim=0).permute(1,0,2) | ||||||||||||
# mlp_out has shape (1, num_window_instances, feature_dim) | ||||||||||||
mlp_out = self.q_proj(collated).transpose(1,0) | ||||||||||||
|
||||||||||||
# key, query of shape (num_instances, 3, feature_dim) | ||||||||||||
# TODO: uncomment this if not using modified attention heads for t,x,y | ||||||||||||
k = self.k_proj(key_stacked).transpose(1, 0) | ||||||||||||
# q = self.q_proj(query).transpose(1, 0) | ||||||||||||
# k,q of shape (num_instances, feature_dim) | ||||||||||||
attn_weights = torch.bmm(mlp_out, k.transpose(1, 2)) | ||||||||||||
else: | ||||||||||||
k = self.k_proj(key) | ||||||||||||
q = self.q_proj(query) | ||||||||||||
attn_weights = torch.bmm(q, k.transpose(1, 2)) | ||||||||||||
|
||||||||||||
|
||||||||||||
return attn_weights # (B, N_t, N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement the usage of new parameters
h
andw
The new parameters
h
andw
have been added to the function signature, but they are not used in the function body. These parameters are intended to represent the height and width of the image for scaling the distance calculation.To address this, implement the scaling of the distance calculation using
h
andw
. Here's a suggested implementation:This change will convert the distance from normalized coordinates to pixel units.
Also applies to: 140-141