Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE embeddings #78

Open
wants to merge 66 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
16d7bc4
create notebook for dev
shaikh58 Jul 31, 2024
6f2c7ad
test update of notebook
shaikh58 Jul 31, 2024
b82c4d4
implement rope embedding
shaikh58 Aug 2, 2024
a07ea57
minor changes - add batch job file to repo
shaikh58 Aug 5, 2024
6d135fb
add local train run script, minor changes
shaikh58 Aug 5, 2024
4714aea
Update rope.ipynb
shaikh58 Aug 5, 2024
9c64789
refactor transformer encoder
shaikh58 Aug 6, 2024
67bf6e4
further changes for rope
shaikh58 Aug 6, 2024
fa61af0
complete encoder section of rope
shaikh58 Aug 6, 2024
55f5f25
setup batch training
shaikh58 Aug 7, 2024
9cec3a2
remove batch run commands from repo
shaikh58 Aug 7, 2024
f02a173
remove batch training script
shaikh58 Aug 7, 2024
287c475
Update base.yaml
shaikh58 Aug 7, 2024
a7e3a56
Merge branch 'mustafa-rope' of https://github.com/talmolab/dreem into…
shaikh58 Aug 7, 2024
5d4bf5e
Update run_trainer.py
shaikh58 Aug 7, 2024
f23ef5c
Update .gitignore
shaikh58 Aug 7, 2024
785df8f
comments for tracker.py
shaikh58 Aug 7, 2024
3d3f2ca
embedding bug fixes for encoder
shaikh58 Aug 8, 2024
6af9e17
implement rope for decoder
shaikh58 Aug 9, 2024
6928078
final attn head supports stack embeddings
shaikh58 Aug 9, 2024
c4b1124
Update tests, add new unit tests for rope
shaikh58 Aug 10, 2024
62f2c03
rope bug fixes
shaikh58 Aug 12, 2024
9292bbc
minor update to previous commit
shaikh58 Aug 12, 2024
3751de0
fix device mismatch in mlp module
shaikh58 Aug 15, 2024
3d1a35e
support for adding embedding to instance
shaikh58 Aug 15, 2024
c4abac2
bug fixes to pass unit tests
shaikh58 Aug 16, 2024
5a7e86b
minor updates from PR review
shaikh58 Aug 16, 2024
9eddead
allow batch eval/inference flexibility rather than just different mod…
aaprasad Aug 16, 2024
d5993a9
linting
shaikh58 Aug 19, 2024
bcb661a
add cross attn for rope-stack before final asso matrix output
shaikh58 Aug 26, 2024
fd77ded
minor bug fix in rope embedding for single instance clips
shaikh58 Aug 27, 2024
41454f7
use `sleap-io` as video backend instead of imageio
aaprasad Aug 30, 2024
64c970b
lint
aaprasad Aug 30, 2024
b63f24f
create notebook for dev
shaikh58 Jul 31, 2024
c320eea
test update of notebook
shaikh58 Jul 31, 2024
21035fb
implement rope embedding
shaikh58 Aug 2, 2024
4d27914
minor changes - add batch job file to repo
shaikh58 Aug 5, 2024
be5e630
add local train run script, minor changes
shaikh58 Aug 5, 2024
dba9f08
Update rope.ipynb
shaikh58 Aug 5, 2024
0dd6a60
refactor transformer encoder
shaikh58 Aug 6, 2024
e492909
further changes for rope
shaikh58 Aug 6, 2024
4140524
complete encoder section of rope
shaikh58 Aug 6, 2024
a1ca23e
setup batch training
shaikh58 Aug 7, 2024
b5fa58d
remove batch run commands from repo
shaikh58 Aug 7, 2024
c721e90
Update base.yaml
shaikh58 Aug 7, 2024
6711697
remove batch training script
shaikh58 Aug 7, 2024
20fd4a7
Update run_trainer.py
shaikh58 Aug 7, 2024
9ac41a8
Update .gitignore
shaikh58 Aug 7, 2024
c43ee75
comments for tracker.py
shaikh58 Aug 7, 2024
fe1eeca
embedding bug fixes for encoder
shaikh58 Aug 8, 2024
2da8c09
implement rope for decoder
shaikh58 Aug 9, 2024
65a4ae0
final attn head supports stack embeddings
shaikh58 Aug 9, 2024
7c38ad4
Update tests, add new unit tests for rope
shaikh58 Aug 10, 2024
8b552ef
rope bug fixes
shaikh58 Aug 12, 2024
8fdfba1
minor update to previous commit
shaikh58 Aug 12, 2024
03df33f
fix device mismatch in mlp module
shaikh58 Aug 15, 2024
1d2f5a5
support for adding embedding to instance
shaikh58 Aug 15, 2024
5a5f75f
bug fixes to pass unit tests
shaikh58 Aug 16, 2024
3ff1ab0
minor updates from PR review
shaikh58 Aug 16, 2024
fe2c88e
linting
shaikh58 Aug 19, 2024
de2ace9
add cross attn for rope-stack before final asso matrix output
shaikh58 Aug 26, 2024
9b29171
minor bug fix in rope embedding for single instance clips
shaikh58 Aug 27, 2024
3bc9fef
Merge branch 'mustafa-rope' of https://github.com/talmolab/dreem into…
shaikh58 Sep 27, 2024
1998f6f
- Started implementation for post processing fixes; no logic changes
shaikh58 Oct 9, 2024
2792d13
Merge branch 'main' into mustafa-rope
shaikh58 Dec 13, 2024
e4ce29c
merge bug fixes
shaikh58 Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,8 @@ dreem/training/models/*

# docs
site/
*.xml
dreem/training/configs/base.yaml
dreem/training/configs/override.yaml
dreem/training/configs/override.yaml
dreem/training/configs/base.yaml
2 changes: 1 addition & 1 deletion dreem/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,4 +415,4 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
def __del__(self):
"""Handle file closing before garbage collection."""
for reader in self.videos:
reader.close()
reader.close()
2 changes: 1 addition & 1 deletion dreem/inference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:

# override with params config, and specific params:
# python eval.py --config-dir=./configs --config-name=inference +params_config=configs/params.yaml dataset.train_dataset.padding=10
run()
run()
12 changes: 9 additions & 3 deletions dreem/inference/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def filter_max_center_dist(
k_boxes: torch.Tensor | None = None,
nonk_boxes: torch.Tensor | None = None,
id_inds: torch.Tensor | None = None,
h: int = None,
w: int = None
Comment on lines +129 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Implement the usage of new parameters h and w

The new parameters h and w have been added to the function signature, but they are not used in the function body. These parameters are intended to represent the height and width of the image for scaling the distance calculation.

To address this, implement the scaling of the distance calculation using h and w. Here's a suggested implementation:

# After calculating dist
scaled_dist = dist * torch.tensor([h, w], device=dist.device)
norm_dist = scaled_dist.mean(axis=-1)  # n_k x Np

This change will convert the distance from normalized coordinates to pixel units.

Also applies to: 140-141

) -> torch.Tensor:
"""Filter trajectory score by distances between objects across frames.

Expand All @@ -135,6 +137,8 @@ def filter_max_center_dist(
k_boxes: The bounding boxes in the current frame
nonk_boxes: the boxes not in the current frame
id_inds: track ids
h: height of image
w: width of image

Returns:
An N_t x N association matrix
Expand All @@ -147,13 +151,15 @@ def filter_max_center_dist(
k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k

nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2

# TODO: nonk_boxes should be only from previous frame rather than entire window
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Address TODO comments regarding distance calculation

The TODO comments indicate that:

  1. nonk_boxes should only be from the previous frame rather than the entire window.
  2. The distance calculation needs to be scaled by the original image size.

To address these issues:

  1. Modify the nonk_boxes input to only include boxes from the previous frame.
  2. Implement the scaling of the distance calculation as suggested in the previous comment.

Please update the implementation to reflect these changes and remove the TODO comments once addressed.

Also applies to: 158-160

dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(
dim=-1
) # n_k x Np

norm_dist = dist / (k_s[:, None, :] + 1e-8)
# TODO: note that dist is in units of fraction of the height and width of the image;
# TODO: need to scale it by the original image size so that its in units of pixels
# norm_dist = dist / (k_s[:, None, :] + 1e-8)
norm_dist = dist.mean(axis=-1) # n_k x Np
# norm_dist =
Comment on lines +160 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Missing Tests for filter_max_center_dist

No existing tests found for filter_max_center_dist. Please add appropriate test cases to ensure that the new distance normalization aligns with the intended functionality.

🔗 Analysis chain

Review the change in distance normalization

The distance normalization has been changed from dividing by the box size to taking the mean across the last dimension.

This change might affect the behavior of the function. Please verify if this new approach aligns with the intended functionality of filter_max_center_dist. Consider running the following test:

If the change is intentional and correct, please add a comment explaining the rationale behind this modification.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Compare the output of the old and new implementations

# Test: Search for test cases of filter_max_center_dist
rg --type python -A 10 "def test_filter_max_center_dist"

# If no tests are found, consider adding a new test case
if [ $? -ne 0 ]; then
    echo "No existing tests found for filter_max_center_dist. Consider adding a new test case."
fi

Length of output: 286


valid = norm_dist < max_center_dist # n_k x Np
valid_assn = (
Expand Down
2 changes: 1 addition & 1 deletion dreem/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,4 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:

# override with params config, and specific params:
# python train.py --config-dir=./configs --config-name=inference +params_config=configs/params.yaml dataset.train_dataset.padding=10
run()
run()
27 changes: 21 additions & 6 deletions dreem/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,10 @@ def track(
# asso_preds, pred_boxes, pred_time, embeddings = self.model(
# instances, reid_features
# )
# get reference and query instances from TrackQueue and calls _run_global_tracker()
instances_pred = self.sliding_inference(model, frames)

# e.g. during train/val, don't track across batches so persistent_tracking is switched off
if not self.persistent_tracking:
logger.debug(f"Clearing Queue after tracking")
self.track_queue.end_tracks()
Expand All @@ -164,7 +166,9 @@ def sliding_inference(
# H: height.
# W: width.

# frames is untracked clip for inference
for batch_idx, frame_to_track in enumerate(frames):
# tracked_frames is a list of reference frames that have been tracked (associated)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

⚠️ Potential issue

Potential AttributeError Confirmed: frame_id is inconsistently handled as both tensors and integers.

  • In dreem/inference/tracker.py, frame_to_track.frame_id is accessed as a tensor (frame_to_track.frame_id.device).
  • Multiple instances in other files assign frame_id as an integer (e.g., frame_id=0), which would cause an AttributeError when .device is accessed.
🔗 Analysis chain

Potential AttributeError when accessing frame_id.device.

In the line device=frame_to_track.frame_id.device, ensure that frame_to_track.frame_id is a tensor. If frame_id is an integer, accessing .device will raise an AttributeError. Please verify the type of frame_id.

You can run the following script to check the type of frame_id across the codebase:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that 'frame_id' is consistently a tensor.

# Test: Search for 'frame_id' assignments. Expect: 'frame_id' should be a tensor.
rg --type python --context 2 'frame_id\s*='

Length of output: 76


Script:

#!/bin/bash
# Description: Verify that 'frame_id' is consistently a tensor.

# Test: Search for 'frame_id' assignments in Python files. Expect: 'frame_id' should be a tensor.
rg --glob "*.py" --context 2 'frame_id\s*='

Length of output: 7061

tracked_frames = self.track_queue.collate_tracks(
device=frame_to_track.frame_id.device
)
Expand All @@ -188,10 +192,11 @@ def sliding_inference(
)

curr_track_id = 0
# if track ids exist from another tracking program i.e. sleap, init with those
for i, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)

# if no track ids, then assign new ones
shaikh58 marked this conversation as resolved.
Show resolved Hide resolved
for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track_id += 1
Expand All @@ -201,6 +206,7 @@ def sliding_inference(
if (
frame_to_track.has_instances()
): # Check if there are detections. If there are skip and increment gap count
# combine the tracked frames with the latest frame; inference pipeline uses latest frame as pred
frames_to_track = tracked_frames + [
frame_to_track
] # better var name?
Expand All @@ -217,7 +223,7 @@ def sliding_inference(
self.track_queue.add_frame(frame_to_track)
else:
self.track_queue.increment_gaps([])

# update the frame object from the input inference untracked clip
frames[batch_idx] = frame_to_track
return frames

Expand Down Expand Up @@ -252,7 +258,7 @@ def _run_global_tracker(
# E.g.: instances_per_frame: [4, 5, 6, 7]; window of length 4 with 4 detected instances in the first frame of the window.

_ = model.eval()

# get the last frame in the clip to perform inference on
query_frame = frames[query_ind]

query_instances = query_frame.instances
Expand All @@ -279,8 +285,10 @@ def _run_global_tracker(

# (L=1, n_query, total_instances)
with torch.no_grad():
# GTR knows this is for inference since query_instances is not None
asso_matrix = model(all_instances, query_instances)

# GTR output is n_query x n_instances - split this into per-frame to softmax each frame separately
asso_output = asso_matrix[-1].matrix.split(
instances_per_frame, dim=1
) # (window_size, n_query, N_i)
Expand All @@ -296,7 +304,7 @@ def _run_global_tracker(

asso_output_df.index.name = "Instances"
asso_output_df.columns.name = "Instances"

# save the association matrix to the Frame object
query_frame.add_traj_score("asso_output", asso_output_df)
query_frame.asso_output = asso_matrix[-1]

Expand Down Expand Up @@ -343,6 +351,8 @@ def _run_global_tracker(

query_frame.add_traj_score("asso_nonquery", asso_nonquery_df)

# need frame height and width to scale boxes during post-processing
_, h, w = query_frame.img_shape.flatten()
pred_boxes = model_utils.get_boxes(all_instances)
query_boxes = pred_boxes[query_inds] # n_k x 4
nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4
Expand Down Expand Up @@ -374,7 +384,7 @@ def _run_global_tracker(

query_frame.add_traj_score("decay_time", decay_time_traj_score)
################################################################################

# reduce association matrix - aggregating reference instance association scores by tracks
# (n_query x n_nonquery) x (n_nonquery x n_traj) --> n_query x n_traj
traj_score = torch.mm(traj_score, id_inds.cpu()) # (n_query, n_traj)

Expand All @@ -387,6 +397,7 @@ def _run_global_tracker(

query_frame.add_traj_score("traj_score", traj_score_df)
################################################################################
# IOU-based post-processing; add a weighted IOU across successive frames to association scores

# with iou -> combining with location in tracker, they set to True
# todo -> should also work without pos_embed
Expand Down Expand Up @@ -421,11 +432,12 @@ def _run_global_tracker(

query_frame.add_traj_score("weight_iou", iou_traj_score)
################################################################################
# filters association matrix such that instances too far from each other get scores=0

# threshold for continuing a tracking or starting a new track -> they use 1.0
# todo -> should also work without pos_embed
traj_score = post_processing.filter_max_center_dist(
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds, h, w
)

if self.max_center_dist is not None and self.max_center_dist > 0:
Expand All @@ -439,6 +451,7 @@ def _run_global_tracker(
query_frame.add_traj_score("max_center_dist", max_center_dist_traj_score)

################################################################################
# softmax along tracks for each instance, for interpretability
scaled_traj_score = torch.softmax(traj_score, dim=1)
scaled_traj_score_df = pd.DataFrame(
scaled_traj_score.numpy(), columns=unique_ids.cpu().numpy()
Expand All @@ -449,6 +462,7 @@ def _run_global_tracker(
query_frame.add_traj_score("scaled", scaled_traj_score_df)
################################################################################

# hungarian matching
match_i, match_j = linear_sum_assignment((-traj_score))

track_ids = instance_ids.new_full((n_query,), -1)
Expand All @@ -462,6 +476,7 @@ def _run_global_tracker(
thresh = (
overlap_thresh * id_inds[:, j].sum() if mult_thresh else overlap_thresh
)
# if the association score for a query instance is lower than the threshold, create a new track for it
if n_traj >= self.max_tracks or traj_score[i, j] > thresh:
logger.debug(
f"Assigning instance {i} to track {j} with id {unique_ids[j]}"
Expand Down
21 changes: 16 additions & 5 deletions dreem/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def get_dataset(
mode: str,
label_files: list[str] | None = None,
vid_files: list[str | list[str]] = None,
) -> "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset":
) -> "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Undefined type names in return type annotation

The return type annotation of get_dataset() includes "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None, but these classes are not imported at the module level. This leads to undefined names, as indicated by the static analysis tool.

Consider importing the dataset classes at the top of the module:

+from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset

Alternatively, if you wish to avoid importing at the top level due to potential circular imports, you can use a TYPE_CHECKING block:

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset
🧰 Tools
🪛 Ruff

202-202: Undefined name SleapDataset

(F821)


202-202: Undefined name MicroscopyDataset

(F821)


202-202: Undefined name CellTrackingDataset

(F821)

"""Getter for datasets.

Args:
Expand Down Expand Up @@ -301,6 +301,10 @@ def get_dataset(
"Could not resolve dataset type from Config! Please include \
either `slp_files` or `tracks`/`source`"
)
if len(dataset) == 0:
logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None")
return None
return dataset
Comment on lines +304 to +307
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use logger.warning instead of deprecated logger.warn

The logger.warn() method is deprecated since Python 3.3 and has been replaced with logger.warning(). Update the logging calls to use logger.warning().

Apply this diff:

-            logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None")
+            logger.warning(f"Length of {mode} dataset is {len(dataset)}! Returning None")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if len(dataset) == 0:
logger.warn(f"Length of {mode} dataset is {len(dataset)}! Returning None")
return None
return dataset
if len(dataset) == 0:
logger.warning(f"Length of {mode} dataset is {len(dataset)}! Returning None")
return None
return dataset


@property
def data_paths(self):
Expand All @@ -319,9 +323,9 @@ def data_paths(self, paths: tuple[str, list[str]]):

def get_dataloader(
self,
dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset",
dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Undefined type names in parameter annotation

In the parameter annotation of get_dataloader(), the type dataset: "SleapDataset" | "MicroscopyDataset" | "CellTrackingDataset" | None includes types that are not imported at the module level. This results in undefined names, as reported by the static analysis tool.

To fix this, import the dataset classes at the top of the module:

+from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset

Or, if you prefer to avoid top-level imports:

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset
🧰 Tools
🪛 Ruff

263-263: Undefined name SleapDataset

(F821)


263-263: Undefined name MicroscopyDataset

(F821)


263-263: Undefined name CellTrackingDataset

(F821)

mode: str,
) -> torch.utils.data.DataLoader:
) -> torch.utils.data.DataLoader | None:
"""Getter for dataloader.

Args:
Expand Down Expand Up @@ -350,14 +354,21 @@ def get_dataloader(
else:
pin_memory = False

return torch.utils.data.DataLoader(
dataloader = torch.utils.data.DataLoader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider making batch_size configurable

Currently, batch_size is hardcoded to 1 in the DataLoader. If you intend to allow different batch sizes, consider retrieving batch_size from dataloader_params to make it configurable.

Apply this diff to use the batch_size from dataloader_params:

-            dataloader = torch.utils.data.DataLoader(
-                dataset=dataset,
-                batch_size=1,
-                pin_memory=pin_memory,
-                collate_fn=dataset.no_batching_fn,
-                **dataloader_params,
-            )
+            dataloader = torch.utils.data.DataLoader(
+                dataset=dataset,
+                pin_memory=pin_memory,
+                collate_fn=dataset.no_batching_fn,
+                **dataloader_params,
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dataloader = torch.utils.data.DataLoader(
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
pin_memory=pin_memory,
collate_fn=dataset.no_batching_fn,
**dataloader_params,
)

dataset=dataset,
batch_size=1,
pin_memory=pin_memory,
collate_fn=dataset.no_batching_fn,
**dataloader_params,
)

if len(dataloader) == 0:
logger.warn(
f"Length of {mode} dataloader is {len(dataloader)}! Returning `None`"
)
Comment on lines +366 to +368
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Replace deprecated logger.warn with logger.warning

The logger.warn() method is deprecated. Update the logging call to logger.warning().

Apply this diff:

-        logger.warn(
+        logger.warning(
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.warn(
f"Length of {mode} dataloader is {len(dataloader)}! Returning `None`"
)
logger.warning(
f"Length of {mode} dataloader is {len(dataloader)}! Returning `None`"
)

return None
return dataloader

def get_optimizer(self, params: Iterable) -> torch.optim.Optimizer:
"""Getter for optimizer.

Expand Down Expand Up @@ -492,7 +503,7 @@ def get_checkpointing(self) -> pl.callbacks.ModelCheckpoint:
filename=f"{{epoch}}-{{{metric}}}",
**checkpoint_params,
)
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-best-{{{metric}}}"
checkpointer.CHECKPOINT_NAME_LAST = f"{{epoch}}-final-{{{metric}}}"
checkpointers.append(checkpointer)
return checkpointers

Expand Down
6 changes: 5 additions & 1 deletion dreem/io/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,11 @@ def add_embedding(self, emb_type: str, embedding: torch.Tensor) -> None:
emb_type: Key/embedding type to be saved to dictionary
embedding: The actual torch tensor embedding.
"""
embedding = _expand_to_rank(embedding, 2)
if (
type(embedding) != dict
): # for embedding agg method "average", input is array
# for method stack and concatenate, input is dict
embedding = _expand_to_rank(embedding, 2)
self._embeddings[emb_type] = embedding

@property
Expand Down
76 changes: 65 additions & 11 deletions dreem/models/attention_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,40 @@
class ATTWeightHead(torch.nn.Module):
"""Single attention head."""

def __init__(
self,
feature_dim: int,
num_layers: int,
dropout: float,
):
def __init__(self, feature_dim: int, num_layers: int, dropout: float, **kwargs):
"""Initialize an instance of ATTWeightHead.

Args:
feature_dim: The dimensionality of input features.
num_layers: The number of hidden layers in the MLP.
dropout: Dropout probability.
embedding_agg_method: how the embeddings are aggregated; average/stack/concatenate
"""
super().__init__()
if "embedding_agg_method" in kwargs:
self.embedding_agg_method = kwargs["embedding_agg_method"]
else:
self.embedding_agg_method = None
Comment on lines +22 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify the initialization of embedding_agg_method.

Instead of using an if block, you can use the get method for simplicity and clarity.

- if "embedding_agg_method" in kwargs:
-     self.embedding_agg_method = kwargs["embedding_agg_method"]
- else:
-     self.embedding_agg_method = None
+ self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if "embedding_agg_method" in kwargs:
self.embedding_agg_method = kwargs["embedding_agg_method"]
else:
self.embedding_agg_method = None
self.embedding_agg_method = kwargs.get("embedding_agg_method", None)
Tools
Ruff

22-25: Use self.embedding_agg_method = kwargs.get("embedding_agg_method", None) instead of an if block

Replace with self.embedding_agg_method = kwargs.get("embedding_agg_method", None)

(SIM401)


self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
# if using stacked embeddings, use 1x1 conv with x,y,t embeddings as channels
# ensures output represents ref instances by query instances
if self.embedding_agg_method == "stack":
self.q_proj = torch.nn.Conv1d(
in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0
)
self.k_proj = torch.nn.Conv1d(
in_channels=3, out_channels=1, kernel_size=1, stride=1, padding=0
)
self.attn_x = torch.nn.MultiheadAttention(feature_dim, 1)
self.attn_y = torch.nn.MultiheadAttention(feature_dim, 1)
self.attn_t = torch.nn.MultiheadAttention(feature_dim, 1)
else:
self.q_proj = MLP(
feature_dim, feature_dim, feature_dim, num_layers, dropout
)
self.k_proj = MLP(
feature_dim, feature_dim, feature_dim, num_layers, dropout
)

def forward(
self,
Expand All @@ -41,8 +58,45 @@ def forward(
Returns:
Output tensor of shape (batch_size, num_frame_instances, num_window_instances).
"""
k = self.k_proj(key)
q = self.q_proj(query)
attn_weights = torch.bmm(q, k.transpose(1, 2))
batch_size, num_query_instances, feature_dim = query.size()
num_window_instances = key.shape[1]

# if stacked embeddings, create channels for each x,y,t embedding dimension
# maps shape (1,num_instances*3,feature_dim) -> (num_instances,3,feature_dim)
if self.embedding_agg_method == "stack":
key_stacked = (
key
.view(batch_size, 3, num_window_instances // 3, feature_dim)
.permute(0, 2, 1, 3)
.squeeze(0) # keep as (num_instances*3, feature_dim)
)
key_orig = key.squeeze(0) # keep as (num_instances*3, feature_dim)

query = (
query.view(batch_size, 3, num_query_instances // 3, feature_dim)
.permute(0, 2, 1, 3)
.squeeze(0)
)
# pass t,x,y frame features through cross attention with entire encoder 3*num_window_instances tokens before MLP;
# note order is t,x,y
out_t, _ = self.attn_t(query=query[:,0,:], key=key_orig, value=key_orig)
out_x, _ = self.attn_x(query=query[:,1,:], key=key_orig, value=key_orig)
out_y, _ = self.attn_y(query=query[:,2,:], key=key_orig, value=key_orig)
# combine each attention output to (num_instances, 3, feature_dim)
collated = torch.stack((out_t, out_x, out_y), dim=0).permute(1,0,2)
# mlp_out has shape (1, num_window_instances, feature_dim)
mlp_out = self.q_proj(collated).transpose(1,0)

# key, query of shape (num_instances, 3, feature_dim)
# TODO: uncomment this if not using modified attention heads for t,x,y
k = self.k_proj(key_stacked).transpose(1, 0)
# q = self.q_proj(query).transpose(1, 0)
# k,q of shape (num_instances, feature_dim)
attn_weights = torch.bmm(mlp_out, k.transpose(1, 2))
else:
k = self.k_proj(key)
q = self.q_proj(query)
attn_weights = torch.bmm(q, k.transpose(1, 2))


return attn_weights # (B, N_t, N)
Loading
Loading