Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- add support for multiple learning rate schedulers in sequential mod… #99

Closed
wants to merge 3 commits into from

Conversation

shaikh58
Copy link
Contributor

@shaikh58 shaikh58 commented Nov 23, 2024

…e; params can take list of schedulers in scheduler.name, with lr params in scheduler.name."0" etc. based on number of schedulers

  • reduce step frequency of lr schedule to each epoch down from 10 epochs
  • patch in eval key retrieval

Summary by CodeRabbit

  • New Features

    • Enhanced metric logging for improved clarity.
    • Added a fallback mechanism for handling missing environment variables.
    • Support for multiple learning rate schedulers in the scheduler initialization process.
  • Improvements

    • More frequent adjustments to the learning rate during training.
    • Detailed saving process for test results, including conditional metrics saving.
  • Bug Fixes

    • Improved error handling for scheduler instantiation.

…e; params can take list of schedulers in scheduler.name, with lr params in scheduler.name."0" etc. based on number of schedulers

- reduce step frequency of lr schedule to each epoch down from 10 epochs
- patch in eval key retrieval
Copy link
Contributor

coderabbitai bot commented Nov 23, 2024

Walkthrough

The pull request introduces several modifications across multiple files. In dreem/inference/eval.py, the logging method for test metrics is updated to use bracket notation, and error handling for the POD_INDEX environment variable is improved with a user prompt. In dreem/models/gtr_runner.py, the frequency of the learning rate scheduler is changed from 10 to 1, and the process for saving test results is enhanced. Lastly, dreem/models/model_utils.py updates the init_scheduler function to accept multiple learning rate schedulers and improves error handling during instantiation.

Changes

File Change Summary
dreem/inference/eval.py Updated metric logging from dot notation to bracket notation; improved error handling for POD_INDEX.
dreem/models/gtr_runner.py Changed learning rate scheduler frequency from 10 to 1; enhanced test results saving process.
dreem/models/model_utils.py Modified init_scheduler to accept a list of schedulers and improved error handling for scheduler instantiation.
dreem/training/configs/base.yaml Added comments for potential configuration of multiple schedulers; no logic changes made.

Possibly related PRs

  • Mustafa multiple data paths #97: The changes in the main PR regarding error handling for environment variables may relate to the configuration enhancements in PR Mustafa multiple data paths #97, which also involves adjustments to how paths are managed in the configuration, potentially impacting error handling and logging in similar contexts.

Poem

🐰 In the meadow where metrics play,
Logging's changed in a clearer way.
Schedulers dance, their rates now fine,
With better handling, all's in line.
HDF5 saves our test results bright,
In the rabbit's world, everything feels right! 🌼

Warning

Rate limit exceeded

@shaikh58 has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 19 minutes and 48 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between f028b17 and 580fd2b.


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between f028b17 and 580fd2b.

📒 Files selected for processing (1)
  • dreem/training/configs/base.yaml (1 hunks)
🧰 Additional context used
🪛 yamllint (1.29.0-1)
dreem/training/configs/base.yaml

[error] 50-50: trailing spaces

(trailing-spaces)


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (4)
dreem/inference/eval.py (2)

Line range hint 28-33: Critical: Replace interactive input with a proper fallback

Using input() as a fallback for missing environment variables is problematic in production environments as it can cause the process to hang indefinitely waiting for user input. Consider using a default value or raising an appropriate exception instead.

- except KeyError:
-     index = input("Pod Index Not found! Please choose a pod index: ")
+ except KeyError:
+     logger.error("POD_INDEX environment variable not set")
+     raise ValueError("POD_INDEX environment variable is required for checkpoint selection")
🧰 Tools
🪛 Ruff (0.7.0)

54-54: f-string without any placeholders

Remove extraneous f prefix

(F541)


Line range hint 67-69: Critical: Function doesn't match return type annotation

The function is annotated to return dict[int, sio.Labels] but doesn't return anything. Additionally, the metrics from trainer.test() are being discarded.

Consider either:

  1. Collecting and returning the results as per the type annotation
  2. Updating the type annotation if returning results isn't required
 for label_file, vid_file in zip(labels_files, vid_files):
     dataset = eval_cfg.get_dataset(
         label_files=[label_file], vid_files=[vid_file], mode="test"
     )
     dataloader = eval_cfg.get_dataloader(dataset, mode="test")
     metrics = trainer.test(model, dataloader)
+    logger.info(f"Test metrics: {metrics}")
🧰 Tools
🪛 Ruff (0.7.0)

54-54: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/models/gtr_runner.py (1)

Line range hint 1-375: Consider architectural improvements

The implementation is solid, but consider the following improvements:

  1. Document the scheduler frequency change in docstrings and possibly add a configuration option for it
  2. Make the test results saving more configurable by:
    • Adding options for which metrics to save
    • Allowing customization of save conditions beyond just "num_switches"
    • Providing configuration for what frame data to save
dreem/models/model_utils.py (1)

165-165: Remove redundant default value in get() method

The default value for dict.get() is None, so specifying None explicitly is unnecessary. Simplify the code by omitting the redundant None argument.

Apply this diff to simplify the code:

