-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
- add support for multiple learning rate schedulers in sequential mod… #99
Conversation
…e; params can take list of schedulers in scheduler.name, with lr params in scheduler.name."0" etc. based on number of schedulers - reduce step frequency of lr schedule to each epoch down from 10 epochs - patch in eval key retrieval
WalkthroughThe pull request introduces several modifications across multiple files. In Changes
Possibly related PRs
Poem
Warning Rate limit exceeded@shaikh58 has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 19 minutes and 48 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🧰 Additional context used🪛 yamllint (1.29.0-1)dreem/training/configs/base.yaml[error] 50-50: trailing spaces (trailing-spaces) Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (4)
dreem/inference/eval.py (2)
Line range hint
28-33
: Critical: Replace interactive input with a proper fallbackUsing
input()
as a fallback for missing environment variables is problematic in production environments as it can cause the process to hang indefinitely waiting for user input. Consider using a default value or raising an appropriate exception instead.- except KeyError: - index = input("Pod Index Not found! Please choose a pod index: ") + except KeyError: + logger.error("POD_INDEX environment variable not set") + raise ValueError("POD_INDEX environment variable is required for checkpoint selection")🧰 Tools
🪛 Ruff (0.7.0)
54-54: f-string without any placeholders
Remove extraneous
f
prefix(F541)
Line range hint
67-69
: Critical: Function doesn't match return type annotationThe function is annotated to return
dict[int, sio.Labels]
but doesn't return anything. Additionally, the metrics fromtrainer.test()
are being discarded.Consider either:
- Collecting and returning the results as per the type annotation
- Updating the type annotation if returning results isn't required
for label_file, vid_file in zip(labels_files, vid_files): dataset = eval_cfg.get_dataset( label_files=[label_file], vid_files=[vid_file], mode="test" ) dataloader = eval_cfg.get_dataloader(dataset, mode="test") metrics = trainer.test(model, dataloader) + logger.info(f"Test metrics: {metrics}")
🧰 Tools
🪛 Ruff (0.7.0)
54-54: f-string without any placeholders
Remove extraneous
f
prefix(F541)
dreem/models/gtr_runner.py (1)
Line range hint
1-375
: Consider architectural improvementsThe implementation is solid, but consider the following improvements:
- Document the scheduler frequency change in docstrings and possibly add a configuration option for it
- Make the test results saving more configurable by:
- Adding options for which metrics to save
- Allowing customization of save conditions beyond just "num_switches"
- Providing configuration for what frame data to save
dreem/models/model_utils.py (1)
165-165
: Remove redundant default value inget()
methodThe default value for
dict.get()
isNone
, so specifyingNone
explicitly is unnecessary. Simplify the code by omitting the redundantNone
argument.Apply this diff to simplify the code:
- milestones = scheduler_params.get("milestones", None) + milestones = scheduler_params.get("milestones")🧰 Tools
🪛 Ruff (0.7.0)
165-165: Use
scheduler_params.get("milestones")
instead ofscheduler_params.get("milestones", None)
Replace
scheduler_params.get("milestones", None)
withscheduler_params.get("milestones")
(SIM910)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
dreem/inference/eval.py
(1 hunks)dreem/models/gtr_runner.py
(1 hunks)dreem/models/model_utils.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.7.0)
dreem/models/model_utils.py
165-165: Use scheduler_params.get("milestones")
instead of scheduler_params.get("milestones", None)
Replace scheduler_params.get("milestones", None)
with scheduler_params.get("milestones")
(SIM910)
🔇 Additional comments (5)
dreem/inference/eval.py (2)
55-55
: LGTM! Consistent dictionary access notation
The change from dot notation to bracket notation is correct and maintains consistency with how the metrics dictionary is accessed elsewhere in the code (see line 41).
Line range hint 41-43
: Verify metric configuration compatibility with multiple schedulers
Since this PR adds support for multiple learning rate schedulers, verify that the test metrics configuration properly handles cases where different schedulers are used.
🧰 Tools
🪛 Ruff (0.7.0)
54-54: f-string without any placeholders
Remove extraneous f
prefix
(F541)
dreem/models/gtr_runner.py (1)
250-250
: Verify impact of more frequent learning rate updates
Changing the scheduler frequency from 10 to 1 means the learning rate will be updated every epoch instead of every 10 epochs. While this provides more responsive learning rate adjustments, it could potentially:
- Make training more unstable if validation metrics are noisy
- Increase computational overhead due to more frequent scheduler updates
Let's verify the scheduler configuration across the codebase:
✅ Verification successful
Frequency change is safe and aligns with scheduler configuration
The change from 10 to 1 for scheduler frequency is safe and well-supported by the codebase evidence:
- The default scheduler is ReduceLROnPlateau, which is designed to handle frequent monitoring of metrics
- The scheduler configuration in both test and training configs (
tests/configs/base.yaml
anddreem/training/configs/base.yaml
) shows:patience: 10
- prevents premature LR changes by waiting 10 epochs before reducing LRthreshold: 1e-4
- provides stability by requiring significant metric changes
- These safeguards ensure that more frequent monitoring (every epoch) won't lead to unstable training
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for any custom scheduler configurations that might be affected
# Search for scheduler configurations in yaml files
echo "Checking scheduler configurations in yaml files:"
rg -A 5 "scheduler:" --type yaml
# Search for other scheduler frequency references
echo "Checking for other scheduler frequency references:"
rg -A 5 "frequency.*[0-9]+" --type python
Length of output: 870
Script:
#!/bin/bash
# Let's try a more specific search for scheduler-related configurations and frequency settings
# Search for scheduler configurations in Python files
echo "Checking scheduler configurations in Python files:"
rg -t py -A 5 "configure_optimizers"
# Search for frequency settings in Python files
echo "Checking frequency settings in Python files:"
rg -t py "frequency.*=.*[0-9]+"
# Search for ReduceLROnPlateau usage
echo "Checking ReduceLROnPlateau usage:"
rg -t py "ReduceLROnPlateau"
Length of output: 1970
dreem/models/model_utils.py (2)
163-171
: Multiple scheduler support implemented correctly
The addition of multiple scheduler support using SequentialLR
is well-executed. The code correctly handles a list of schedulers and applies them sequentially with the specified milestones. This enhancement increases the flexibility of the learning rate scheduling.
🧰 Tools
🪛 Ruff (0.7.0)
165-165: Use scheduler_params.get("milestones")
instead of scheduler_params.get("milestones", None)
Replace scheduler_params.get("milestones", None)
with scheduler_params.get("milestones")
(SIM910)
162-173
:
Initialize scheduler_class
to prevent potential UnboundLocalError
If an AttributeError
occurs before scheduler_class
is assigned (e.g., if getattr
fails), referencing scheduler_class
in the except block will raise an UnboundLocalError
. To prevent this, initialize scheduler_class
to None
before the try-except block.
Apply this diff to fix the issue:
+ scheduler_class = None
try:
# if a list is provided, apply each one sequentially
if isinstance(scheduler, list):
schedulers = []
milestones = scheduler_params.get("milestones")
for ix, s in enumerate(scheduler):
params = scheduler_params[str(ix)]
schedulers.append(getattr(torch.optim.lr_scheduler, s)(optimizer, **params))
scheduler_class = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones)
return scheduler_class
else:
scheduler_class = getattr(torch.optim.lr_scheduler, scheduler)
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff (0.7.0)
165-165: Use scheduler_params.get("milestones")
instead of scheduler_params.get("milestones", None)
Replace scheduler_params.get("milestones", None)
with scheduler_params.get("milestones")
(SIM910)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (1)
dreem/models/model_utils.py (1)
162-165
: Consider optimizing the milestones parameter retrieval.The implementation for multiple schedulers looks good, but there's a minor optimization opportunity.
- milestones = scheduler_params.get("milestones", None) + milestones = scheduler_params.get("milestones")🧰 Tools
🪛 Ruff (0.7.0)
165-165: Use
scheduler_params.get("milestones")
instead ofscheduler_params.get("milestones", None)
Replace
scheduler_params.get("milestones", None)
withscheduler_params.get("milestones")
(SIM910)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
dreem/inference/eval.py
(1 hunks)dreem/models/model_utils.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- dreem/inference/eval.py
🧰 Additional context used
🪛 Ruff (0.7.0)
dreem/models/model_utils.py
165-165: Use scheduler_params.get("milestones")
instead of scheduler_params.get("milestones", None)
Replace scheduler_params.get("milestones", None)
with scheduler_params.get("milestones")
(SIM910)
🔇 Additional comments (1)
dreem/models/model_utils.py (1)
176-177
: LGTM: Error handling remains robust.
The error handling for scheduler instantiation is well-maintained, with graceful fallbacks for different capitalizations.
for ix, s in enumerate(scheduler): | ||
params = scheduler_params[str(ix)] | ||
schedulers.append( | ||
getattr(torch.optim.lr_scheduler, s)(optimizer, **params) | ||
) | ||
scheduler_class = torch.optim.lr_scheduler.SequentialLR( | ||
optimizer, schedulers, milestones | ||
) | ||
return scheduler_class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
🛠️ Refactor suggestion
Validation for sequential schedulers is essential, but documentation exists
The codebase shows that the default configuration uses a single scheduler (ReduceLROnPlateau), and the new sequential scheduler feature is an extension. While validation is important, the suggested changes should be more flexible:
- The milestone validation should be optional since single schedulers don't need milestones
- The type checking is already handled by PyTorch's scheduler implementations
- The configuration format is consistent with the existing YAML structure
if isinstance(scheduler, list):
schedulers = []
milestones = scheduler_params.get("milestones")
+ if milestones is None:
+ raise ValueError("'milestones' parameter is required when using multiple schedulers")
+ if len(milestones) != len(scheduler) - 1:
+ raise ValueError(
+ f"Expected {len(scheduler)-1} milestones for {len(scheduler)} schedulers, "
+ f"got {len(milestones)}"
+ )
for ix, s in enumerate(scheduler):
params = scheduler_params[str(ix)]
🔗 Analysis chain
Add validation and documentation for scheduler configuration.
While the implementation for multiple schedulers is functional, consider adding:
- Validation to ensure
milestones
length matches the number of schedulers minus one - Type checking for scheduler parameters
- Documentation about the expected format of the scheduler configuration
Let's check if there are any configuration examples or tests that demonstrate the expected format:
Consider adding these improvements:
if isinstance(scheduler, list):
schedulers = []
milestones = scheduler_params.get("milestones")
+ if milestones is None or len(milestones) != len(scheduler) - 1:
+ raise ValueError(
+ f"Expected {len(scheduler)-1} milestones for {len(scheduler)} schedulers, "
+ f"got {len(milestones) if milestones else 0}"
+ )
for ix, s in enumerate(scheduler):
params = scheduler_params[str(ix)]
+ if not isinstance(params, dict):
+ raise TypeError(f"Scheduler parameters for index {ix} must be a dictionary")
Also, add docstring examples:
"""
Example config for multiple schedulers:
{
"name": ["LinearLR", "CosineAnnealingLR"],
"milestones": [5], # Switch from LinearLR to CosineAnnealingLR after epoch 5
"0": {"start_factor": 0.1, "total_iters": 5}, # LinearLR params
"1": {"T_max": 10} # CosineAnnealingLR params
}
"""
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for scheduler configuration examples
rg -A 10 "scheduler.*name.*\["
# Search for related tests
rg -l "init_scheduler.*test"
Length of output: 63
Script:
#!/bin/bash
# Search for scheduler configuration patterns more broadly
rg -A 10 "scheduler.*=.*\["
# Search for scheduler tests with different patterns
rg -l "scheduler.*test"
# Search for configuration examples in test files
rg -A 10 "scheduler.*:.*{"
# Look for any YAML/JSON config files that might contain scheduler examples
fd -e yaml -e yml -e json | xargs rg -l "scheduler"
Length of output: 1086
Script:
#!/bin/bash
# Check the scheduler configuration in the base config files
cat tests/configs/base.yaml dreem/training/configs/base.yaml
# Look for scheduler-related test files
fd -e py | xargs rg -l "test.*scheduler"
# Search for scheduler configuration examples with fixed pattern
rg -A 5 '"scheduler":\s*{'
rg -A 5 'scheduler:' tests/configs/base.yaml dreem/training/configs/base.yaml
Length of output: 6281
Merged as part of #100 |
…e; params can take list of schedulers in scheduler.name, with lr params in scheduler.name."0" etc. based on number of schedulers
Summary by CodeRabbit
New Features
Improvements
Bug Fixes