-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RoPE embeddings #78
base: main
Are you sure you want to change the base?
RoPE embeddings #78
Conversation
- changes to embedding class - add apply() function to Embedding class - remove references to embedding from encoderlayer fwd pass
- add support for stack/avg/concatenate - move embedding processing out of transformer and into encoder
- get centroid from x,y for spatial embedding - complete stack agg method - add docstrings
- concatenation method with mlp - complete pre-processing for input to EncoderLayer - fix shape issues in rope/additive_embedding/forward modules in embedding.py
- bounding box embedding only for method "average" - modify emb_funcs routing - temporarily remove support for adding embeddings into instance objects - need to make compatible with x,y,t embeddings - remove config yamls from updates - current versions serve as templates - runs through to end of encoder forward pass
- implement embeddings for decoder + refactor - add 1x1 conv to final attn head to deal with stacked embeddings (3x tokens) and create channels for each dim - bug fix in rope rotation matrix product with input data
- 1x1 conv for stack embedding - stack into 3 channels for x,y,t
- add unit tests for rope - Update existing tests to use new args/return params related to tfmr - Modify test to remove return_embedding=True support - need to address this
- create rope isntance once rather than each fwd pass - construct embedding lookup array each fwd pass based on num instances passed in to embedding - scale only pos embs * 100 rather than also temp embs
- times array for embedding for encoder queries inside decoder was of query size rather than ref size
WalkthroughThe recent changes enhance the Changes
Possibly related PRs
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 28
Outside diff range, codebase verification and nitpick comments (4)
run_trainer.py (1)
10-10
: Avoid commented-out code.Instead of commenting out the override configuration, consider using a command-line argument or environment variable to specify which configuration file to use.
- # params_config = "./configs/override.yaml"
tests/test_models.py (2)
247-248
: Clarify the shape of input data.The comment on line 247 specifies the shape of the input data. Ensure this comment remains accurate and update it if the input shape changes in future modifications.
317-324
: Ensure consistency in input data shape comments.The comment on line 323 specifies the shape of the input data. Ensure consistency across the file by maintaining accurate comments about data shapes.
rope.ipynb (1)
307-311
: Avoid hardcoded dataset paths.The dataset path is hardcoded, which can lead to issues when running the notebook in different environments. Consider parameterizing the path or using a configuration file.
- train_path = "/home/jovyan/talmolab-smb/datasets/mot/microscopy/airyscan_proofread/Final/dreem-train" + train_path = os.getenv("DREEM_TRAIN_PATH", "/default/path/to/dreem-train")
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files ignored due to path filters (1)
dreem/training/configs/test_batch_train.csv
is excluded by!**/*.csv
Files selected for processing (15)
- .gitignore (1 hunks)
- dreem/inference/tracker.py (14 hunks)
- dreem/io/config.py (1 hunks)
- dreem/io/instance.py (1 hunks)
- dreem/models/attention_head.py (2 hunks)
- dreem/models/embedding.py (13 hunks)
- dreem/models/mlp.py (2 hunks)
- dreem/models/transformer.py (14 hunks)
- dreem/training/configs/base.yaml (6 hunks)
- dreem/training/configs/override.yaml (1 hunks)
- dreem/training/train.py (2 hunks)
- rope.ipynb (1 hunks)
- run_trainer.py (1 hunks)
- tests/test_models.py (12 hunks)
- tests/test_training.py (1 hunks)
Files skipped from review due to trivial changes (2)
- .gitignore
- dreem/training/train.py
Additional context used
yamllint
dreem/training/configs/base.yaml
[warning] 19-19: wrong indentation: expected 6 but found 8
(indentation)
[warning] 22-22: wrong indentation: expected 6 but found 8
(indentation)
dreem/training/configs/override.yaml
[error] 3-3: trailing spaces
(trailing-spaces)
[warning] 4-4: wrong indentation: expected 4 but found 6
(indentation)
[warning] 20-20: wrong indentation: expected 6 but found 8
(indentation)
[warning] 24-24: wrong indentation: expected 6 but found 8
(indentation)
[warning] 29-29: too many spaces after colon
(colons)
[error] 43-43: trailing spaces
(trailing-spaces)
[warning] 46-46: wrong indentation: expected 6 but found 8
(indentation)
[warning] 58-58: too few spaces after comma
(commas)
[warning] 66-66: too few spaces after comma
(commas)
[error] 76-76: trailing spaces
(trailing-spaces)
[warning] 78-78: wrong indentation: expected 4 but found 6
(indentation)
[error] 78-78: trailing spaces
(trailing-spaces)
[error] 99-99: trailing spaces
(trailing-spaces)
[error] 112-112: trailing spaces
(trailing-spaces)
[error] 130-130: trailing spaces
(trailing-spaces)
[warning] 132-132: wrong indentation: expected 2 but found 4
(indentation)
[warning] 133-133: wrong indentation: expected 6 but found 8
(indentation)
[warning] 137-137: wrong indentation: expected 6 but found 8
(indentation)
[error] 142-142: no new line character at the end of file
(new-line-at-end-of-file)
Ruff
dreem/models/attention_head.py
28-31: Use
self.embedding_agg_method = kwargs.get('embedding_agg_method', None)
instead of anif
blockReplace with
self.embedding_agg_method = kwargs.get('embedding_agg_method', None)
(SIM401)
tests/test_models.py
18-18:
dreem.models.transformer.apply_embeddings
imported but unusedRemove unused import:
dreem.models.transformer.apply_embeddings
(F401)
166-166:
pytest.raises(Exception)
should be considered evil(B017)
dreem/inference/tracker.py
196-196: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
dreem/io/instance.py
568-568: Use
is
andis not
for type comparisons, orisinstance()
for isinstance checks(E721)
dreem/models/embedding.py
267-267: f-string without any placeholders
Remove extraneous
f
prefix(F541)
dreem/models/transformer.py
84-85: Use
embedding_meta.get("embedding_agg_method", "average")
instead of anif
blockReplace with
embedding_meta.get("embedding_agg_method", "average")
(SIM401)
Additional comments not posted (39)
run_trainer.py (1)
15-15
: LGTM!The training function is correctly invoked with the configuration.
dreem/models/mlp.py (2)
37-37
: Comment improves readability.The added comment clarifies the purpose of list concatenations, enhancing code readability.
59-59
: Ensures device compatibility.Moving each layer to the device of the input tensor is a good practice to avoid device mismatch errors.
dreem/training/configs/base.yaml (5)
23-23
: LGTM!The
embedding_agg_method
is correctly set to "stack".
Line range hint
76-94
:
Improves configuration clarity.The addition of mode settings for datasets enhances clarity and explicitness in configuration.
Tools
yamllint
[error] 82-82: trailing spaces
(trailing-spaces)
117-117
: LGTM!Setting
log_model
to null aligns with the goal of reducing logging verbosity.
134-134
: Enhances checkpointing strategy.Setting
every_n_epochs
to 1 improves recovery options during training.
147-148
: LGTM!Setting
max_epochs
andmin_epochs
to 1 aligns with the goal of quick iterations or testing.dreem/models/attention_head.py (2)
33-44
: Ensure correct handling of embedding aggregation methods.The logic for handling the "stack" method using
torch.nn.Conv1d
is correctly implemented. Ensure that the reshaping and permutation logic in theforward
method aligns with the expected input shapes for the convolutional layers.
60-79
: Verify tensor reshaping and permutation inforward
.The reshaping and permutation logic for the "stack" method should be verified to ensure it correctly aligns with the input requirements of the convolutional layers. Ensure that the dimensions and operations are consistent with the intended design.
Verification successful
Tensor reshaping and permutation logic is verified.
The reshaping and permutation logic in the
forward
method aligns with the input requirements of theConv1d
layers used ink_proj
andq_proj
. The reshaped tensors have the correct channel dimension of 3, consistent with theConv1d
configuration. The use ofsqueeze(0)
is appropriate given the expected batch size of 1. No issues found with the current implementation.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the tensor reshaping and permutation logic in the `forward` method. # Test: Search for the `forward` method to review its implementation. ast-grep --lang python --pattern $'def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:\n $$$' dreem/models/attention_head.pyLength of output: 2923
Script:
#!/bin/bash # Description: Search for the definitions of `k_proj` and `q_proj` to verify compatibility with reshaped tensors. # Find the definition of `k_proj` in the codebase. ast-grep --lang python --pattern 'self.k_proj = $_' dreem/models/attention_head.py # Find the definition of `q_proj` in the codebase. ast-grep --lang python --pattern 'self.q_proj = $_' dreem/models/attention_head.pyLength of output: 1009
tests/test_training.py (1)
145-146
: Reconsider commenting out anomaly detection.Commenting out
torch.autograd.set_detect_anomaly(True)
reduces the ability to debug and catch subtle errors during the execution ofrun(cfg.cfg)
. Consider re-enabling it, especially during development or when debugging complex issues.Verification successful
Reconsider commenting out anomaly detection in
tests/test_training.py
.The
torch.autograd.set_detect_anomaly(True)
line is commented out and not used elsewhere in the test suite. Re-enabling it could enhance debugging capabilities by catching subtle errors during test execution.
- File:
tests/test_training.py
, line 145Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify if anomaly detection is used elsewhere in the tests. # Test: Search for usage of `set_detect_anomaly` in the test directory. rg --type python 'set_detect_anomaly'Length of output: 68
Script:
#!/bin/bash # Description: Verify if anomaly detection is used elsewhere in the tests. # Test: Search for usage of `set_detect_anomaly` in the test directory. rg 'set_detect_anomaly' tests/Length of output: 105
dreem/io/config.py (1)
43-43
: Consider the implications of making the configuration mutable.The addition of
OmegaConf.set_struct(self.cfg, False)
allows dynamic changes to the configuration object. Ensure this change aligns with the intended design and does not introduce unintended side effects, such as accidental modifications to the configuration during runtime.tests/test_models.py (2)
38-38
: Ensure proper testing ofembedding_agg_method
.The addition of
embedding_agg_method="average"
inATTWeightHead
tests the new parameter. Ensure that tests cover all possible values and edge cases for this parameter to validate its behavior thoroughly.
186-232
: New test for RoPE embeddings is comprehensive.The
test_rope_embedding
function effectively tests the new RoPE embedding feature, checking output sizes and ensuring distinctness from input data. This test enhances the robustness of the embedding logic.dreem/inference/tracker.py (5)
141-144
: Improved clarity in tracking logic.The added comments clarify the purpose of the
sliding_inference
function and the conditions under which the track queue is cleared. This enhances the readability and maintainability of the code.
169-171
: Enhanced understanding of frame processing.The comments provide context for the processing of frames in the
sliding_inference
function, explaining the roles oftracked_frames
andframe_to_track
. This improves the clarity of the tracking logic.
261-261
: Clarify inference target.The comment about getting the last frame for inference clarifies the target of the inference process. This helps in understanding the flow of data during tracking.
385-385
: Clarify association matrix reduction.The comment explains the reduction of the association matrix, which is crucial for understanding the trajectory scoring process. This enhances the code's readability.
433-433
: Clarify filtering logic.The comment about filtering the association matrix based on distance provides insight into the logic for maintaining or initiating tracks. This improves the understanding of the tracking process.
rope.ipynb (20)
12-25
: Imports look good.The imported libraries and modules appear necessary for the operations in this notebook.
39-39
: Logger setup is appropriate.The logger is set up correctly for the
dreem.models
module.
61-73
: Constructor parameters are well-documented.The constructor of the
Embedding
class is well-documented with clear descriptions of each parameter.
91-92
: Good use of argument validation.The
_check_init_args
method ensures that the embedding type and mode are valid, which helps prevent runtime errors.
105-106
: Conditional logic for scale initialization is appropriate.The logic to initialize
self.scale
whennormalize
is true andscale
is None is correctly implemented.
125-127
: Embedding function defaults to zero embeddings.The default
_emb_func
returns zero embeddings, effectively disabling embeddings when not configured otherwise. This is a sensible default.
147-168
: Argument validation in_check_init_args
is robust.The method properly raises exceptions for invalid
emb_type
andmode
values, enhancing robustness.
170-190
: Forward method correctly handles embedding dimension mismatch.The
forward
method raises aRuntimeError
if the output embedding dimension doesn't match the expected features, providing a helpful hint for resolution.
192-204
: Integer division method usesrounding_mode
correctly.The
_torch_int_div
method usesrounding_mode="floor"
to ensure correct integer division, which is appropriate for the use case.
206-249
: Sine box embedding method is well-implemented.The
_sine_box_embedding
method computes sine positional embeddings correctly, with appropriate handling for normalization and scaling.
251-281
: Sine temporal embedding method is efficient.The
_sine_temp_embedding
method efficiently computes fixed sine temporal embeddings, leveraging tensor operations.
293-295
: Embedding object creation is appropriate.The parameters used for creating
Embedding
objects for temporal and positional embeddings are appropriate and consistent with the class definition.
323-327
: Data extraction logic is correct.The logic for extracting instances from the dataset is appropriate and efficient.
339-341
: Time vector extraction logic is correct.The logic for extracting time vectors from the instances is appropriate.
378-379
: Visualization logic is appropriate.The logic for visualizing the temporal embedding using
matplotlib
is correct and provides meaningful insights.
391-393
: Encoder instantiation is appropriate.The parameters used for instantiating
TransformerEncoderLayer
andVisualEncoder
are appropriate and consistent with their respective class definitions.
405-436
: Feature extraction function is well-implemented.The
extract_features
function correctly extracts features using a visual encoder, with appropriate handling for different scenarios.
450-493
: Transformer preparation function is well-implemented.The
prepare_for_xfmr
function correctly prepares data for a transformer, with appropriate handling for embeddings and feature vectors.
505-506
: Feature extraction invocation is correct.The invocation of the
extract_features
function is correct and uses appropriate arguments.
526-563
: Transformer instantiation and evaluation logic is appropriate.The logic for instantiating a
Transformer
and evaluating it usingGTRRunner
is correct and consistent with the intended use.
- concatenation method with mlp - complete pre-processing for input to EncoderLayer - fix shape issues in rope/additive_embedding/forward modules in embedding.py
- bounding box embedding only for method "average" - modify emb_funcs routing - temporarily remove support for adding embeddings into instance objects - need to make compatible with x,y,t embeddings - remove config yamls from updates - current versions serve as templates - runs through to end of encoder forward pass
- implement embeddings for decoder + refactor - add 1x1 conv to final attn head to deal with stacked embeddings (3x tokens) and create channels for each dim - bug fix in rope rotation matrix product with input data
- 1x1 conv for stack embedding - stack into 3 channels for x,y,t
- add unit tests for rope - Update existing tests to use new args/return params related to tfmr - Modify test to remove return_embedding=True support - need to address this
- create rope isntance once rather than each fwd pass - construct embedding lookup array each fwd pass based on num instances passed in to embedding - scale only pos embs * 100 rather than also temp embs
- times array for embedding for encoder queries inside decoder was of query size rather than ref size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 19
🧹 Outside diff range and nitpick comments (2)
dreem/inference/eval.py (1)
Line range hint
68-70
: Verify thatlabels_files
andvid_files
have the same lengthBefore zipping
labels_files
andvid_files
, it's important to ensure they have the same length to prevent potentialIndexError
exceptions or mismatches between labels and videos.🧰 Tools
🪛 Ruff
30-30: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
56-56: f-string without any placeholders
Remove extraneous
f
prefix(F541)
dreem/inference/track.py (1)
104-110
: Consider usingos.getenv
for environment variable retrievalUsing
os.getenv
allows you to provide a default value if the environment variable is not set, which can simplify the code.Apply this code change to use
os.getenv
:index = int(os.getenv("POD_INDEX", default=-1)) if index == -1: # Handle the case where POD_INDEX is not set index = int(input("POD_INDEX not found. Please input task index to run: "))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (6)
- dreem/datasets/sleap_dataset.py (3 hunks)
- dreem/inference/eval.py (1 hunks)
- dreem/inference/track.py (2 hunks)
- dreem/io/config.py (7 hunks)
- dreem/models/embedding.py (14 hunks)
- dreem/models/gtr_runner.py (2 hunks)
🧰 Additional context used
🪛 Ruff
dreem/datasets/sleap_dataset.py
163-163: Local variable
e
is assigned to but never usedRemove assignment to unused variable
e
(F841)
dreem/inference/eval.py
30-30: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
56-56: f-string without any placeholders
Remove extraneous
f
prefix(F541)
dreem/inference/track.py
100-100: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
dreem/io/config.py
202-202: Undefined name
SleapDataset
(F821)
202-202: Undefined name
MicroscopyDataset
(F821)
202-202: Undefined name
CellTrackingDataset
(F821)
263-263: Undefined name
SleapDataset
(F821)
263-263: Undefined name
MicroscopyDataset
(F821)
263-263: Undefined name
CellTrackingDataset
(F821)
dreem/models/embedding.py
104-104: Multiple statements on one line (colon)
(E701)
267-267: f-string without any placeholders
Remove extraneous
f
prefix(F541)
🔇 Additional comments (9)
dreem/models/gtr_runner.py (4)
304-304
: Improved video name extractionThe simplification of video name extraction by removing the file extension is a good change. It makes the code more robust to different file extensions and simplifies the naming convention in the HDF5 file.
Line range hint
304-325
: Consistent implementation of changesThe changes to video name processing and ground truth track ID extraction have been applied consistently within the
on_test_epoch_end
method. This consistency is crucial for maintaining the integrity of the data processing pipeline.
Line range hint
304-325
: Overall assessment of changesThe modifications in the
on_test_epoch_end
method align well with the PR objectives. They contribute to the refactoring of the Transformer architecture by updating how track IDs are accessed, which is likely related to changes in the embedding handling. The code changes are focused, consistent, and improve the robustness of the data processing pipeline.
313-316
: Updated ground truth track ID extractionThe new method of extracting ground truth track IDs by iterating over
frame.instances
and accessinggt_track_id
directly is more straightforward. This change appears to be consistent with updates to theInstance
class structure or modifications in how track IDs are stored.To ensure this change is consistent across the codebase, please run the following script:
This will help identify any other locations where similar updates might be needed.
Also applies to: 322-325
dreem/datasets/sleap_dataset.py (1)
109-109
: Initialization ofself.vid_readers
is appropriateThe addition of
self.vid_readers = {}
correctly initializes a dictionary to store video readers, facilitating efficient access and management of video files within the dataset.dreem/io/config.py (2)
179-181
: Improved reproducibility by sorting file listsSorting
label_files
andvid_files
ensures consistent ordering of files, which is beneficial for reproducibility and consistent data loading. Good practice.
235-249
: 🛠️ Refactor suggestionRedundant imports inside the method
You're importing
MicroscopyDataset
,SleapDataset
, andCellTrackingDataset
inside theget_dataset()
method, but their types are used in type annotations outside the method. This can cause issues with type checking and static analysis.Consider moving the imports to the top of the module to ensure that the type annotations are recognized properly.
+from dreem.datasets import MicroscopyDataset, SleapDataset, CellTrackingDataset
Alternatively, if you need to avoid top-level imports, use forward references in type annotations or import under
if TYPE_CHECKING
.Likely invalid or redundant comment.
dreem/models/embedding.py (1)
75-122
: Verify the correctness of the RoPE implementationThe
RotaryPositionalEmbeddings
class'sforward
method returnsrope_cache
, which appears to be used as rotation matrices in_apply_rope
. Please verify that the dimensions and the application of these rotation matrices align with the RoPE mechanism described in the original paper. Ensuring the rotation is correctly applied to the input tensorx
is crucial for the embeddings to function as intended.🧰 Tools
🪛 Ruff
104-104: Multiple statements on one line (colon)
(E701)
dreem/inference/track.py (1)
137-144
: Ensure that output directory exists before saving resultsIf the
outdir
directory does not exist, saving the results may fail. Althoughos.makedirs(outdir, exist_ok=True)
is called earlier, ensure that any subdirectories needed are also created.Run the following script to confirm that the output directory exists and is writable:
dreem/inference/eval.py
Outdated
index = int( | ||
input(f"{e}. Assuming single run!\nPlease input task index to run:") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add input validation for user input
When prompting the user for the task index, consider adding input validation to handle non-integer inputs gracefully. This prevents the program from crashing if the user enters invalid data.
dreem/inference/eval.py
Outdated
hparams = {} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle failure of set_hparams
appropriately
If eval_cfg.set_hparams(hparams)
returns False
, simply setting hparams = {}
may not be sufficient. Consider logging an error message or raising an exception to indicate that the hyperparameters were not set correctly.
dreem/inference/eval.py
Outdated
logger.info(f"Using the following tracker:") | ||
|
||
print(model.tracker) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use logger instead of print statement
To maintain consistent logging practices, replace the print
statement with logger.info
. This ensures that all output is handled by the logging system.
- print(model.tracker)
+ logger.info(model.tracker)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
print(model.tracker) | |
logger.info(model.tracker) |
dreem/inference/eval.py
Outdated
@@ -26,24 +26,35 @@ | |||
""" | |||
eval_cfg = Config(cfg) | |||
|
|||
if "checkpoints" in cfg.keys(): | |||
# update with parameters for batch train job | |||
if "batch_config" in cfg.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify dictionary key check
You can simplify the condition by removing .keys()
as it's unnecessary. The more Pythonic way is:
- if "batch_config" in cfg.keys():
+ if "batch_config" in cfg:
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if "batch_config" in cfg.keys(): | |
if "batch_config" in cfg: |
🧰 Tools
🪛 Ruff
30-30: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
dreem/inference/eval.py
Outdated
model.tracker_cfg = eval_cfg.cfg.tracker | ||
model.tracker = Tracker(**model.tracker_cfg) | ||
|
||
logger.info(f"Using the following tracker:") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unnecessary f-string
The log message does not contain any variables to interpolate, so the f
prefix is unnecessary. You can simplify the line as follows:
- logger.info(f"Using the following tracker:")
+ logger.info("Using the following tracker:")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
logger.info(f"Using the following tracker:") | |
logger.info("Using the following tracker:") |
🧰 Tools
🪛 Ruff
56-56: f-string without any placeholders
Remove extraneous
f
prefix(F541)
@@ -323,6 +562,7 @@ def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor: | |||
""" | |||
temp_lookup = self.lookup | |||
N = times.shape[0] | |||
times = times / times.max() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle potential division by zero in time normalization
In the _learned_temp_embedding
method, times
is normalized by dividing by times.max()
. If times.max()
is zero, this will result in a division by zero error. To prevent this, consider adding a check to ensure that times.max()
is not zero before performing the division.
Apply this diff to handle the potential division by zero:
def _learned_temp_embedding(self, times: torch.Tensor, *args) -> torch.Tensor:
# ...
- times = times / times.max()
+ max_time = times.max()
+ if max_time > 0:
+ times = times / max_time
+ else:
+ times = times.clone() # All times are zero, no normalization needed
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
times = times / times.max() | |
max_time = times.max() | |
if max_time > 0: | |
times = times / max_time | |
else: | |
times = times.clone() # All times are zero, no normalization needed |
x_rope = torch.rand(input_shape).unsqueeze(2) | ||
# infer whether it is a positional or temporal embedding | ||
is_pos_emb = 1 if seq_positions.max() < 1 else 0 | ||
# if it is positional, scale seq_positions since these are fractions | ||
# in [0,1] and we need int indexes for embedding lookup | ||
seq_positions = seq_positions * 100 if is_pos_emb else seq_positions | ||
seq_positions = seq_positions.unsqueeze(0).int() | ||
# RoPE module takes in dimension, num_queries as input to calculate rotation matrix | ||
rot_mat = self.rope_instance(x_rope, seq_positions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid using random tensor in _rope_embedding
In the _rope_embedding
method, x_rope
is initialized as a random tensor. This tensor is used to compute the rotation matrix but introduces unnecessary randomness. Consider using a tensor with consistent values, such as zeros, to avoid any unintended side effects and ensure deterministic behavior.
Apply this diff to use a zero tensor instead:
def _rope_embedding(self, seq_positions: torch.Tensor, input_shape: torch.Size) -> torch.Tensor:
# ...
- x_rope = torch.rand(input_shape).unsqueeze(2)
+ x_rope = torch.zeros(input_shape, dtype=torch.float32, device=seq_positions.device).unsqueeze(2)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
x_rope = torch.rand(input_shape).unsqueeze(2) | |
# infer whether it is a positional or temporal embedding | |
is_pos_emb = 1 if seq_positions.max() < 1 else 0 | |
# if it is positional, scale seq_positions since these are fractions | |
# in [0,1] and we need int indexes for embedding lookup | |
seq_positions = seq_positions * 100 if is_pos_emb else seq_positions | |
seq_positions = seq_positions.unsqueeze(0).int() | |
# RoPE module takes in dimension, num_queries as input to calculate rotation matrix | |
rot_mat = self.rope_instance(x_rope, seq_positions) | |
x_rope = torch.zeros(input_shape, dtype=torch.float32, device=seq_positions.device).unsqueeze(2) | |
# infer whether it is a positional or temporal embedding | |
is_pos_emb = 1 if seq_positions.max() < 1 else 0 | |
# if it is positional, scale seq_positions since these are fractions | |
# in [0,1] and we need int indexes for embedding lookup | |
seq_positions = seq_positions * 100 if is_pos_emb else seq_positions | |
seq_positions = seq_positions.unsqueeze(0).int() | |
# RoPE module takes in dimension, num_queries as input to calculate rotation matrix | |
rot_mat = self.rope_instance(x_rope, seq_positions) |
dreem/inference/track.py
Outdated
index = int( | ||
input(f"{e}. Assuming single run!\nPlease input task index to run:") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle non-integer input when prompting for task index
When the environment variable POD_INDEX
is not set, the code prompts the user for input. If the user enters a non-integer value, a ValueError
will be raised. Consider adding error handling to ensure that the program handles invalid input gracefully.
Apply this code change to handle invalid inputs:
except KeyError as e:
print(f"{e}. Assuming single run!")
while True:
try:
index = int(input("Please input task index to run: "))
break
except ValueError:
print("Invalid input. Please enter an integer value for the task index.")
dreem/inference/track.py
Outdated
hparams_df = pd.read_csv(cfg.batch_config) | ||
hparams = hparams_df.iloc[index].to_dict() | ||
_ = hparams.pop("Unnamed: 0", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add boundary check for 'index' to prevent IndexError
When accessing hparams_df.iloc[index]
, if index
is out of bounds, an IndexError
will be raised. Consider adding a check to ensure that the index
is within the valid range of the DataFrame.
Apply this code change to add boundary checking:
hparams_df = pd.read_csv(cfg.batch_config)
if not 0 <= index < len(hparams_df):
raise IndexError(f"Task index {index} is out of bounds. Available indices are from 0 to {len(hparams_df) - 1}.")
hparams = hparams_df.iloc[index].to_dict()
dreem/inference/track.py
Outdated
@@ -96,25 +96,35 @@ | |||
""" | |||
pred_cfg = Config(cfg) | |||
|
|||
if "checkpoints" in cfg.keys(): | |||
# update with parameters for batch train job | |||
if "batch_config" in cfg.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify dictionary key check by removing .keys()
In Python, you can directly check if a key exists in a dictionary without using .keys()
. This is more efficient and idiomatic.
Apply this diff to simplify the condition:
-if "batch_config" in cfg.keys():
+if "batch_config" in cfg:
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if "batch_config" in cfg.keys(): | |
if "batch_config" in cfg: |
🧰 Tools
🪛 Ruff
100-100: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
- added tracker debugging script
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 14
🧹 Outside diff range and nitpick comments (3)
scripts/run_tracker.py (1)
5-7
: Remove commented-out codeThe commented-out lines related to changing directories are no longer needed and can be safely removed. Keeping unused code comments can lead to confusion and clutter in the codebase.
Apply this diff to remove the commented-out lines:
-# /Users/mustafashaikh/dreem/dreem/training -# /Users/main/Documents/GitHub/dreem/dreem/training -# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")dreem/inference/post_processing.py (1)
Line range hint
126-174
: Summary of changes tofilter_max_center_dist
The modifications to this function introduce new parameters (
h
andw
) and change the distance calculation logic. While these changes aim to improve functionality, there are several issues that need to be addressed:
- The new parameters
h
andw
are not used in the function body.- TODO comments indicate necessary changes that haven't been implemented yet.
- The distance normalization logic has been modified, which may affect the function's behavior.
To improve this implementation:
- Implement the usage of
h
andw
parameters to scale the distance calculation.- Address the TODO comments by modifying
nonk_boxes
to only include the previous frame and scaling the distance calculation.- Review and verify the change in distance normalization to ensure it aligns with the intended functionality.
- Add or update unit tests to cover these changes and verify the function's behavior.
Please make these adjustments to ensure the function works as intended and maintains consistency with the rest of the codebase.
🧰 Tools
🪛 Ruff
151-151: Local variable
k_s
is assigned to but never usedRemove assignment to unused variable
k_s
(F841)
tests/test_inference.py (1)
218-219
: LGTM! Consider clarifying image dimensions assumption.The addition of
h
andw
parameters to thefilter_max_center_dist
function calls is a good improvement. It allows the function to handle explicit image dimensions, which is more flexible and robust.For improved clarity, consider adding a comment explaining the assumption of a square image, or consider using separate variables for height and width to make the test more general:
im_height = 128 im_width = 128 # ... (rest of the code) h=im_height, w=im_widthThis change would make the test more explicit about handling potentially non-square images, which might be beneficial for future test cases or modifications.
Also applies to: 231-232
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
- dreem/inference/post_processing.py (3 hunks)
- dreem/inference/tracker.py (14 hunks)
- scripts/run_tracker.py (1 hunks)
- scripts/run_trainer.py (1 hunks)
- tests/test_inference.py (2 hunks)
🧰 Additional context used
🪛 Ruff
dreem/inference/tracker.py
196-196: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
200-200: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
scripts/run_tracker.py
3-3:
os
imported but unusedRemove unused import:
os
(F401)
scripts/run_trainer.py
3-3:
os
imported but unusedRemove unused import:
os
(F401)
🔇 Additional comments (4)
scripts/run_tracker.py (1)
1-12
: Overall assessment: Good foundation with room for improvementThe script provides a simple and straightforward way to run the tracking process using a configuration file. It's a good starting point, but there are several areas where it can be improved:
- Remove unused imports and commented-out code.
- Make the configuration file path more flexible.
- Add error handling for both configuration loading and tracking execution.
- Consider adding logging for better visibility into the script's execution.
Implementing these suggestions will make the script more robust, flexible, and easier to maintain.
🧰 Tools
🪛 Ruff
3-3:
os
imported but unusedRemove unused import:
os
(F401)
dreem/inference/tracker.py (3)
288-291
: Ensure correct indexing ofasso_matrix
.When accessing
asso_matrix[-1]
andasso_matrix[-1].matrix
, confirm thatasso_matrix
is a sequence (e.g., list or tuple) with at least one element. Ifasso_matrix
is a tensor or has an unexpected structure, this could raise anIndexError
.
355-355
: Verify the structure ofimg_shape
when unpacking.The line
_, h, w = query_frame.img_shape.flatten()
assumes thatimg_shape
can be flattened into at least three values. Ensure thatimg_shape
has the correct dimensions and thath
andw
correspond to the height and width of the image. Misalignment could lead to incorrect scaling in post-processing.
Line range hint
465-479
: Review the logic in track assignment condition.The condition in the
if
statement may not correctly handle all cases:if n_traj >= self.max_tracks or traj_score[i, j] > thresh:Consider the following:
- If
n_traj
exceedsmax_tracks
, the tracker assigns to existing tracks regardless of thetraj_score
.- Comparing
traj_score[i, j]
(which could be negative) withthresh
may not behave as expected, especially iftraj_score
contains negative values due to earlier operations like taking the negative forlinear_sum_assignment
.Please ensure that
traj_score
values are on the appropriate scale and sign for this comparison, and that the logic correctly determines when to assign to existing tracks versus creating new ones.
# /Users/mustafashaikh/dreem/dreem/training | ||
# /Users/main/Documents/GitHub/dreem/dreem/training | ||
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training") | ||
config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Make configuration file path more flexible
The configuration file path is currently hardcoded, which might cause issues when running the script in different environments or by different users. Consider making this path more flexible by using environment variables or command-line arguments.
Here's an example of how you could make the path more flexible using environment variables:
import os
config = os.environ.get('DREEM_CONFIG_PATH', "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml")
This allows users to set the DREEM_CONFIG_PATH
environment variable to override the default path.
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training") | ||
config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml" | ||
|
||
cfg = OmegaConf.load(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for configuration loading
While using OmegaConf for configuration management is good practice, it's important to handle potential errors that may occur during the loading process. This could include file not found errors or parsing issues with the YAML file.
Consider wrapping the configuration loading in a try-except block to handle potential errors gracefully. Here's an example:
import sys
try:
cfg = OmegaConf.load(config)
except FileNotFoundError:
print(f"Error: Configuration file not found at {config}")
sys.exit(1)
except Exception as e:
print(f"Error loading configuration: {str(e)}")
sys.exit(1)
This will provide more informative error messages and prevent the script from crashing unexpectedly if there are issues with the configuration file.
|
||
cfg = OmegaConf.load(config) | ||
|
||
track.run(cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for tracking execution
The current implementation doesn't handle potential errors that may occur during the tracking process. Adding error handling can improve the script's robustness and provide more informative feedback if issues arise.
Consider wrapping the tracking execution in a try-except block to handle potential errors gracefully. Here's an example:
try:
track.run(cfg)
except Exception as e:
print(f"Error during tracking execution: {str(e)}")
sys.exit(1)
This will catch any exceptions thrown during the tracking process, print an informative error message, and exit the script with a non-zero status code to indicate failure.
from dreem.inference import track | ||
from omegaconf import OmegaConf | ||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import
The os
module is imported but not used in the script. To keep the imports clean and avoid potential confusion, it's recommended to remove unused imports.
Apply this diff to remove the unused import:
from dreem.inference import track
from omegaconf import OmegaConf
-import os
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from dreem.inference import track | |
from omegaconf import OmegaConf | |
import os | |
from dreem.inference import track | |
from omegaconf import OmegaConf |
🧰 Tools
🪛 Ruff
3-3:
os
imported but unusedRemove unused import:
os
(F401)
cfg = OmegaConf.load(base_config) | ||
# Load and merge override config | ||
override_cfg = OmegaConf.load(params_config) | ||
cfg = OmegaConf.merge(cfg, override_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for configuration loading.
While the configuration loading process is correct, it lacks error handling. This could lead to unclear error messages if there are issues with the configuration files.
Consider wrapping the configuration loading in a try-except block to handle potential errors:
try:
cfg = OmegaConf.load(base_config)
override_cfg = OmegaConf.load(params_config)
cfg = OmegaConf.merge(cfg, override_cfg)
except Exception as e:
print(f"Error loading configuration: {e}")
exit(1)
This change will provide clearer error messages if there are issues with loading or merging the configuration files.
h: int = None, | ||
w: int = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement the usage of new parameters h
and w
The new parameters h
and w
have been added to the function signature, but they are not used in the function body. These parameters are intended to represent the height and width of the image for scaling the distance calculation.
To address this, implement the scaling of the distance calculation using h
and w
. Here's a suggested implementation:
# After calculating dist
scaled_dist = dist * torch.tensor([h, w], device=dist.device)
norm_dist = scaled_dist.mean(axis=-1) # n_k x Np
This change will convert the distance from normalized coordinates to pixel units.
Also applies to: 140-141
@@ -147,13 +151,15 @@ | |||
k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k | |||
|
|||
nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2 | |||
|
|||
# TODO: nonk_boxes should be only from previous frame rather than entire window |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address TODO comments regarding distance calculation
The TODO comments indicate that:
nonk_boxes
should only be from the previous frame rather than the entire window.- The distance calculation needs to be scaled by the original image size.
To address these issues:
- Modify the
nonk_boxes
input to only include boxes from the previous frame. - Implement the scaling of the distance calculation as suggested in the previous comment.
Please update the implementation to reflect these changes and remove the TODO comments once addressed.
Also applies to: 158-160
# norm_dist = dist / (k_s[:, None, :] + 1e-8) | ||
norm_dist = dist.mean(axis=-1) # n_k x Np | ||
# norm_dist = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Missing Tests for filter_max_center_dist
No existing tests found for filter_max_center_dist
. Please add appropriate test cases to ensure that the new distance normalization aligns with the intended functionality.
🔗 Analysis chain
Review the change in distance normalization
The distance normalization has been changed from dividing by the box size to taking the mean across the last dimension.
This change might affect the behavior of the function. Please verify if this new approach aligns with the intended functionality of filter_max_center_dist
. Consider running the following test:
If the change is intentional and correct, please add a comment explaining the rationale behind this modification.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Compare the output of the old and new implementations
# Test: Search for test cases of filter_max_center_dist
rg --type python -A 10 "def test_filter_max_center_dist"
# If no tests are found, consider adding a new test case
if [ $? -ne 0 ]; then
echo "No existing tests found for filter_max_center_dist. Consider adding a new test case."
fi
Length of output: 286
dreem/inference/tracker.py
Outdated
# if track ids exist from another tracking program i.e. sleap, init with those | ||
for i, instance in enumerate(frames[batch_idx].instances): | ||
instance.pred_track_id = instance.gt_track_id | ||
curr_track_id = max(curr_track_id, instance.pred_track_id) | ||
|
||
# if no track ids, then assign new ones | ||
for i, instance in enumerate(frames[batch_idx].instances): | ||
if instance.pred_track_id == -1: | ||
curr_track += 1 | ||
curr_track_id += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Simplify code by using frame_to_track
directly.
Instead of accessing frames[batch_idx].instances
, you can use frame_to_track.instances
, which is already available from the loop. This makes the code cleaner and avoids redundant indexing.
Apply this diff to simplify the code:
- for i, instance in enumerate(frames[batch_idx].instances):
+ for i, instance in enumerate(frame_to_track.instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)
- for i, instance in enumerate(frames[batch_idx].instances):
+ for i, instance in enumerate(frame_to_track.instances):
if instance.pred_track_id == -1:
curr_track_id += 1
instance.pred_track_id = curr_track_id
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# if track ids exist from another tracking program i.e. sleap, init with those | |
for i, instance in enumerate(frames[batch_idx].instances): | |
instance.pred_track_id = instance.gt_track_id | |
curr_track_id = max(curr_track_id, instance.pred_track_id) | |
# if no track ids, then assign new ones | |
for i, instance in enumerate(frames[batch_idx].instances): | |
if instance.pred_track_id == -1: | |
curr_track += 1 | |
curr_track_id += 1 | |
# if track ids exist from another tracking program i.e. sleap, init with those | |
for i, instance in enumerate(frame_to_track.instances): | |
instance.pred_track_id = instance.gt_track_id | |
curr_track_id = max(curr_track_id, instance.pred_track_id) | |
# if no track ids, then assign new ones | |
for i, instance in enumerate(frame_to_track.instances): | |
if instance.pred_track_id == -1: | |
curr_track_id += 1 | |
instance.pred_track_id = curr_track_id |
🧰 Tools
🪛 Ruff
196-196: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
200-200: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
for batch_idx, frame_to_track in enumerate(frames): | ||
# tracked_frames is a list of reference frames that have been tracked (associated) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Potential AttributeError
Confirmed: frame_id
is inconsistently handled as both tensors and integers.
- In
dreem/inference/tracker.py
,frame_to_track.frame_id
is accessed as a tensor (frame_to_track.frame_id.device
). - Multiple instances in other files assign
frame_id
as an integer (e.g.,frame_id=0
), which would cause anAttributeError
when.device
is accessed.
🔗 Analysis chain
Potential AttributeError
when accessing frame_id.device
.
In the line device=frame_to_track.frame_id.device
, ensure that frame_to_track.frame_id
is a tensor. If frame_id
is an integer, accessing .device
will raise an AttributeError
. Please verify the type of frame_id
.
You can run the following script to check the type of frame_id
across the codebase:
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that 'frame_id' is consistently a tensor.
# Test: Search for 'frame_id' assignments. Expect: 'frame_id' should be a tensor.
rg --type python --context 2 'frame_id\s*='
Length of output: 76
Script:
#!/bin/bash
# Description: Verify that 'frame_id' is consistently a tensor.
# Test: Search for 'frame_id' assignments in Python files. Expect: 'frame_id' should be a tensor.
rg --glob "*.py" --context 2 'frame_id\s*='
Length of output: 7061
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
dreem/datasets/sleap_dataset.py
(3 hunks)dreem/inference/eval.py
(1 hunks)dreem/inference/track.py
(2 hunks)dreem/io/config.py
(5 hunks)dreem/models/embedding.py
(14 hunks)dreem/models/gtr_runner.py
(2 hunks)dreem/models/transformer.py
(12 hunks)dreem/training/configs/base.yaml
(6 hunks)dreem/training/train.py
(2 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
dreem/inference/eval.py
30-30: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
50-50: Undefined name checkpoint
(F821)
54-54: f-string without any placeholders
Remove extraneous f
prefix
(F541)
dreem/inference/track.py
112-112: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
132-132: Undefined name checkpoint
(F821)
dreem/datasets/sleap_dataset.py
208-208: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
dreem/io/config.py
229-229: Undefined name SleapDataset
(F821)
229-229: Undefined name MicroscopyDataset
(F821)
229-229: Undefined name CellTrackingDataset
(F821)
304-304: Undefined name dataset
(F821)
305-305: Undefined name dataset
(F821)
307-307: Undefined name dataset
(F821)
326-326: Undefined name SleapDataset
(F821)
326-326: Undefined name MicroscopyDataset
(F821)
326-326: Undefined name CellTrackingDataset
(F821)
dreem/models/embedding.py
104-104: Multiple statements on one line (colon)
(E701)
267-267: f-string without any placeholders
Remove extraneous f
prefix
(F541)
dreem/models/transformer.py
301-301: SyntaxError: Expected 'else', found ':'
301-302: SyntaxError: Expected ')', found newline
302-302: SyntaxError: Unexpected indentation
🪛 yamllint (1.35.1)
dreem/training/configs/base.yaml
[warning] 32-32: wrong indentation: expected 6 but found 8
(indentation)
[warning] 35-35: wrong indentation: expected 6 but found 8
(indentation)
🔇 Additional comments (26)
dreem/inference/eval.py (4)
30-30
: Simplify dictionary key check
You can simplify the condition by removing .keys()
, as it's unnecessary. The more Pythonic way is:
- if "batch_config" in cfg.keys():
+ if "batch_config" in cfg:
🧰 Tools
🪛 Ruff (0.8.2)
30-30: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
34-36
: Add input validation for user input
When prompting the user for the task index, consider adding input validation to handle non-integer inputs gracefully. This prevents the program from crashing if the user enters invalid data.
53-53
: Remove unnecessary f
prefix in log message
The log message does not contain any variables to interpolate, so the f
prefix is unnecessary. You can simplify the line as follows:
- logger.info(f"Using the following tracker:")
+ logger.info("Using the following tracker:")
55-55
: Use logger instead of print
statement
To maintain consistent logging practices, replace the print
statement with logger.info
. This ensures that all output is handled by the logging system.
- print(model.tracker)
+ logger.info(model.tracker)
dreem/inference/track.py (3)
112-112
: Simplify dictionary key check
You can simplify the condition by removing .keys()
, as it's unnecessary. The more Pythonic way is:
- if "batch_config" in cfg.keys():
+ if "batch_config" in cfg:
🧰 Tools
🪛 Ruff (0.8.2)
112-112: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
116-118
: Add input validation for task index input
When the environment variable POD_INDEX
is not set, the code prompts the user for input. If the user enters a non-integer value, a ValueError
will be raised. Consider adding error handling to ensure the program handles invalid input gracefully.
120-122
: Add boundary check for index
to prevent IndexError
When accessing hparams_df.iloc[index]
, if index
is out of bounds, an IndexError
will be raised. Consider adding a check to ensure that index
is within the valid range of the DataFrame.
dreem/models/embedding.py (3)
104-104
: Split multiple statements into separate lines
For better readability and to comply with PEP 8 style guidelines, split multiple statements into separate lines:
- if input_pos.dim() <= 1: input_pos = input_pos.unsqueeze(0)
+ if input_pos.dim() <= 1:
+ input_pos = input_pos.unsqueeze(0)
🧰 Tools
🪛 Ruff (0.8.2)
104-104: Multiple statements on one line (colon)
(E701)
267-267
: Remove unnecessary f
prefix in exception message
The string does not contain any placeholders, so the f
prefix is unnecessary.
- f"Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
+ "Cannot use aggregation method 'average' for rope embedding; must use 'stack' or 'concatenate'"
🧰 Tools
🪛 Ruff (0.8.2)
267-267: f-string without any placeholders
Remove extraneous f
prefix
(F541)
565-568
:
Handle potential division by zero in time normalization
In the _learned_temp_embedding
method, times
is normalized by dividing by times.max()
. If times.max()
is zero, this will result in a division by zero error. To prevent this, add a check to ensure that times.max()
is not zero before performing the division.
Apply this diff to handle the potential division by zero:
def _learned_temp_embedding(self, times: torch.Tensor, *args) -> torch.Tensor:
# ...
+ max_time = times.max()
+ if max_time > 0:
+ times = times / max_time
+ else:
+ times = times.clone() # All times are zero, no normalization needed
Likely invalid or redundant comment.
dreem/models/transformer.py (1)
87-91
: Simplify code using dict.get()
Simplify the retrieval of embedding_agg_method
by using the get
method:
- self.embedding_agg_method = (
- embedding_meta["embedding_agg_method"]
- if "embedding_agg_method" in embedding_meta
- else "average"
- )
+ self.embedding_agg_method = embedding_meta.get("embedding_agg_method", "average")
This improves readability and follows Python best practices.
dreem/training/configs/base.yaml (3)
101-101
: LGTM: Dataset mode configuration.
The addition of mode fields to dataset configurations properly identifies the purpose of each dataset.
Also applies to: 110-110, 119-119
32-36
:
Fix duplicate configuration and indentation.
The embedding configuration has duplicate mode declarations and incorrect indentation.
Apply this diff to fix the issues:
- mode: "fixed" # supports fixed, learned, rope
- normalize: true
- temp:
- mode: "fixed" # supports fixed, learned, rope
- embedding_agg_method: "stack" # supports stack, average, concatenate
+ mode: "fixed" # supports fixed, learned, rope
+ normalize: true
+ temp:
+ mode: "fixed" # supports fixed, learned, rope
+ embedding_agg_method: "stack" # supports stack, average, concatenate
Likely invalid or redundant comment.
🧰 Tools
🪛 yamllint (1.35.1)
[warning] 32-32: wrong indentation: expected 6 but found 8
(indentation)
[warning] 35-35: wrong indentation: expected 6 but found 8
(indentation)
142-142
: Verify impact of disabled model logging.
Setting log_model
to null disables model logging. Please ensure this won't affect model checkpointing functionality.
✅ Verification successful
Model logging disabled but checkpointing remains functional
Based on the search results, the log_model
configuration appears to be independent of the model checkpointing functionality. The checkpointing is controlled by separate configurations that are still enabled:
enable_checkpointing: true
in trainer settings- Dedicated
checkpointing
section with proper monitoring and save settings
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check if there are any other model logging configurations
rg -l "log_model" --type yaml
Length of output: 29
Script:
#!/bin/bash
# Let's try a broader search for model logging related patterns
rg -i "model.*log|log.*model" --type yaml -A 3 -B 3
# Also search for checkpointing related configurations
rg -i "checkpoint|save.*model|model.*save" --type yaml -A 3 -B 3
Length of output: 947
Script:
#!/bin/bash
# Let's check the actual implementation to understand how log_model is used
rg -A 5 -B 5 "log_model" --type python
# Also search for any model logging related code
ast-grep --pattern 'def $_(self, $$$):
$$$
self.log($$$)
$$$'
Length of output: 139
dreem/training/train.py (2)
Line range hint 57-65
: LGTM: Structured dataset handling.
The implementation of mode-specific dataset and dataloader creation improves code organization and aligns with the configuration changes.
100-102
: LGTM: Flexible device configuration.
The implementation properly handles both GPU and CPU scenarios, with efficient multiprocessing support for CPU-only environments.
dreem/models/gtr_runner.py (2)
303-303
: LGTM: Simplified video name extraction.
The change to use path splitting for video name extraction is more robust than relying on file extensions.
312-315
: LGTM: Consistent track ID handling.
The implementation now consistently accesses track IDs directly from instances, improving code clarity and maintainability.
Also applies to: 321-324
dreem/datasets/sleap_dataset.py (4)
140-140
: LGTM: Good architectural improvement
The change to use a dictionary for video readers improves resource management and provides clearer mapping between video names and their readers.
207-212
: LGTM: Improved frame retrieval logic
The new implementation efficiently handles frame retrieval by first attempting to get the image directly from the label frame and falling back to the video reader if needed.
🧰 Tools
🪛 Ruff (0.8.2)
208-208: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
208-208
: Remove unused variable e
in exception handling
The variable e
in the except FileNotFoundError as e:
block is assigned but never used.
🧰 Tools
🪛 Ruff (0.8.2)
208-208: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
417-418
: Correct iteration over dictionary to access video readers
Iterating directly over self.vid_readers
accesses the dictionary keys instead of the reader objects.
dreem/io/config.py (4)
229-229
: LGTM: Improved type safety with explicit None return type
The return type annotation now correctly reflects that the method can return None, improving type safety and code clarity.
🧰 Tools
🪛 Ruff (0.8.2)
229-229: Undefined name SleapDataset
(F821)
229-229: Undefined name MicroscopyDataset
(F821)
229-229: Undefined name CellTrackingDataset
(F821)
304-307
: Use logger.warning
instead of deprecated logger.warn
The logger.warn()
method is deprecated since Python 3.3.
🧰 Tools
🪛 Ruff (0.8.2)
304-304: Undefined name dataset
(F821)
305-305: Undefined name dataset
(F821)
307-307: Undefined name dataset
(F821)
326-328
: LGTM: Consistent type safety improvements
The parameter and return type annotations now correctly handle None cases, maintaining consistency with the dataset changes.
🧰 Tools
🪛 Ruff (0.8.2)
326-326: Undefined name SleapDataset
(F821)
326-326: Undefined name MicroscopyDataset
(F821)
326-326: Undefined name CellTrackingDataset
(F821)
506-506
: LGTM: Better checkpoint naming
Including the metric in the final checkpoint name improves checkpoint identification and traceability.
dreem/inference/eval.py
Outdated
else: | ||
checkpoint = eval_cfg.cfg.ckpt_path | ||
hparams = {} | ||
|
||
logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper()) | ||
|
||
model = GTRRunner.load_from_checkpoint(checkpoint, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undefined variable checkpoint
The variable checkpoint
is used here but has not been defined, leading to a NameError
at runtime.
Please define checkpoint
before using it. If it was previously set from eval_cfg.cfg.ckpt_path
, you may need to reintroduce that logic or ensure that checkpoint
is correctly initialized.
🧰 Tools
🪛 Ruff (0.8.2)
50-50: Undefined name checkpoint
(F821)
dreem/inference/track.py
Outdated
else: | ||
checkpoint = pred_cfg.cfg.ckpt_path | ||
hparams = {} | ||
|
||
logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper()) | ||
|
||
model = GTRRunner.load_from_checkpoint(checkpoint, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undefined variable checkpoint
The variable checkpoint
is used here but has not been defined, leading to a NameError
at runtime.
Please define checkpoint
before using it. If it was previously set from pred_cfg.cfg.ckpt_path
, you may need to reintroduce that logic or ensure that checkpoint
is correctly initialized.
🧰 Tools
🪛 Ruff (0.8.2)
132-132: Undefined name checkpoint
(F821)
max_epochs: 1 | ||
min_epochs: 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verify training epoch reduction impact.
Setting both max_epochs and min_epochs to 1 seems unusually low for training a transformer model, especially with RoPE embeddings.
Consider increasing the epochs to allow proper model convergence, typically:
- max_epochs: 1
- min_epochs: 1
+ max_epochs: 100
+ min_epochs: 10
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
max_epochs: 1 | |
min_epochs: 1 | |
max_epochs: 100 | |
min_epochs: 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (4)
dreem/inference/track.py (2)
Line range hint
31-63
: Add error handling for DataFrame operations.The function should handle potential errors when accessing instance attributes and performing DataFrame operations.
Consider wrapping the DataFrame operations in a try-except block:
def export_trajectories( frames_pred: list["dreem.io.Frame"], save_path: str | None = None ) -> pd.DataFrame: save_dict = {} frame_ids = [] X, Y = [], [] pred_track_ids = [] track_scores = [] - for frame in frames_pred: - for i, instance in enumerate(frame.instances): - frame_ids.append(frame.frame_id.item()) - bbox = instance.bbox.squeeze() - y = (bbox[2] + bbox[0]) / 2 - x = (bbox[3] + bbox[1]) / 2 - X.append(x.item()) - Y.append(y.item()) - track_scores.append(instance.track_score) - pred_track_ids.append(instance.pred_track_id.item()) + try: + for frame in frames_pred: + for i, instance in enumerate(frame.instances): + frame_ids.append(frame.frame_id.item()) + bbox = instance.bbox.squeeze() + y = (bbox[2] + bbox[0]) / 2 + x = (bbox[3] + bbox[1]) / 2 + X.append(x.item()) + Y.append(y.item()) + track_scores.append(instance.track_score) + pred_track_ids.append(instance.pred_track_id.item()) + except (AttributeError, IndexError) as e: + logger.error(f"Error processing frames: {e}") + raise
Line range hint
66-93
: Add detailed logging and improve video object creation.The function would benefit from more detailed logging and safer video object creation.
Consider these improvements:
def track( model: GTRRunner, trainer: pl.Trainer, dataloader: torch.utils.data.DataLoader ) -> list[pd.DataFrame]: + logger.info("Starting inference process...") preds = trainer.predict(model, dataloader) pred_slp = [] tracks = {} for batch in preds: for frame in batch: if frame.frame_id.item() == 0: - video = ( - sio.Video(frame.video) - if isinstance(frame.video, str) - else sio.Video - ) + try: + video = sio.Video(frame.video) if isinstance(frame.video, str) else sio.Video() + logger.debug(f"Created video object for frame {frame.frame_id}") + except Exception as e: + logger.error(f"Failed to create video object: {e}") + raise lf, tracks = frame.to_slp(tracks, video=video) pred_slp.append(lf) + logger.info(f"Completed inference with {len(pred_slp)} predictions") pred_slp = sio.Labels(pred_slp) - print(pred_slp) + logger.info(f"Generated SLEAP labels: {pred_slp}") return pred_slpdreem/models/transformer.py (2)
87-91
: Simplify code usingdict.get()
Use Python's built-in
dict.get()
method to simplify the code.- self.embedding_agg_method = ( - embedding_meta["embedding_agg_method"] - if "embedding_agg_method" in embedding_meta - else "average" - ) + self.embedding_agg_method = embedding_meta.get("embedding_agg_method", "average")🧰 Tools
🪛 Ruff (0.8.2)
88-90: Use
embedding_meta.get("embedding_agg_method", "average")
instead of anif
blockReplace with
embedding_meta.get("embedding_agg_method", "average")
(SIM401)
788-803
: Document bounding box format assumptionsThe function assumes a specific bounding box format [ymin, xmin, ymax, xmax] but this critical detail is only mentioned in the docstring. Consider adding runtime validation to prevent silent errors.
def spatial_emb_from_bb(bb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if bb.shape[-1] != 4: + raise ValueError(f"Expected bounding boxes with 4 coordinates, got {bb.shape[-1]}") return ( bb[:, :, [1, 3]].mean(axis=2).squeeze(), bb[:, :, [0, 2]].mean(axis=2).squeeze(), )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
dreem/datasets/sleap_dataset.py
(1 hunks)dreem/inference/eval.py
(1 hunks)dreem/inference/track.py
(1 hunks)dreem/models/transformer.py
(12 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- dreem/inference/eval.py
- dreem/datasets/sleap_dataset.py
🧰 Additional context used
🪛 Ruff (0.8.2)
dreem/models/transformer.py
88-90: Use embedding_meta.get("embedding_agg_method", "average")
instead of an if
block
Replace with embedding_meta.get("embedding_agg_method", "average")
(SIM401)
🔇 Additional comments (2)
dreem/inference/track.py (1)
Line range hint 96-164
: Improve error handling and configuration validation.
The function needs better error handling and configuration validation, especially for the new RoPE embeddings integration.
Consider these improvements:
@hydra.main(config_path=None, config_name=None, version_base=None)
def run(cfg: DictConfig) -> dict[int, sio.Labels]:
+ logger.info("Starting inference with configuration...")
pred_cfg = Config(cfg)
- if "checkpoints" in cfg.keys():
+ if "checkpoints" in cfg:
try:
index = int(os.environ["POD_INDEX"])
- # For testing without deploying a job on runai
except KeyError:
- index = input("Pod Index Not found! Please choose a pod index: ")
+ while True:
+ try:
+ index = int(input("Pod Index Not found! Please choose a pod index: "))
+ break
+ except ValueError:
+ print("Invalid input. Please enter an integer value.")
logger.info(f"Pod Index: {index}")
checkpoints = pd.read_csv(cfg.checkpoints)
+ if not 0 <= index < len(checkpoints):
+ raise IndexError(f"Pod index {index} out of range [0, {len(checkpoints)-1}]")
checkpoint = checkpoints.iloc[index]
else:
checkpoint = pred_cfg.cfg.ckpt_path
+ if not os.path.exists(checkpoint):
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper())
+ logger.info(f"Loading model from checkpoint: {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint, strict=False)
Additionally, since this PR introduces RoPE embeddings, we should verify the embedding configuration:
#!/bin/bash
# Description: Check for RoPE embedding configuration in the codebase
# Look for RoPE configuration in yaml files
rg -t yaml "rope|rotary|positional.*embedding"
# Look for RoPE implementation in Python files
ast-grep --pattern 'class $_(RoPE|RotaryPositionalEmbedding)'
dreem/models/transformer.py (1)
628-641
: Review architectural decision about embedding application
The TODO comment raises an important question about applying embeddings to encoder output before cross attention. This architectural decision needs careful consideration:
- For "average" method, embeddings are applied twice (potentially redundant)
- For "stack" and "concatenate", embeddings are applied only once (potentially inconsistent)
Consider:
- Document the rationale for this difference in behavior
- Evaluate if this asymmetry affects model performance
- Consider unifying the behavior across all methods if possible
✅ Verification successful
Let me gather more information about the embedding application in the transformer model.
Based on the gathered information, I can now provide a final response:
Architectural decision about embedding application is justified
The different behavior between embedding methods is intentional and correct:
-
For "average" method: The double application is needed because it averages positional and temporal embeddings into the same feature space, maintaining backward compatibility.
-
For "stack" and "concatenate": These methods are explicitly prevented from double application because:
- They already preserve the separate embedding information by design (shown in
aggregate_queries
function) - For "stack", embeddings are handled as channels (t1,t2,t3...), (x1,x2,x3...), (y1,y2,y3...)
- For "concatenate", embeddings are combined through an MLP
- They already preserve the separate embedding information by design (shown in
The TODO comment can be removed as the architectural decision is well-implemented and the behavior difference is by design.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any tests or documentation that might explain this architectural decision
rg -l "embedding.*cross.*attention" --type py
Length of output: 75
Script:
#!/bin/bash
# Search for the implementation of apply_embeddings and its usage
ast-grep --pattern 'def apply_embeddings'
# Look for tests related to embedding methods
rg -A 5 "embedding_agg_method" --type py
# Search for documentation or comments about embedding behavior
rg -A 3 "embedding.*average|embedding.*stack|embedding.*concatenate" --type py
Length of output: 25205
if self.return_embedding: | ||
for i, instance in enumerate(ref_instances): | ||
if self.embedding_agg_method == "average": | ||
ref_pos_emb = pos_emb_traceback[0][i] # array | ||
else: | ||
ref_pos_emb = { | ||
"x": pos_emb_traceback[0][0][i], | ||
"y": pos_emb_traceback[1][0][i], | ||
} # dict | ||
|
||
n_query = total_instances | ||
instance.add_embedding("pos", ref_pos_emb) # can be an array or a dict | ||
instance.add_embedding("temp", temp_emb_traceback) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Refactor duplicate embedding handling code
The code block for handling embeddings in instances is duplicated. Consider extracting this into a helper method.
+ def _add_embeddings_to_instance(self, instance, pos_emb_traceback, temp_emb_traceback, index):
+ if self.embedding_agg_method == "average":
+ ref_pos_emb = pos_emb_traceback[0][index]
+ else:
+ ref_pos_emb = {
+ "x": pos_emb_traceback[0][0][index],
+ "y": pos_emb_traceback[1][0][index],
+ }
+ instance.add_embedding("pos", ref_pos_emb)
+ instance.add_embedding("temp", temp_emb_traceback)
- if self.return_embedding:
- for i, instance in enumerate(ref_instances):
- if self.embedding_agg_method == "average":
- ref_pos_emb = pos_emb_traceback[0][i]
- else:
- ref_pos_emb = {
- "x": pos_emb_traceback[0][0][i],
- "y": pos_emb_traceback[1][0][i],
- }
- instance.add_embedding("pos", ref_pos_emb)
- instance.add_embedding("temp", temp_emb_traceback)
+ if self.return_embedding:
+ for i, instance in enumerate(ref_instances):
+ self._add_embeddings_to_instance(instance, pos_emb_traceback, temp_emb_traceback, i)
Also applies to: 300-311
mlp = MLP( | ||
input_dim=queries_t.shape[-1] * 3, | ||
hidden_dim=queries_t.shape[-1] * 2, | ||
output_dim=queries_t.shape[-1], | ||
num_layers=1, | ||
dropout=0.0, | ||
) | ||
# concatenated is of shape (batch_size, n_query, 3*embed_dim) | ||
collated_queries = torch.cat((queries_t, queries_x, queries_y), dim=2) | ||
# pass through MLP to project into space of (batch_size, n_query, embed_dim) | ||
collated_queries = mlp(collated_queries) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider caching or pre-initializing MLP for concatenate method
The current implementation creates a new MLP instance on every call when using the "concatenate" method. This could be inefficient, especially during inference.
Consider moving the MLP initialization to the constructor and reusing it:
+ def __init__(self, ...):
+ self.concat_mlp = MLP(
+ input_dim=d_model * 3,
+ hidden_dim=d_model * 2,
+ output_dim=d_model,
+ num_layers=1,
+ dropout=0.0,
+ )
- mlp = MLP(
- input_dim=queries_t.shape[-1] * 3,
- hidden_dim=queries_t.shape[-1] * 2,
- output_dim=queries_t.shape[-1],
- num_layers=1,
- dropout=0.0,
- )
- collated_queries = mlp(collated_queries)
+ collated_queries = self.concat_mlp(collated_queries)
Committable suggestion skipped: line range outside the PR's diff.
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Chores