-                milestones = scheduler_params.get("milestones", None)
+                milestones = scheduler_params.get("milestones")
🧰 Tools
🪛 Ruff (0.7.0)

165-165: Use scheduler_params.get("milestones") instead of scheduler_params.get("milestones", None)

Replace scheduler_params.get("milestones", None) with scheduler_params.get("milestones")

(SIM910)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 9bf0af7 and 7ab066c.

📒 Files selected for processing (3)
  • dreem/inference/eval.py (1 hunks)
  • dreem/models/gtr_runner.py (1 hunks)
  • dreem/models/model_utils.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.7.0)
dreem/models/model_utils.py

165-165: Use scheduler_params.get("milestones") instead of scheduler_params.get("milestones", None)

Replace scheduler_params.get("milestones", None) with scheduler_params.get("milestones")

(SIM910)

🔇 Additional comments (5)
dreem/inference/eval.py (2)

55-55: LGTM! Consistent dictionary access notation

The change from dot notation to bracket notation is correct and maintains consistency with how the metrics dictionary is accessed elsewhere in the code (see line 41).


Line range hint 41-43: Verify metric configuration compatibility with multiple schedulers

Since this PR adds support for multiple learning rate schedulers, verify that the test metrics configuration properly handles cases where different schedulers are used.

🧰 Tools
🪛 Ruff (0.7.0)

54-54: f-string without any placeholders

Remove extraneous f prefix

(F541)

dreem/models/gtr_runner.py (1)

250-250: Verify impact of more frequent learning rate updates

Changing the scheduler frequency from 10 to 1 means the learning rate will be updated every epoch instead of every 10 epochs. While this provides more responsive learning rate adjustments, it could potentially:

  1. Make training more unstable if validation metrics are noisy
  2. Increase computational overhead due to more frequent scheduler updates

Let's verify the scheduler configuration across the codebase:

✅ Verification successful

Frequency change is safe and aligns with scheduler configuration

The change from 10 to 1 for scheduler frequency is safe and well-supported by the codebase evidence:

  1. The default scheduler is ReduceLROnPlateau, which is designed to handle frequent monitoring of metrics
  2. The scheduler configuration in both test and training configs (tests/configs/base.yaml and dreem/training/configs/base.yaml) shows:
    • patience: 10 - prevents premature LR changes by waiting 10 epochs before reducing LR
    • threshold: 1e-4 - provides stability by requiring significant metric changes
  3. These safeguards ensure that more frequent monitoring (every epoch) won't lead to unstable training
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any custom scheduler configurations that might be affected

# Search for scheduler configurations in yaml files
echo "Checking scheduler configurations in yaml files:"
rg -A 5 "scheduler:" --type yaml

# Search for other scheduler frequency references
echo "Checking for other scheduler frequency references:"
rg -A 5 "frequency.*[0-9]+" --type python

Length of output: 870


Script:

#!/bin/bash
# Let's try a more specific search for scheduler-related configurations and frequency settings

# Search for scheduler configurations in Python files
echo "Checking scheduler configurations in Python files:"
rg -t py -A 5 "configure_optimizers"

# Search for frequency settings in Python files
echo "Checking frequency settings in Python files:"
rg -t py "frequency.*=.*[0-9]+"

# Search for ReduceLROnPlateau usage
echo "Checking ReduceLROnPlateau usage:"
rg -t py "ReduceLROnPlateau"

Length of output: 1970

dreem/models/model_utils.py (2)

163-171: Multiple scheduler support implemented correctly

The addition of multiple scheduler support using SequentialLR is well-executed. The code correctly handles a list of schedulers and applies them sequentially with the specified milestones. This enhancement increases the flexibility of the learning rate scheduling.

🧰 Tools
🪛 Ruff (0.7.0)

165-165: Use scheduler_params.get("milestones") instead of scheduler_params.get("milestones", None)

Replace scheduler_params.get("milestones", None) with scheduler_params.get("milestones")

(SIM910)


162-173: ⚠️ Potential issue

Initialize scheduler_class to prevent potential UnboundLocalError

If an AttributeError occurs before scheduler_class is assigned (e.g., if getattr fails), referencing scheduler_class in the except block will raise an UnboundLocalError. To prevent this, initialize scheduler_class to None before the try-except block.

Apply this diff to fix the issue:

+        scheduler_class = None
         try:
             # if a list is provided, apply each one sequentially
             if isinstance(scheduler, list):
                 schedulers = []
                 milestones = scheduler_params.get("milestones")
                 for ix, s in enumerate(scheduler):
                     params = scheduler_params[str(ix)]
                     schedulers.append(getattr(torch.optim.lr_scheduler, s)(optimizer, **params))
                 scheduler_class = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones)
                 return scheduler_class
             else:
                 scheduler_class = getattr(torch.optim.lr_scheduler, scheduler)

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff (0.7.0)

165-165: Use scheduler_params.get("milestones") instead of scheduler_params.get("milestones", None)

Replace scheduler_params.get("milestones", None) with scheduler_params.get("milestones")

(SIM910)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (1)
dreem/models/model_utils.py (1)

