Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE embeddings #78

Open
wants to merge 66 commits into
base: main
Choose a base branch
from
Open

RoPE embeddings #78

wants to merge 66 commits into from

Conversation

shaikh58
Copy link
Contributor

@shaikh58 shaikh58 commented Aug 16, 2024

  • Implementation of Rotary positional embeddings
  • Refactored Transformer, TransformerDecoder, TransformerEncoder, Layers, Embeddings; abstracts out the application of the embedding; since this change support multiple types of embedding aggregation, and RoPE transforms the feature vector through matmul rather than adding, all the above needed to be refactored to reduce the number of points in the code that embeddings are processed
  • Integrated into config structure; supports stack and concatenate methods for the transformed feature vectors
  • Backward compatible with fixed embeddings and averaging of embeddings
  • Passes all tests locally
  • Supports adding embedding to instance, though retrieval is slightly different as the embedding can now have an x,y component rather than a single array

Summary by CodeRabbit

  • New Features

    • Added support for Rotary Positional Embeddings in the embedding module, improving handling of temporal and positional data.
    • Introduced a new class for Rotary Positional Embeddings to optimize performance in models utilizing these embeddings.
    • Enhanced configuration management for batch training jobs, allowing for flexible hyperparameter setups.
    • Improved tracking functionality with enhanced instance management during inference.
    • Implemented a structured approach to dataset management in training scripts, including visualization options for training batches.
    • Updated methods to handle potential absence of datasets and dataloaders, providing clearer feedback during execution.
  • Bug Fixes

    • Adjusted the handling of embedding inputs to accommodate different types for increased flexibility.
    • Enhanced error handling in dataset and dataloader methods to manage cases where no valid data is available.
  • Documentation

    • Enhanced comments and documentation across various files to improve code readability and understanding.
  • Chores

    • Improved the structure and readability of code in multiple files, including whitespace adjustments.

shaikh58 and others added 26 commits July 31, 2024 15:23
- changes to embedding class
- add apply() function to Embedding class
- remove references to embedding from encoderlayer fwd pass
- add support for stack/avg/concatenate
- move embedding processing out of transformer and into encoder
- get centroid from x,y for spatial embedding
- complete stack agg method
- add docstrings
- concatenation method with mlp
- complete pre-processing for input to EncoderLayer
- fix shape issues in rope/additive_embedding/forward modules in embedding.py
- bounding box embedding only for method "average" - modify emb_funcs routing
- temporarily remove support for adding embeddings into instance objects - need to make compatible with x,y,t embeddings
- remove config yamls from updates - current versions serve as templates
- runs through to end of encoder forward pass
- implement embeddings for decoder + refactor
- add 1x1 conv to final attn head to deal with stacked embeddings (3x tokens) and create channels for each dim
- bug fix in rope rotation matrix product with input data
- 1x1 conv for stack embedding
- stack into 3 channels for x,y,t
- add unit tests for rope
- Update existing tests to use new args/return params related to tfmr
- Modify test to remove return_embedding=True support - need to address this
- create rope isntance once rather than each fwd pass
- construct embedding lookup array each fwd pass based on num instances passed in to embedding
- scale only pos embs * 100 rather than also temp embs
- times array for embedding for encoder queries inside decoder was of query size rather than ref size
@shaikh58 shaikh58 linked an issue Aug 16, 2024 that may be closed by this pull request
Copy link
Contributor

coderabbitai bot commented Aug 16, 2024

Walkthrough

The recent changes enhance the dreem framework by refining the SleapDataset class, modifying inference scripts for batch processing, and introducing a new embedding strategy with Rotary Positional Embeddings. Configuration management has been improved to handle datasets and hyperparameters more flexibly. The modifications to logging and error handling across various components contribute to a more robust training and evaluation process. Overall, these updates promote cleaner code and a more structured approach to model training and inference.

Changes

Files Change Summary
dreem/datasets/sleap_dataset.py Renamed self.videos to self.vid_readers, updated video frame retrieval logic, and modified error handling. Adjusted anchors parameter handling.
dreem/inference/eval.py Updated run function to load hyperparameters from batch_config, added user input fallback for pod index, and simplified checkpoint loading logic.
dreem/inference/track.py Modified run function to check for batch_config, read hyperparameters from a CSV, and updated logging statements.
dreem/io/config.py Enhanced get_dataset and get_dataloader methods to allow None return values with warning logs for empty datasets.
dreem/models/embedding.py Introduced RotaryPositionalEmbeddings class and updated Embedding class to support new aggregation methods and embedding logic.
dreem/models/gtr_runner.py Simplified extraction of video names and ground truth track IDs in on_test_epoch_end method.
dreem/models/transformer.py Updated Transformer class to handle new embedding_agg_method, refactored embedding application logic, and adjusted forward methods for flexibility.
dreem/training/configs/base.yaml Modified training configurations to include embedding_agg_method, dataset modes, and updated checkpointing strategy.
dreem/training/train.py Enhanced dataset management and logging, added batch viewing functionality, and improved device configuration handling.

Possibly related PRs

  • Refactor visual encoder and features #101: The changes in this PR involve modifications to the SleapDataset class in sleap_dataset.py, specifically the addition of a normalize_image parameter, which is relevant to the main PR's updates to the same class, particularly regarding how video data is managed and potentially normalized.

🐰 In the garden, hopping around,
New changes sprout from the ground.
Embeddings and configs, all in a row,
Streamlined processes help us grow!
Repositories tidy, code shining bright,
A joyful journey, oh what a sight! 🌼✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Experiment)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 28

Outside diff range, codebase verification and nitpick comments (4)
run_trainer.py (1)

10-10: Avoid commented-out code.

Instead of commenting out the override configuration, consider using a command-line argument or environment variable to specify which configuration file to use.

- # params_config = "./configs/override.yaml"
tests/test_models.py (2)

247-248: Clarify the shape of input data.

The comment on line 247 specifies the shape of the input data. Ensure this comment remains accurate and update it if the input shape changes in future modifications.


317-324: Ensure consistency in input data shape comments.

The comment on line 323 specifies the shape of the input data. Ensure consistency across the file by maintaining accurate comments about data shapes.

rope.ipynb (1)

307-311: Avoid hardcoded dataset paths.

The dataset path is hardcoded, which can lead to issues when running the notebook in different environments. Consider parameterizing the path or using a configuration file.

- train_path = "/home/jovyan/talmolab-smb/datasets/mot/microscopy/airyscan_proofread/Final/dreem-train"
+ train_path = os.getenv("DREEM_TRAIN_PATH", "/default/path/to/dreem-train")
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 2af0dd5 and c4abac2.

Files ignored due to path filters (1)
  • dreem/training/configs/test_batch_train.csv is excluded by !**/*.csv
Files selected for processing (15)
  • .gitignore (1 hunks)
  • dreem/inference/tracker.py (14 hunks)
  • dreem/io/config.py (1 hunks)
  • dreem/io/instance.py (1 hunks)
  • dreem/models/attention_head.py (2 hunks)
  • dreem/models/embedding.py (13 hunks)
  • dreem/models/mlp.py (2 hunks)
  • dreem/models/transformer.py (14 hunks)
  • dreem/training/configs/base.yaml (6 hunks)
  • dreem/training/configs/override.yaml (1 hunks)
  • dreem/training/train.py (2 hunks)
  • rope.ipynb (1 hunks)
  • run_trainer.py (1 hunks)
  • tests/test_models.py (12 hunks)
  • tests/test_training.py (1 hunks)
Files skipped from review due to trivial changes (2)
  • .gitignore
  • dreem/training/train.py
Additional context used
yamllint
dreem/training/configs/base.yaml

[warning] 19-19: wrong indentation: expected 6 but found 8

(indentation)


[warning] 22-22: wrong indentation: expected 6 but found 8

(indentation)

dreem/training/configs/override.yaml

[error] 3-3: trailing spaces

(trailing-spaces)


[warning] 4-4: wrong indentation: expected 4 but found 6

(indentation)


[warning] 20-20: wrong indentation: expected 6 but found 8

(indentation)


[warning] 24-24: wrong indentation: expected 6 but found 8

(indentation)


[warning] 29-29: too many spaces after colon

(colons)


[error] 43-43: trailing spaces

(trailing-spaces)


[warning] 46-46: wrong indentation: expected 6 but found 8

(indentation)


[warning] 58-58: too few spaces after comma

(commas)


[warning] 66-66: too few spaces after comma

(commas)


[error] 76-76: trailing spaces

(trailing-spaces)


[warning] 78-78: wrong indentation: expected 4 but found 6

(indentation)


[error] 78-78: trailing spaces

(trailing-spaces)


[error] 99-99: trailing spaces

(trailing-spaces)


[error] 112-112: trailing spaces

(trailing-spaces)


[error] 130-130: trailing spaces

(trailing-spaces)


[warning] 132-132: wrong indentation: expected 2 but found 4

(indentation)


[warning] 133-133: wrong indentation: expected 6 but found 8

(indentation)


[warning] 137-137: wrong indentation: expected 6 but found 8

(indentation)


[error] 142-142: no new line character at the end of file

(new-line-at-end-of-file)

Ruff
dreem/models/attention_head.py

28-31: Use self.embedding_agg_method = kwargs.get('embedding_agg_method', None) instead of an if block

Replace with self.embedding_agg_method = kwargs.get('embedding_agg_method', None)

(SIM401)

tests/test_models.py

18-18: dreem.models.transformer.apply_embeddings imported but unused

Remove unused import: dreem.models.transformer.apply_embeddings

(F401)


166-166: pytest.raises(Exception) should be considered evil

(B017)

dreem/inference/tracker.py

196-196: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

dreem/io/instance.py

568-568: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

dreem/models/embedding.py

267-267: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/models/transformer.py

84-85: Use embedding_meta.get("embedding_agg_method", "average") instead of an if block

Replace with embedding_meta.get("embedding_agg_method", "average")

(SIM401)

Additional comments not posted (39)
run_trainer.py (1)

15-15: LGTM!

The training function is correctly invoked with the configuration.

dreem/models/mlp.py (2)

37-37: Comment improves readability.

The added comment clarifies the purpose of list concatenations, enhancing code readability.


59-59: Ensures device compatibility.

Moving each layer to the device of the input tensor is a good practice to avoid device mismatch errors.

dreem/training/configs/base.yaml (5)

23-23: LGTM!

The embedding_agg_method is correctly set to "stack".


Line range hint 76-94:
Improves configuration clarity.

The addition of mode settings for datasets enhances clarity and explicitness in configuration.

Tools
yamllint

[error] 82-82: trailing spaces

(trailing-spaces)


117-117: LGTM!

Setting log_model to null aligns with the goal of reducing logging verbosity.


134-134: Enhances checkpointing strategy.

Setting every_n_epochs to 1 improves recovery options during training.


147-148: LGTM!

Setting max_epochs and min_epochs to 1 aligns with the goal of quick iterations or testing.

dreem/models/attention_head.py (2)

33-44: Ensure correct handling of embedding aggregation methods.

The logic for handling the "stack" method using torch.nn.Conv1d is correctly implemented. Ensure that the reshaping and permutation logic in the forward method aligns with the expected input shapes for the convolutional layers.


60-79: Verify tensor reshaping and permutation in forward.

The reshaping and permutation logic for the "stack" method should be verified to ensure it correctly aligns with the input requirements of the convolutional layers. Ensure that the dimensions and operations are consistent with the intended design.

Verification successful

Tensor reshaping and permutation logic is verified.

The reshaping and permutation logic in the forward method aligns with the input requirements of the Conv1d layers used in k_proj and q_proj. The reshaped tensors have the correct channel dimension of 3, consistent with the Conv1d configuration. The use of squeeze(0) is appropriate given the expected batch size of 1. No issues found with the current implementation.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the tensor reshaping and permutation logic in the `forward` method.

# Test: Search for the `forward` method to review its implementation.
ast-grep --lang python --pattern $'def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:\n    $$$' dreem/models/attention_head.py

Length of output: 2923


Script:

#!/bin/bash
# Description: Search for the definitions of `k_proj` and `q_proj` to verify compatibility with reshaped tensors.

# Find the definition of `k_proj` in the codebase.
ast-grep --lang python --pattern 'self.k_proj = $_' dreem/models/attention_head.py

# Find the definition of `q_proj` in the codebase.
ast-grep --lang python --pattern 'self.q_proj = $_' dreem/models/attention_head.py

Length of output: 1009

tests/test_training.py (1)

145-146: Reconsider commenting out anomaly detection.

Commenting out torch.autograd.set_detect_anomaly(True) reduces the ability to debug and catch subtle errors during the execution of run(cfg.cfg). Consider re-enabling it, especially during development or when debugging complex issues.

Verification successful

Reconsider commenting out anomaly detection in tests/test_training.py.

The torch.autograd.set_detect_anomaly(True) line is commented out and not used elsewhere in the test suite. Re-enabling it could enhance debugging capabilities by catching subtle errors during test execution.

  • File: tests/test_training.py, line 145
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify if anomaly detection is used elsewhere in the tests.

# Test: Search for usage of `set_detect_anomaly` in the test directory.
rg --type python 'set_detect_anomaly'

Length of output: 68


Script:

#!/bin/bash
# Description: Verify if anomaly detection is used elsewhere in the tests.

# Test: Search for usage of `set_detect_anomaly` in the test directory.
rg 'set_detect_anomaly' tests/

Length of output: 105

dreem/io/config.py (1)

43-43: Consider the implications of making the configuration mutable.

The addition of OmegaConf.set_struct(self.cfg, False) allows dynamic changes to the configuration object. Ensure this change aligns with the intended design and does not introduce unintended side effects, such as accidental modifications to the configuration during runtime.

tests/test_models.py (2)

38-38: Ensure proper testing of embedding_agg_method.

The addition of embedding_agg_method="average" in ATTWeightHead tests the new parameter. Ensure that tests cover all possible values and edge cases for this parameter to validate its behavior thoroughly.


186-232: New test for RoPE embeddings is comprehensive.

The test_rope_embedding function effectively tests the new RoPE embedding feature, checking output sizes and ensuring distinctness from input data. This test enhances the robustness of the embedding logic.

dreem/inference/tracker.py (5)

141-144: Improved clarity in tracking logic.

The added comments clarify the purpose of the sliding_inference function and the conditions under which the track queue is cleared. This enhances the readability and maintainability of the code.


169-171: Enhanced understanding of frame processing.

The comments provide context for the processing of frames in the sliding_inference function, explaining the roles of tracked_frames and frame_to_track. This improves the clarity of the tracking logic.


261-261: Clarify inference target.

The comment about getting the last frame for inference clarifies the target of the inference process. This helps in understanding the flow of data during tracking.


385-385: Clarify association matrix reduction.

The comment explains the reduction of the association matrix, which is crucial for understanding the trajectory scoring process. This enhances the code's readability.


433-433: Clarify filtering logic.

The comment about filtering the association matrix based on distance provides insight into the logic for maintaining or initiating tracks. This improves the understanding of the tracking process.

rope.ipynb (20)

12-25: Imports look good.

The imported libraries and modules appear necessary for the operations in this notebook.


39-39: Logger setup is appropriate.

The logger is set up correctly for the dreem.models module.


61-73: Constructor parameters are well-documented.

The constructor of the Embedding class is well-documented with clear descriptions of each parameter.


91-92: Good use of argument validation.

The _check_init_args method ensures that the embedding type and mode are valid, which helps prevent runtime errors.


105-106: Conditional logic for scale initialization is appropriate.

The logic to initialize self.scale when normalize is true and scale is None is correctly implemented.


125-127: Embedding function defaults to zero embeddings.

The default _emb_func returns zero embeddings, effectively disabling embeddings when not configured otherwise. This is a sensible default.


147-168: Argument validation in _check_init_args is robust.

The method properly raises exceptions for invalid emb_type and mode values, enhancing robustness.


170-190: Forward method correctly handles embedding dimension mismatch.

The forward method raises a RuntimeError if the output embedding dimension doesn't match the expected features, providing a helpful hint for resolution.


192-204: Integer division method uses rounding_mode correctly.

The _torch_int_div method uses rounding_mode="floor" to ensure correct integer division, which is appropriate for the use case.


206-249: Sine box embedding method is well-implemented.

The _sine_box_embedding method computes sine positional embeddings correctly, with appropriate handling for normalization and scaling.


251-281: Sine temporal embedding method is efficient.

The _sine_temp_embedding method efficiently computes fixed sine temporal embeddings, leveraging tensor operations.


293-295: Embedding object creation is appropriate.

The parameters used for creating Embedding objects for temporal and positional embeddings are appropriate and consistent with the class definition.


323-327: Data extraction logic is correct.

The logic for extracting instances from the dataset is appropriate and efficient.


339-341: Time vector extraction logic is correct.

The logic for extracting time vectors from the instances is appropriate.


378-379: Visualization logic is appropriate.

The logic for visualizing the temporal embedding using matplotlib is correct and provides meaningful insights.


391-393: Encoder instantiation is appropriate.

The parameters used for instantiating TransformerEncoderLayer and VisualEncoder are appropriate and consistent with their respective class definitions.


405-436: Feature extraction function is well-implemented.

The extract_features function correctly extracts features using a visual encoder, with appropriate handling for different scenarios.


450-493: Transformer preparation function is well-implemented.

The prepare_for_xfmr function correctly prepares data for a transformer, with appropriate handling for embeddings and feature vectors.


505-506: Feature extraction invocation is correct.

The invocation of the extract_features function is correct and uses appropriate arguments.


526-563: Transformer instantiation and evaluation logic is appropriate.

The logic for instantiating a Transformer and evaluating it using GTRRunner is correct and consistent with the intended use.

dreem/io/instance.py Outdated Show resolved Hide resolved
dreem/models/embedding.py Show resolved Hide resolved
tests/test_models.py Outdated Show resolved Hide resolved
tests/test_models.py Outdated Show resolved Hide resolved
dreem/models/transformer.py Outdated Show resolved Hide resolved
dreem/training/configs/override.yaml Show resolved Hide resolved
dreem/training/configs/override.yaml Show resolved Hide resolved
dreem/training/configs/override.yaml Show resolved Hide resolved
dreem/training/configs/override.yaml Show resolved Hide resolved
dreem/training/configs/override.yaml Show resolved Hide resolved
rope.ipynb Outdated Show resolved Hide resolved
dreem/models/embedding.py Outdated Show resolved Hide resolved
dreem/models/embedding.py Outdated Show resolved Hide resolved
dreem/models/embedding.py Outdated Show resolved Hide resolved
dreem/models/embedding.py Outdated Show resolved Hide resolved
dreem/models/transformer.py Show resolved Hide resolved
dreem/models/transformer.py Outdated Show resolved Hide resolved
dreem/models/transformer.py Outdated Show resolved Hide resolved
dreem/models/transformer.py Outdated Show resolved Hide resolved
shaikh58 and others added 22 commits September 27, 2024 09:44
- concatenation method with mlp
- complete pre-processing for input to EncoderLayer
- fix shape issues in rope/additive_embedding/forward modules in embedding.py
- bounding box embedding only for method "average" - modify emb_funcs routing
- temporarily remove support for adding embeddings into instance objects - need to make compatible with x,y,t embeddings
- remove config yamls from updates - current versions serve as templates
- runs through to end of encoder forward pass
- implement embeddings for decoder + refactor
- add 1x1 conv to final attn head to deal with stacked embeddings (3x tokens) and create channels for each dim
- bug fix in rope rotation matrix product with input data
- 1x1 conv for stack embedding
- stack into 3 channels for x,y,t
- add unit tests for rope
- Update existing tests to use new args/return params related to tfmr
- Modify test to remove return_embedding=True support - need to address this
- create rope isntance once rather than each fwd pass
- construct embedding lookup array each fwd pass based on num instances passed in to embedding
- scale only pos embs * 100 rather than also temp embs
- times array for embedding for encoder queries inside decoder was of query size rather than ref size
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 19

🧹 Outside diff range and nitpick comments (2)
dreem/inference/eval.py (1)

Line range hint 68-70: Verify that labels_files and vid_files have the same length

Before zipping labels_files and vid_files, it's important to ensure they have the same length to prevent potential IndexError exceptions or mismatches between labels and videos.

🧰 Tools
🪛 Ruff

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/inference/track.py (1)

104-110: Consider using os.getenv for environment variable retrieval

Using os.getenv allows you to provide a default value if the environment variable is not set, which can simplify the code.

Apply this code change to use os.getenv:

index = int(os.getenv("POD_INDEX", default=-1))
if index == -1:
    # Handle the case where POD_INDEX is not set
    index = int(input("POD_INDEX not found. Please input task index to run: "))
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between fd77ded and 3bc9fef.

📒 Files selected for processing (6)
  • dreem/datasets/sleap_dataset.py (3 hunks)
  • dreem/inference/eval.py (1 hunks)
  • dreem/inference/track.py (2 hunks)
  • dreem/io/config.py (7 hunks)
  • dreem/models/embedding.py (14 hunks)
  • dreem/models/gtr_runner.py (2 hunks)
🧰 Additional context used
🪛 Ruff
dreem/datasets/sleap_dataset.py

163-163: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)

dreem/inference/eval.py

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/inference/track.py

100-100: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

dreem/io/config.py

202-202: Undefined name SleapDataset

(F821)


202-202: Undefined name MicroscopyDataset

(F821)


202-202: Undefined name CellTrackingDataset

(F821)


263-263: Undefined name SleapDataset

(F821)


263-263: Undefined name MicroscopyDataset

(F821)


263-263: Undefined name CellTrackingDataset

(F821)

dreem/models/embedding.py

104-104: Multiple statements on one line (colon)

(E701)


267-267: f-string without any placeholders

Remove extraneous f prefix

(F541)

🔇 Additional comments (9)
dreem/models/gtr_runner.py (4)

304-304: Improved video name extraction

The simplification of video name extraction by removing the file extension is a good change. It makes the code more robust to different file extensions and simplifies the naming convention in the HDF5 file.


Line range hint 304-325: Consistent implementation of changes

The changes to video name processing and ground truth track ID extraction have been applied consistently within the on_test_epoch_end method. This consistency is crucial for maintaining the integrity of the data processing pipeline.


Line range hint 304-325: Overall assessment of changes

The modifications in the on_test_epoch_end method align well with the PR objectives. They contribute to the refactoring of the Transformer architecture by updating how track IDs are accessed, which is likely related to changes in the embedding handling. The code changes are focused, consistent, and improve the robustness of the data processing pipeline.


313-316: Updated ground truth track ID extraction

The new method of extracting ground truth track IDs by iterating over frame.instances and accessing gt_track_id directly is more straightforward. This change appears to be consistent with updates to the Instance class structure or modifications in how track IDs are stored.

To ensure this change is consistent across the codebase, please run the following script:

This will help identify any other locations where similar updates might be needed.

Also applies to: 322-325

dreem/datasets/sleap_dataset.py (1)

109-109: Initialization of self.vid_readers is appropriate

The addition of self.vid_readers = {} correctly initializes a dictionary to store video readers, facilitating efficient access and management of video files within the dataset.

dreem/io/config.py (2)

179-181: Improved reproducibility by sorting file lists

Sorting label_files and vid_files ensures consistent ordering of files, which is beneficial for reproducibility and consistent data loading. Good practice.


235-249: 🛠️ Refactor suggestion

Redundant imports inside the method

You're importing MicroscopyDataset, SleapDataset, and CellTrackingDataset inside the get_dataset() method, but their types are used in type annotations outside the method. This can cause issues with type checking and static analysis.

Consider moving the imports to the top of the module to ensure that the type annotations are recognized properly.

+from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset

Alternatively, if you need to avoid top-level imports, use forward references in type annotations or import under if TYPE_CHECKING.

Likely invalid or redundant comment.

dreem/models/embedding.py (1)

75-122: Verify the correctness of the RoPE implementation

The RotaryPositionalEmbeddings class's forward method returns rope_cache, which appears to be used as rotation matrices in _apply_rope. Please verify that the dimensions and the application of these rotation matrices align with the RoPE mechanism described in the original paper. Ensuring the rotation is correctly applied to the input tensor x is crucial for the embeddings to function as intended.

🧰 Tools
🪛 Ruff

104-104: Multiple statements on one line (colon)

(E701)

dreem/inference/track.py (1)

137-144: Ensure that output directory exists before saving results

If the outdir directory does not exist, saving the results may fail. Although os.makedirs(outdir, exist_ok=True) is called earlier, ensure that any subdirectories needed are also created.

Run the following script to confirm that the output directory exists and is writable:

Comment on lines 34 to 36
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add input validation for user input

When prompting the user for the task index, consider adding input validation to handle non-integer inputs gracefully. This prevents the program from crashing if the user enters invalid data.

Comment on lines 46 to 47
hparams = {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle failure of set_hparams appropriately

If eval_cfg.set_hparams(hparams) returns False, simply setting hparams = {} may not be sufficient. Consider logging an error message or raising an exception to indicate that the hyperparameters were not set correctly.

logger.info(f"Using the following tracker:")

print(model.tracker)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use logger instead of print statement

To maintain consistent logging practices, replace the print statement with logger.info. This ensures that all output is handled by the logging system.

- print(model.tracker)
+ logger.info(model.tracker)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(model.tracker)
logger.info(model.tracker)

@@ -26,24 +26,35 @@
"""
eval_cfg = Config(cfg)

if "checkpoints" in cfg.keys():
# update with parameters for batch train job
if "batch_config" in cfg.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Simplify dictionary key check

You can simplify the condition by removing .keys() as it's unnecessary. The more Pythonic way is:

- if "batch_config" in cfg.keys():
+ if "batch_config" in cfg:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if "batch_config" in cfg.keys():
if "batch_config" in cfg:
🧰 Tools
🪛 Ruff

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

model.tracker_cfg = eval_cfg.cfg.tracker
model.tracker = Tracker(**model.tracker_cfg)

logger.info(f"Using the following tracker:")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unnecessary f-string

The log message does not contain any variables to interpolate, so the f prefix is unnecessary. You can simplify the line as follows:

- logger.info(f"Using the following tracker:")
+ logger.info("Using the following tracker:")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.info(f"Using the following tracker:")
logger.info("Using the following tracker:")
🧰 Tools
🪛 Ruff

56-56: f-string without any placeholders

Remove extraneous f prefix

(F541)

@@ -323,6 +562,7 @@ def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
"""
temp_lookup = self.lookup
N = times.shape[0]
times = times / times.max()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle potential division by zero in time normalization

In the _learned_temp_embedding method, times is normalized by dividing by times.max(). If times.max() is zero, this will result in a division by zero error. To prevent this, consider adding a check to ensure that times.max() is not zero before performing the division.

Apply this diff to handle the potential division by zero:

 def _learned_temp_embedding(self, times: torch.Tensor, *args) -> torch.Tensor:
     # ...
-        times = times / times.max()
+        max_time = times.max()
+        if max_time > 0:
+            times = times / max_time
+        else:
+            times = times.clone()  # All times are zero, no normalization needed
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
times = times / times.max()
max_time = times.max()
if max_time > 0:
times = times / max_time
else:
times = times.clone() # All times are zero, no normalization needed

Comment on lines +372 to +380
x_rope = torch.rand(input_shape).unsqueeze(2)
# infer whether it is a positional or temporal embedding
is_pos_emb = 1 if seq_positions.max() < 1 else 0
# if it is positional, scale seq_positions since these are fractions
# in [0,1] and we need int indexes for embedding lookup
seq_positions = seq_positions * 100 if is_pos_emb else seq_positions
seq_positions = seq_positions.unsqueeze(0).int()
# RoPE module takes in dimension, num_queries as input to calculate rotation matrix
rot_mat = self.rope_instance(x_rope, seq_positions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid using random tensor in _rope_embedding

In the _rope_embedding method, x_rope is initialized as a random tensor. This tensor is used to compute the rotation matrix but introduces unnecessary randomness. Consider using a tensor with consistent values, such as zeros, to avoid any unintended side effects and ensure deterministic behavior.

Apply this diff to use a zero tensor instead:

 def _rope_embedding(self, seq_positions: torch.Tensor, input_shape: torch.Size) -> torch.Tensor:
     # ...
-        x_rope = torch.rand(input_shape).unsqueeze(2)
+        x_rope = torch.zeros(input_shape, dtype=torch.float32, device=seq_positions.device).unsqueeze(2)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
x_rope = torch.rand(input_shape).unsqueeze(2)
# infer whether it is a positional or temporal embedding
is_pos_emb = 1 if seq_positions.max() < 1 else 0
# if it is positional, scale seq_positions since these are fractions
# in [0,1] and we need int indexes for embedding lookup
seq_positions = seq_positions * 100 if is_pos_emb else seq_positions
seq_positions = seq_positions.unsqueeze(0).int()
# RoPE module takes in dimension, num_queries as input to calculate rotation matrix
rot_mat = self.rope_instance(x_rope, seq_positions)
x_rope = torch.zeros(input_shape, dtype=torch.float32, device=seq_positions.device).unsqueeze(2)
# infer whether it is a positional or temporal embedding
is_pos_emb = 1 if seq_positions.max() < 1 else 0
# if it is positional, scale seq_positions since these are fractions
# in [0,1] and we need int indexes for embedding lookup
seq_positions = seq_positions * 100 if is_pos_emb else seq_positions
seq_positions = seq_positions.unsqueeze(0).int()
# RoPE module takes in dimension, num_queries as input to calculate rotation matrix
rot_mat = self.rope_instance(x_rope, seq_positions)

Comment on lines 104 to 106
index = int(
input(f"{e}. Assuming single run!\nPlease input task index to run:")
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle non-integer input when prompting for task index

When the environment variable POD_INDEX is not set, the code prompts the user for input. If the user enters a non-integer value, a ValueError will be raised. Consider adding error handling to ensure that the program handles invalid input gracefully.

Apply this code change to handle invalid inputs:

except KeyError as e:
    print(f"{e}. Assuming single run!")
    while True:
        try:
            index = int(input("Please input task index to run: "))
            break
        except ValueError:
            print("Invalid input. Please enter an integer value for the task index.")

Comment on lines 108 to 110
hparams_df = pd.read_csv(cfg.batch_config)
hparams = hparams_df.iloc[index].to_dict()
_ = hparams.pop("Unnamed: 0", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add boundary check for 'index' to prevent IndexError

When accessing hparams_df.iloc[index], if index is out of bounds, an IndexError will be raised. Consider adding a check to ensure that the index is within the valid range of the DataFrame.

Apply this code change to add boundary checking:

hparams_df = pd.read_csv(cfg.batch_config)
if not 0 <= index < len(hparams_df):
    raise IndexError(f"Task index {index} is out of bounds. Available indices are from 0 to {len(hparams_df) - 1}.")
hparams = hparams_df.iloc[index].to_dict()

@@ -96,25 +96,35 @@
"""
pred_cfg = Config(cfg)

if "checkpoints" in cfg.keys():
# update with parameters for batch train job
if "batch_config" in cfg.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Simplify dictionary key check by removing .keys()

In Python, you can directly check if a key exists in a dictionary without using .keys(). This is more efficient and idiomatic.

Apply this diff to simplify the condition:

-if "batch_config" in cfg.keys():
+if "batch_config" in cfg:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if "batch_config" in cfg.keys():
if "batch_config" in cfg:
🧰 Tools
🪛 Ruff

100-100: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

@shaikh58 shaikh58 linked an issue Sep 30, 2024 that may be closed by this pull request
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

🧹 Outside diff range and nitpick comments (3)
scripts/run_tracker.py (1)

5-7: Remove commented-out code

The commented-out lines related to changing directories are no longer needed and can be safely removed. Keeping unused code comments can lead to confusion and clutter in the codebase.

Apply this diff to remove the commented-out lines:

-# /Users/mustafashaikh/dreem/dreem/training
-# /Users/main/Documents/GitHub/dreem/dreem/training
-# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
dreem/inference/post_processing.py (1)

Line range hint 126-174: Summary of changes to filter_max_center_dist

The modifications to this function introduce new parameters (h and w) and change the distance calculation logic. While these changes aim to improve functionality, there are several issues that need to be addressed:

  1. The new parameters h and w are not used in the function body.
  2. TODO comments indicate necessary changes that haven't been implemented yet.
  3. The distance normalization logic has been modified, which may affect the function's behavior.

To improve this implementation:

  1. Implement the usage of h and w parameters to scale the distance calculation.
  2. Address the TODO comments by modifying nonk_boxes to only include the previous frame and scaling the distance calculation.
  3. Review and verify the change in distance normalization to ensure it aligns with the intended functionality.
  4. Add or update unit tests to cover these changes and verify the function's behavior.

Please make these adjustments to ensure the function works as intended and maintains consistency with the rest of the codebase.

🧰 Tools
🪛 Ruff

151-151: Local variable k_s is assigned to but never used

Remove assignment to unused variable k_s

(F841)

tests/test_inference.py (1)

218-219: LGTM! Consider clarifying image dimensions assumption.

The addition of h and w parameters to the filter_max_center_dist function calls is a good improvement. It allows the function to handle explicit image dimensions, which is more flexible and robust.

For improved clarity, consider adding a comment explaining the assumption of a square image, or consider using separate variables for height and width to make the test more general:

im_height = 128
im_width = 128
# ... (rest of the code)
h=im_height,
w=im_width

This change would make the test more explicit about handling potentially non-square images, which might be beneficial for future test cases or modifications.

Also applies to: 231-232

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 3bc9fef and 1998f6f.

📒 Files selected for processing (5)
  • dreem/inference/post_processing.py (3 hunks)
  • dreem/inference/tracker.py (14 hunks)
  • scripts/run_tracker.py (1 hunks)
  • scripts/run_trainer.py (1 hunks)
  • tests/test_inference.py (2 hunks)
🧰 Additional context used
🪛 Ruff
dreem/inference/tracker.py

196-196: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


200-200: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

scripts/run_tracker.py

3-3: os imported but unused

Remove unused import: os

(F401)

scripts/run_trainer.py

3-3: os imported but unused

Remove unused import: os

(F401)

🔇 Additional comments (4)
scripts/run_tracker.py (1)

1-12: Overall assessment: Good foundation with room for improvement

The script provides a simple and straightforward way to run the tracking process using a configuration file. It's a good starting point, but there are several areas where it can be improved:

  1. Remove unused imports and commented-out code.
  2. Make the configuration file path more flexible.
  3. Add error handling for both configuration loading and tracking execution.
  4. Consider adding logging for better visibility into the script's execution.

Implementing these suggestions will make the script more robust, flexible, and easier to maintain.

🧰 Tools
🪛 Ruff

3-3: os imported but unused

Remove unused import: os

(F401)

dreem/inference/tracker.py (3)

288-291: Ensure correct indexing of asso_matrix.

When accessing asso_matrix[-1] and asso_matrix[-1].matrix, confirm that asso_matrix is a sequence (e.g., list or tuple) with at least one element. If asso_matrix is a tensor or has an unexpected structure, this could raise an IndexError.


355-355: Verify the structure of img_shape when unpacking.

The line _, h, w = query_frame.img_shape.flatten() assumes that img_shape can be flattened into at least three values. Ensure that img_shape has the correct dimensions and that h and w correspond to the height and width of the image. Misalignment could lead to incorrect scaling in post-processing.


Line range hint 465-479: Review the logic in track assignment condition.

The condition in the if statement may not correctly handle all cases:

if n_traj >= self.max_tracks or traj_score[i, j] > thresh:

Consider the following:

  • If n_traj exceeds max_tracks, the tracker assigns to existing tracks regardless of the traj_score.
  • Comparing traj_score[i, j] (which could be negative) with thresh may not behave as expected, especially if traj_score contains negative values due to earlier operations like taking the negative for linear_sum_assignment.

Please ensure that traj_score values are on the appropriate scale and sign for this comparison, and that the logic correctly determines when to assign to existing tracks versus creating new ones.

# /Users/mustafashaikh/dreem/dreem/training
# /Users/main/Documents/GitHub/dreem/dreem/training
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Make configuration file path more flexible

The configuration file path is currently hardcoded, which might cause issues when running the script in different environments or by different users. Consider making this path more flexible by using environment variables or command-line arguments.

Here's an example of how you could make the path more flexible using environment variables:

import os

config = os.environ.get('DREEM_CONFIG_PATH', "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml")

This allows users to set the DREEM_CONFIG_PATH environment variable to override the default path.

# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml"

cfg = OmegaConf.load(config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for configuration loading

While using OmegaConf for configuration management is good practice, it's important to handle potential errors that may occur during the loading process. This could include file not found errors or parsing issues with the YAML file.

Consider wrapping the configuration loading in a try-except block to handle potential errors gracefully. Here's an example:

import sys

try:
    cfg = OmegaConf.load(config)
except FileNotFoundError:
    print(f"Error: Configuration file not found at {config}")
    sys.exit(1)
except Exception as e:
    print(f"Error loading configuration: {str(e)}")
    sys.exit(1)

This will provide more informative error messages and prevent the script from crashing unexpectedly if there are issues with the configuration file.


cfg = OmegaConf.load(config)

track.run(cfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for tracking execution

The current implementation doesn't handle potential errors that may occur during the tracking process. Adding error handling can improve the script's robustness and provide more informative feedback if issues arise.

Consider wrapping the tracking execution in a try-except block to handle potential errors gracefully. Here's an example:

try:
    track.run(cfg)
except Exception as e:
    print(f"Error during tracking execution: {str(e)}")
    sys.exit(1)

This will catch any exceptions thrown during the tracking process, print an informative error message, and exit the script with a non-zero status code to indicate failure.

Comment on lines +1 to +3
from dreem.inference import track
from omegaconf import OmegaConf
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused import

The os module is imported but not used in the script. To keep the imports clean and avoid potential confusion, it's recommended to remove unused imports.

Apply this diff to remove the unused import:

from dreem.inference import track
from omegaconf import OmegaConf
-import os
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from dreem.inference import track
from omegaconf import OmegaConf
import os
from dreem.inference import track
from omegaconf import OmegaConf
🧰 Tools
🪛 Ruff

3-3: os imported but unused

Remove unused import: os

(F401)

Comment on lines +11 to +14
cfg = OmegaConf.load(base_config)
# Load and merge override config
override_cfg = OmegaConf.load(params_config)
cfg = OmegaConf.merge(cfg, override_cfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Add error handling for configuration loading.

While the configuration loading process is correct, it lacks error handling. This could lead to unclear error messages if there are issues with the configuration files.

Consider wrapping the configuration loading in a try-except block to handle potential errors:

try:
    cfg = OmegaConf.load(base_config)
    override_cfg = OmegaConf.load(params_config)
    cfg = OmegaConf.merge(cfg, override_cfg)
except Exception as e:
    print(f"Error loading configuration: {e}")
    exit(1)

This change will provide clearer error messages if there are issues with loading or merging the configuration files.

Comment on lines +129 to +130
h: int = None,
w: int = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Implement the usage of new parameters h and w

The new parameters h and w have been added to the function signature, but they are not used in the function body. These parameters are intended to represent the height and width of the image for scaling the distance calculation.

To address this, implement the scaling of the distance calculation using h and w. Here's a suggested implementation:

# After calculating dist
scaled_dist = dist * torch.tensor([h, w], device=dist.device)
norm_dist = scaled_dist.mean(axis=-1)  # n_k x Np

This change will convert the distance from normalized coordinates to pixel units.

Also applies to: 140-141

@@ -147,13 +151,15 @@
k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k

nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2

# TODO: nonk_boxes should be only from previous frame rather than entire window
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Address TODO comments regarding distance calculation

The TODO comments indicate that:

  1. nonk_boxes should only be from the previous frame rather than the entire window.
  2. The distance calculation needs to be scaled by the original image size.

To address these issues:

  1. Modify the nonk_boxes input to only include boxes from the previous frame.
  2. Implement the scaling of the distance calculation as suggested in the previous comment.

Please update the implementation to reflect these changes and remove the TODO comments once addressed.

Also applies to: 158-160

Comment on lines +160 to +162
# norm_dist = dist / (k_s[:, None, :] + 1e-8)
norm_dist = dist.mean(axis=-1) # n_k x Np
# norm_dist =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Missing Tests for filter_max_center_dist

No existing tests found for filter_max_center_dist. Please add appropriate test cases to ensure that the new distance normalization aligns with the intended functionality.

🔗 Analysis chain

Review the change in distance normalization

The distance normalization has been changed from dividing by the box size to taking the mean across the last dimension.

This change might affect the behavior of the function. Please verify if this new approach aligns with the intended functionality of filter_max_center_dist. Consider running the following test:

If the change is intentional and correct, please add a comment explaining the rationale behind this modification.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Compare the output of the old and new implementations

# Test: Search for test cases of filter_max_center_dist
rg --type python -A 10 "def test_filter_max_center_dist"

# If no tests are found, consider adding a new test case
if [ $? -ne 0 ]; then
    echo "No existing tests found for filter_max_center_dist. Consider adding a new test case."
fi

Length of output: 286

Comment on lines 195 to 202
# if track ids exist from another tracking program i.e. sleap, init with those
for i, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)

# if no track ids, then assign new ones
for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track += 1
curr_track_id += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Simplify code by using frame_to_track directly.

Instead of accessing frames[batch_idx].instances, you can use frame_to_track.instances, which is already available from the loop. This makes the code cleaner and avoids redundant indexing.

Apply this diff to simplify the code:

-        for i, instance in enumerate(frames[batch_idx].instances):
+        for i, instance in enumerate(frame_to_track.instances):
             instance.pred_track_id = instance.gt_track_id
             curr_track_id = max(curr_track_id, instance.pred_track_id)

-        for i, instance in enumerate(frames[batch_idx].instances):
+        for i, instance in enumerate(frame_to_track.instances):
             if instance.pred_track_id == -1:
                 curr_track_id += 1
                 instance.pred_track_id = curr_track_id
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# if track ids exist from another tracking program i.e. sleap, init with those
for i, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)
# if no track ids, then assign new ones
for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track += 1
curr_track_id += 1
# if track ids exist from another tracking program i.e. sleap, init with those
for i, instance in enumerate(frame_to_track.instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)
# if no track ids, then assign new ones
for i, instance in enumerate(frame_to_track.instances):
if instance.pred_track_id == -1:
curr_track_id += 1
instance.pred_track_id = curr_track_id
🧰 Tools
🪛 Ruff

196-196: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


200-200: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

for batch_idx, frame_to_track in enumerate(frames):
# tracked_frames is a list of reference frames that have been tracked (associated)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

⚠️ Potential issue

Potential AttributeError Confirmed: frame_id is inconsistently handled as both tensors and integers.

  • In dreem/inference/tracker.py, frame_to_track.frame_id is accessed as a tensor (frame_to_track.frame_id.device).
  • Multiple instances in other files assign frame_id as an integer (e.g., frame_id=0), which would cause an AttributeError when .device is accessed.
🔗 Analysis chain

Potential AttributeError when accessing frame_id.device.

In the line device=frame_to_track.frame_id.device, ensure that frame_to_track.frame_id is a tensor. If frame_id is an integer, accessing .device will raise an AttributeError. Please verify the type of frame_id.

You can run the following script to check the type of frame_id across the codebase:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that 'frame_id' is consistently a tensor.

# Test: Search for 'frame_id' assignments. Expect: 'frame_id' should be a tensor.
rg --type python --context 2 'frame_id\s*='

Length of output: 76


Script:

#!/bin/bash
# Description: Verify that 'frame_id' is consistently a tensor.

# Test: Search for 'frame_id' assignments in Python files. Expect: 'frame_id' should be a tensor.
rg --glob "*.py" --context 2 'frame_id\s*='

Length of output: 7061

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1998f6f and 2792d13.

📒 Files selected for processing (9)
  • dreem/datasets/sleap_dataset.py (3 hunks)
  • dreem/inference/eval.py (1 hunks)
  • dreem/inference/track.py (2 hunks)
  • dreem/io/config.py (5 hunks)
  • dreem/models/embedding.py (14 hunks)
  • dreem/models/gtr_runner.py (2 hunks)
  • dreem/models/transformer.py (12 hunks)
  • dreem/training/configs/base.yaml (6 hunks)
  • dreem/training/train.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
dreem/inference/eval.py

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


50-50: Undefined name checkpoint

(F821)


54-54: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/inference/track.py

112-112: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


132-132: Undefined name checkpoint

(F821)

dreem/datasets/sleap_dataset.py

208-208: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)

dreem/io/config.py

229-229: Undefined name SleapDataset

(F821)


229-229: Undefined name MicroscopyDataset

(F821)


229-229: Undefined name CellTrackingDataset

(F821)


304-304: Undefined name dataset

(F821)


305-305: Undefined name dataset

(F821)


307-307: Undefined name dataset

(F821)


326-326: Undefined name SleapDataset

(F821)


326-326: Undefined name MicroscopyDataset

(F821)


326-326: Undefined name CellTrackingDataset

(F821)

dreem/models/embedding.py

104-104: Multiple statements on one line (colon)

(E701)


267-267: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/models/transformer.py

301-301: SyntaxError: Expected 'else', found ':'


301-302: SyntaxError: Expected ')', found newline


302-302: SyntaxError: Unexpected indentation

🪛 yamllint (1.35.1)
dreem/training/configs/base.yaml

[warning] 32-32: wrong indentation: expected 6 but found 8

(indentation)


[warning] 35-35: wrong indentation: expected 6 but found 8

(indentation)

🔇 Additional comments (26)
dreem/inference/eval.py (4)

30-30: Simplify dictionary key check

You can simplify the condition by removing .keys(), as it's unnecessary. The more Pythonic way is:

- if "batch_config" in cfg.keys():
+ if "batch_config" in cfg:
🧰 Tools
🪛 Ruff (0.8.2)

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


34-36: Add input validation for user input

When prompting the user for the task index, consider adding input validation to handle non-integer inputs gracefully. This prevents the program from crashing if the user enters invalid data.


53-53: Remove unnecessary f prefix in log message

The log message does not contain any variables to interpolate, so the f prefix is unnecessary. You can simplify the line as follows:

- logger.info(f"Using the following tracker:")
+ logger.info("Using the following tracker:")

55-55: Use logger instead of print statement

To maintain consistent logging practices, replace the print statement with logger.info. This ensures that all output is handled by the logging system.

- print(model.tracker)
+ logger.info(model.tracker)
dreem/inference/track.py (3)

112-112: Simplify dictionary key check

You can simplify the condition by removing .keys(), as it's unnecessary. The more Pythonic way is:

- if "batch_config" in cfg.keys():
+ if "batch_config" in cfg:
🧰 Tools
🪛 Ruff (0.8.2)

112-112: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


116-118: Add input validation for task index input

When the environment variable POD_INDEX is not set, the code prompts the user for input. If the user enters a non-integer value, a ValueError will be raised. Consider adding error handling to ensure the program handles invalid input gracefully.


120-122: Add boundary check for index to prevent IndexError

When accessing hparams_df.iloc[index], if index is out of bounds, an IndexError will be raised. Consider adding a check to ensure that index is within the valid range of the DataFrame.

dreem/models/embedding.py (3)

104-104: Split multiple statements into separate lines

For better readability and to comply with PEP 8 style guidelines, split multiple statements into separate lines:

- if input_pos.dim() <= 1: input_pos = input_pos.unsqueeze(0)
+ if input_pos.dim() <= 1:
+     input_pos = input_pos.unsqueeze(0)
🧰 Tools
🪛 Ruff (0.8.2)

104-104: Multiple statements on one line (colon)

(E701)


267-267: Remove unnecessary f prefix in exception message

The string does not contain any placeholders, so the f prefix is unnecessary.

- f"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
+ "Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
🧰 Tools
🪛 Ruff (0.8.2)

267-267: f-string without any placeholders

Remove extraneous f prefix

(F541)


565-568: ⚠️ Potential issue

Handle potential division by zero in time normalization

In the _learned_temp_embedding method, times is normalized by dividing by times.max(). If times.max() is zero, this will result in a division by zero error. To prevent this, add a check to ensure that times.max() is not zero before performing the division.

Apply this diff to handle the potential division by zero:

 def _learned_temp_embedding(self, times: torch.Tensor, *args) -> torch.Tensor:
     # ...
+    max_time = times.max()
+    if max_time > 0:
+        times = times / max_time
+    else:
+        times = times.clone()  # All times are zero, no normalization needed

Likely invalid or redundant comment.

dreem/models/transformer.py (1)

87-91: Simplify code using dict.get()

Simplify the retrieval of embedding_agg_method by using the get method:

- self.embedding_agg_method = (
-     embedding_meta["embedding_agg_method"]
-     if "embedding_agg_method" in embedding_meta
-     else "average"
- )
+ self.embedding_agg_method = embedding_meta.get("embedding_agg_method", "average")

This improves readability and follows Python best practices.

dreem/training/configs/base.yaml (3)

101-101: LGTM: Dataset mode configuration.

The addition of mode fields to dataset configurations properly identifies the purpose of each dataset.

Also applies to: 110-110, 119-119


32-36: ⚠️ Potential issue

Fix duplicate configuration and indentation.

The embedding configuration has duplicate mode declarations and incorrect indentation.

Apply this diff to fix the issues:

-        mode: "fixed" # supports fixed, learned, rope
-        normalize: true
-    temp:
-        mode: "fixed" # supports fixed, learned, rope
-    embedding_agg_method: "stack" # supports stack, average, concatenate
+      mode: "fixed" # supports fixed, learned, rope
+      normalize: true
+    temp:
+      mode: "fixed" # supports fixed, learned, rope
+    embedding_agg_method: "stack" # supports stack, average, concatenate

Likely invalid or redundant comment.

🧰 Tools
🪛 yamllint (1.35.1)

[warning] 32-32: wrong indentation: expected 6 but found 8

(indentation)


[warning] 35-35: wrong indentation: expected 6 but found 8

(indentation)


142-142: Verify impact of disabled model logging.

Setting log_model to null disables model logging. Please ensure this won't affect model checkpointing functionality.

✅ Verification successful

Model logging disabled but checkpointing remains functional

Based on the search results, the log_model configuration appears to be independent of the model checkpointing functionality. The checkpointing is controlled by separate configurations that are still enabled:

  • enable_checkpointing: true in trainer settings
  • Dedicated checkpointing section with proper monitoring and save settings
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if there are any other model logging configurations
rg -l "log_model" --type yaml

Length of output: 29


Script:

#!/bin/bash
# Let's try a broader search for model logging related patterns
rg -i "model.*log|log.*model" --type yaml -A 3 -B 3

# Also search for checkpointing related configurations
rg -i "checkpoint|save.*model|model.*save" --type yaml -A 3 -B 3

Length of output: 947


Script:

#!/bin/bash
# Let's check the actual implementation to understand how log_model is used
rg -A 5 -B 5 "log_model" --type python

# Also search for any model logging related code
ast-grep --pattern 'def $_(self, $$$):
  $$$
  self.log($$$)
  $$$'

Length of output: 139

dreem/training/train.py (2)

Line range hint 57-65: LGTM: Structured dataset handling.

The implementation of mode-specific dataset and dataloader creation improves code organization and aligns with the configuration changes.


100-102: LGTM: Flexible device configuration.

The implementation properly handles both GPU and CPU scenarios, with efficient multiprocessing support for CPU-only environments.

dreem/models/gtr_runner.py (2)

303-303: LGTM: Simplified video name extraction.

The change to use path splitting for video name extraction is more robust than relying on file extensions.


312-315: LGTM: Consistent track ID handling.

The implementation now consistently accesses track IDs directly from instances, improving code clarity and maintainability.

Also applies to: 321-324

dreem/datasets/sleap_dataset.py (4)

140-140: LGTM: Good architectural improvement

The change to use a dictionary for video readers improves resource management and provides clearer mapping between video names and their readers.


207-212: LGTM: Improved frame retrieval logic

The new implementation efficiently handles frame retrieval by first attempting to get the image directly from the label frame and falling back to the video reader if needed.

🧰 Tools
🪛 Ruff (0.8.2)

208-208: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)


208-208: Remove unused variable e in exception handling

The variable e in the except FileNotFoundError as e: block is assigned but never used.

🧰 Tools
🪛 Ruff (0.8.2)

208-208: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)


417-418: Correct iteration over dictionary to access video readers

Iterating directly over self.vid_readers accesses the dictionary keys instead of the reader objects.

dreem/io/config.py (4)

229-229: LGTM: Improved type safety with explicit None return type

The return type annotation now correctly reflects that the method can return None, improving type safety and code clarity.

🧰 Tools
🪛 Ruff (0.8.2)

229-229: Undefined name SleapDataset

(F821)


229-229: Undefined name MicroscopyDataset

(F821)


229-229: Undefined name CellTrackingDataset

(F821)


304-307: Use logger.warning instead of deprecated logger.warn

The logger.warn() method is deprecated since Python 3.3.

🧰 Tools
🪛 Ruff (0.8.2)

304-304: Undefined name dataset

(F821)


305-305: Undefined name dataset

(F821)


307-307: Undefined name dataset

(F821)


326-328: LGTM: Consistent type safety improvements

The parameter and return type annotations now correctly handle None cases, maintaining consistency with the dataset changes.

🧰 Tools
🪛 Ruff (0.8.2)

326-326: Undefined name SleapDataset

(F821)


326-326: Undefined name MicroscopyDataset

(F821)


326-326: Undefined name CellTrackingDataset

(F821)


506-506: LGTM: Better checkpoint naming

Including the metric in the final checkpoint name improves checkpoint identification and traceability.

else:
checkpoint = eval_cfg.cfg.ckpt_path
hparams = {}

logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper())

model = GTRRunner.load_from_checkpoint(checkpoint, strict=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Undefined variable checkpoint

The variable checkpoint is used here but has not been defined, leading to a NameError at runtime.

Please define checkpoint before using it. If it was previously set from eval_cfg.cfg.ckpt_path, you may need to reintroduce that logic or ensure that checkpoint is correctly initialized.

🧰 Tools
🪛 Ruff (0.8.2)

50-50: Undefined name checkpoint

(F821)

else:
checkpoint = pred_cfg.cfg.ckpt_path
hparams = {}

logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper())

model = GTRRunner.load_from_checkpoint(checkpoint, strict=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Undefined variable checkpoint

The variable checkpoint is used here but has not been defined, leading to a NameError at runtime.

Please define checkpoint before using it. If it was previously set from pred_cfg.cfg.ckpt_path, you may need to reintroduce that logic or ensure that checkpoint is correctly initialized.

🧰 Tools
🪛 Ruff (0.8.2)

132-132: Undefined name checkpoint

(F821)

Comment on lines +172 to +173
max_epochs: 1
min_epochs: 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Verify training epoch reduction impact.

Setting both max_epochs and min_epochs to 1 seems unusually low for training a transformer model, especially with RoPE embeddings.

Consider increasing the epochs to allow proper model convergence, typically:

-  max_epochs: 1
-  min_epochs: 1
+  max_epochs: 100
+  min_epochs: 10
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
max_epochs: 1
min_epochs: 1
max_epochs: 100
min_epochs: 10

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (4)
dreem/inference/track.py (2)

Line range hint 31-63: Add error handling for DataFrame operations.

The function should handle potential errors when accessing instance attributes and performing DataFrame operations.

Consider wrapping the DataFrame operations in a try-except block:

 def export_trajectories(
     frames_pred: list["dreem.io.Frame"], save_path: str | None = None
 ) -> pd.DataFrame:
     save_dict = {}
     frame_ids = []
     X, Y = [], []
     pred_track_ids = []
     track_scores = []
-    for frame in frames_pred:
-        for i, instance in enumerate(frame.instances):
-            frame_ids.append(frame.frame_id.item())
-            bbox = instance.bbox.squeeze()
-            y = (bbox[2] + bbox[0]) / 2
-            x = (bbox[3] + bbox[1]) / 2
-            X.append(x.item())
-            Y.append(y.item())
-            track_scores.append(instance.track_score)
-            pred_track_ids.append(instance.pred_track_id.item())
+    try:
+        for frame in frames_pred:
+            for i, instance in enumerate(frame.instances):
+                frame_ids.append(frame.frame_id.item())
+                bbox = instance.bbox.squeeze()
+                y = (bbox[2] + bbox[0]) / 2
+                x = (bbox[3] + bbox[1]) / 2
+                X.append(x.item())
+                Y.append(y.item())
+                track_scores.append(instance.track_score)
+                pred_track_ids.append(instance.pred_track_id.item())
+    except (AttributeError, IndexError) as e:
+        logger.error(f"Error processing frames: {e}")
+        raise

Line range hint 66-93: Add detailed logging and improve video object creation.

The function would benefit from more detailed logging and safer video object creation.

Consider these improvements:

 def track(
     model: GTRRunner, trainer: pl.Trainer, dataloader: torch.utils.data.DataLoader
 ) -> list[pd.DataFrame]:
+    logger.info("Starting inference process...")
     preds = trainer.predict(model, dataloader)
     pred_slp = []
     tracks = {}
     for batch in preds:
         for frame in batch:
             if frame.frame_id.item() == 0:
-                video = (
-                    sio.Video(frame.video)
-                    if isinstance(frame.video, str)
-                    else sio.Video
-                )
+                try:
+                    video = sio.Video(frame.video) if isinstance(frame.video, str) else sio.Video()
+                    logger.debug(f"Created video object for frame {frame.frame_id}")
+                except Exception as e:
+                    logger.error(f"Failed to create video object: {e}")
+                    raise
             lf, tracks = frame.to_slp(tracks, video=video)
             pred_slp.append(lf)
+    logger.info(f"Completed inference with {len(pred_slp)} predictions")
     pred_slp = sio.Labels(pred_slp)
-    print(pred_slp)
+    logger.info(f"Generated SLEAP labels: {pred_slp}")
     return pred_slp
dreem/models/transformer.py (2)

87-91: Simplify code using dict.get()

Use Python's built-in dict.get() method to simplify the code.

- self.embedding_agg_method = (
-     embedding_meta["embedding_agg_method"]
-     if "embedding_agg_method" in embedding_meta
-     else "average"
- )
+ self.embedding_agg_method = embedding_meta.get("embedding_agg_method", "average")
🧰 Tools
🪛 Ruff (0.8.2)

88-90: Use embedding_meta.get("embedding_agg_method", "average") instead of an if block

Replace with embedding_meta.get("embedding_agg_method", "average")

(SIM401)


788-803: Document bounding box format assumptions

The function assumes a specific bounding box format [ymin, xmin, ymax, xmax] but this critical detail is only mentioned in the docstring. Consider adding runtime validation to prevent silent errors.

 def spatial_emb_from_bb(bb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+    if bb.shape[-1] != 4:
+        raise ValueError(f"Expected bounding boxes with 4 coordinates, got {bb.shape[-1]}")
     return (
         bb[:, :, [1, 3]].mean(axis=2).squeeze(),
         bb[:, :, [0, 2]].mean(axis=2).squeeze(),
     )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2792d13 and e4ce29c.

📒 Files selected for processing (4)
  • dreem/datasets/sleap_dataset.py (1 hunks)
  • dreem/inference/eval.py (1 hunks)
  • dreem/inference/track.py (1 hunks)
  • dreem/models/transformer.py (12 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • dreem/inference/eval.py
  • dreem/datasets/sleap_dataset.py
🧰 Additional context used
🪛 Ruff (0.8.2)
dreem/models/transformer.py

88-90: Use embedding_meta.get("embedding_agg_method", "average") instead of an if block

Replace with embedding_meta.get("embedding_agg_method", "average")

(SIM401)

🔇 Additional comments (2)
dreem/inference/track.py (1)

Line range hint 96-164: Improve error handling and configuration validation.

The function needs better error handling and configuration validation, especially for the new RoPE embeddings integration.

Consider these improvements:

 @hydra.main(config_path=None, config_name=None, version_base=None)
 def run(cfg: DictConfig) -> dict[int, sio.Labels]:
+    logger.info("Starting inference with configuration...")
     pred_cfg = Config(cfg)
 
-    if "checkpoints" in cfg.keys():
+    if "checkpoints" in cfg:
         try:
             index = int(os.environ["POD_INDEX"])
-        # For testing without deploying a job on runai
         except KeyError:
-            index = input("Pod Index Not found! Please choose a pod index: ")
+            while True:
+                try:
+                    index = int(input("Pod Index Not found! Please choose a pod index: "))
+                    break
+                except ValueError:
+                    print("Invalid input. Please enter an integer value.")
 
         logger.info(f"Pod Index: {index}")
 
         checkpoints = pd.read_csv(cfg.checkpoints)
+        if not 0 <= index < len(checkpoints):
+            raise IndexError(f"Pod index {index} out of range [0, {len(checkpoints)-1}]")
         checkpoint = checkpoints.iloc[index]
     else:
         checkpoint = pred_cfg.cfg.ckpt_path
+        if not os.path.exists(checkpoint):
+            raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
 
     logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper())
 
+    logger.info(f"Loading model from checkpoint: {checkpoint}")
     model = GTRRunner.load_from_checkpoint(checkpoint, strict=False)

Additionally, since this PR introduces RoPE embeddings, we should verify the embedding configuration:

#!/bin/bash
# Description: Check for RoPE embedding configuration in the codebase
# Look for RoPE configuration in yaml files
rg -t yaml "rope|rotary|positional.*embedding" 

# Look for RoPE implementation in Python files
ast-grep --pattern 'class $_(RoPE|RotaryPositionalEmbedding)' 
dreem/models/transformer.py (1)

628-641: Review architectural decision about embedding application

The TODO comment raises an important question about applying embeddings to encoder output before cross attention. This architectural decision needs careful consideration:

  1. For "average" method, embeddings are applied twice (potentially redundant)
  2. For "stack" and "concatenate", embeddings are applied only once (potentially inconsistent)

Consider:

  • Document the rationale for this difference in behavior
  • Evaluate if this asymmetry affects model performance
  • Consider unifying the behavior across all methods if possible
✅ Verification successful

Let me gather more information about the embedding application in the transformer model.


Based on the gathered information, I can now provide a final response:

Architectural decision about embedding application is justified

The different behavior between embedding methods is intentional and correct:

  1. For "average" method: The double application is needed because it averages positional and temporal embeddings into the same feature space, maintaining backward compatibility.

  2. For "stack" and "concatenate": These methods are explicitly prevented from double application because:

    • They already preserve the separate embedding information by design (shown in aggregate_queries function)
    • For "stack", embeddings are handled as channels (t1,t2,t3...), (x1,x2,x3...), (y1,y2,y3...)
    • For "concatenate", embeddings are combined through an MLP

The TODO comment can be removed as the architectural decision is well-implemented and the behavior difference is by design.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any tests or documentation that might explain this architectural decision
rg -l "embedding.*cross.*attention" --type py

Length of output: 75


Script:

#!/bin/bash
# Search for the implementation of apply_embeddings and its usage
ast-grep --pattern 'def apply_embeddings'

# Look for tests related to embedding methods
rg -A 5 "embedding_agg_method" --type py

# Search for documentation or comments about embedding behavior
rg -A 3 "embedding.*average|embedding.*stack|embedding.*concatenate" --type py

Length of output: 25205

Comment on lines +234 to +245
if self.return_embedding:
for i, instance in enumerate(ref_instances):
if self.embedding_agg_method == "average":
ref_pos_emb = pos_emb_traceback[0][i] # array
else:
ref_pos_emb = {
"x": pos_emb_traceback[0][0][i],
"y": pos_emb_traceback[1][0][i],
} # dict

n_query = total_instances
instance.add_embedding("pos", ref_pos_emb) # can be an array or a dict
instance.add_embedding("temp", temp_emb_traceback)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor duplicate embedding handling code

The code block for handling embeddings in instances is duplicated. Consider extracting this into a helper method.

+ def _add_embeddings_to_instance(self, instance, pos_emb_traceback, temp_emb_traceback, index):
+     if self.embedding_agg_method == "average":
+         ref_pos_emb = pos_emb_traceback[0][index]
+     else:
+         ref_pos_emb = {
+             "x": pos_emb_traceback[0][0][index],
+             "y": pos_emb_traceback[1][0][index],
+         }
+     instance.add_embedding("pos", ref_pos_emb)
+     instance.add_embedding("temp", temp_emb_traceback)

- if self.return_embedding:
-     for i, instance in enumerate(ref_instances):
-         if self.embedding_agg_method == "average":
-             ref_pos_emb = pos_emb_traceback[0][i]
-         else:
-             ref_pos_emb = {
-                 "x": pos_emb_traceback[0][0][i],
-                 "y": pos_emb_traceback[1][0][i],
-             }
-         instance.add_embedding("pos", ref_pos_emb)
-         instance.add_embedding("temp", temp_emb_traceback)
+ if self.return_embedding:
+     for i, instance in enumerate(ref_instances):
+         self._add_embeddings_to_instance(instance, pos_emb_traceback, temp_emb_traceback, i)

Also applies to: 300-311

Comment on lines +771 to +781
mlp = MLP(
input_dim=queries_t.shape[-1] * 3,
hidden_dim=queries_t.shape[-1] * 2,
output_dim=queries_t.shape[-1],
num_layers=1,
dropout=0.0,
)
# concatenated is of shape (batch_size, n_query, 3*embed_dim)
collated_queries = torch.cat((queries_t, queries_x, queries_y), dim=2)
# pass through MLP to project into space of (batch_size, n_query, embed_dim)
collated_queries = mlp(collated_queries)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider caching or pre-initializing MLP for concatenate method

The current implementation creates a new MLP instance on every call when using the "concatenate" method. This could be inefficient, especially during inference.

Consider moving the MLP initialization to the constructor and reusing it:

+ def __init__(self, ...):
+     self.concat_mlp = MLP(
+         input_dim=d_model * 3,
+         hidden_dim=d_model * 2,
+         output_dim=d_model,
+         num_layers=1,
+         dropout=0.0,
+     )

- mlp = MLP(
-     input_dim=queries_t.shape[-1] * 3,
-     hidden_dim=queries_t.shape[-1] * 2,
-     output_dim=queries_t.shape[-1],
-     num_layers=1,
-     dropout=0.0,
- )
- collated_queries = mlp(collated_queries)
+ collated_queries = self.concat_mlp(collated_queries)

Committable suggestion skipped: line range outside the PR's diff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix implementation of max distance hyperparam for track linking Implement Rotary Positional Embeddings
3 participants