162-165: Consider optimizing the milestones parameter retrieval.

The implementation for multiple schedulers looks good, but there's a minor optimization opportunity.

-            milestones = scheduler_params.get("milestones", None)
+            milestones = scheduler_params.get("milestones")
🧰 Tools
🪛 Ruff (0.7.0)

165-165: Use scheduler_params.get("milestones") instead of scheduler_params.get("milestones", None)

Replace scheduler_params.get("milestones", None) with scheduler_params.get("milestones")

(SIM910)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 7ab066c and f028b17.

📒 Files selected for processing (2)
  • dreem/inference/eval.py (1 hunks)
  • dreem/models/model_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • dreem/inference/eval.py
🧰 Additional context used
🪛 Ruff (0.7.0)
dreem/models/model_utils.py

165-165: Use scheduler_params.get("milestones") instead of scheduler_params.get("milestones", None)

Replace scheduler_params.get("milestones", None) with scheduler_params.get("milestones")

(SIM910)

🔇 Additional comments (1)
dreem/models/model_utils.py (1)

176-177: LGTM: Error handling remains robust.

The error handling for scheduler instantiation is well-maintained, with graceful fallbacks for different capitalizations.

Comment on lines +166 to +174
for ix, s in enumerate(scheduler):
params = scheduler_params[str(ix)]
schedulers.append(
getattr(torch.optim.lr_scheduler, s)(optimizer, **params)
)
scheduler_class = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers, milestones
)
return scheduler_class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

🛠️ Refactor suggestion

Validation for sequential schedulers is essential, but documentation exists

The codebase shows that the default configuration uses a single scheduler (ReduceLROnPlateau), and the new sequential scheduler feature is an extension. While validation is important, the suggested changes should be more flexible:

  1. The milestone validation should be optional since single schedulers don't need milestones
  2. The type checking is already handled by PyTorch's scheduler implementations
  3. The configuration format is consistent with the existing YAML structure
         if isinstance(scheduler, list):
             schedulers = []
             milestones = scheduler_params.get("milestones")
+            if milestones is None:
+                raise ValueError("'milestones' parameter is required when using multiple schedulers")
+            if len(milestones) != len(scheduler) - 1:
+                raise ValueError(
+                    f"Expected {len(scheduler)-1} milestones for {len(scheduler)} schedulers, "
+                    f"got {len(milestones)}"
+                )
             for ix, s in enumerate(scheduler):
                 params = scheduler_params[str(ix)]
🔗 Analysis chain

Add validation and documentation for scheduler configuration.

While the implementation for multiple schedulers is functional, consider adding:

  1. Validation to ensure milestones length matches the number of schedulers minus one
  2. Type checking for scheduler parameters
  3. Documentation about the expected format of the scheduler configuration

Let's check if there are any configuration examples or tests that demonstrate the expected format:

Consider adding these improvements:

         if isinstance(scheduler, list):
             schedulers = []
             milestones = scheduler_params.get("milestones")
+            if milestones is None or len(milestones) != len(scheduler) - 1:
+                raise ValueError(
+                    f"Expected {len(scheduler)-1} milestones for {len(scheduler)} schedulers, "
+                    f"got {len(milestones) if milestones else 0}"
+                )
             for ix, s in enumerate(scheduler):
                 params = scheduler_params[str(ix)]
+                if not isinstance(params, dict):
+                    raise TypeError(f"Scheduler parameters for index {ix} must be a dictionary")

Also, add docstring examples:

    """
    Example config for multiple schedulers:
    {
        "name": ["LinearLR", "CosineAnnealingLR"],
        "milestones": [5],  # Switch from LinearLR to CosineAnnealingLR after epoch 5
        "0": {"start_factor": 0.1, "total_iters": 5},  # LinearLR params
        "1": {"T_max": 10}  # CosineAnnealingLR params
    }
    """
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for scheduler configuration examples
rg -A 10 "scheduler.*name.*\[" 

# Search for related tests
rg -l "init_scheduler.*test" 

Length of output: 63


Script:

#!/bin/bash
# Search for scheduler configuration patterns more broadly
rg -A 10 "scheduler.*=.*\[" 

# Search for scheduler tests with different patterns
rg -l "scheduler.*test"

# Search for configuration examples in test files
rg -A 10 "scheduler.*:.*{" 

# Look for any YAML/JSON config files that might contain scheduler examples
fd -e yaml -e yml -e json | xargs rg -l "scheduler"

Length of output: 1086


Script:

#!/bin/bash
# Check the scheduler configuration in the base config files
cat tests/configs/base.yaml dreem/training/configs/base.yaml

# Look for scheduler-related test files
fd -e py | xargs rg -l "test.*scheduler"

# Search for scheduler configuration examples with fixed pattern
rg -A 5 '"scheduler":\s*{' 
rg -A 5 'scheduler:' tests/configs/base.yaml dreem/training/configs/base.yaml

Length of output: 6281

@shaikh58
Copy link
Contributor Author

Merged as part of #100

@shaikh58 shaikh58 closed this Nov 25, 2024
@shaikh58 shaikh58 deleted the mustafa-sequential-lr branch November 25, 2024 19:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant