Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement option to have multiple learning rates #329

Merged
merged 9 commits into from
Jan 17, 2025
Prev Previous commit
Next Next commit
Get params from task instead of model
Signed-off-by: Francesc Marti Escofet <f.martiescofet@gmail.com>
fmartiescofet committed Jan 17, 2025
commit 934da81921b6b7c594c70512d8fad072f5d3b006
4 changes: 2 additions & 2 deletions terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
@@ -58,11 +58,11 @@ def configure_optimizers(
if self.hparams.get("lr_overrides", None) is not None and len(self.hparams["lr_overrides"]) > 0:
parameters = []
for param_name, custom_lr in self.hparams["lr_overrides"].items():
p = [p for n, p in self.model.named_parameters() if param_name in n]
p = [p for n, p in self.named_parameters() if param_name in n]
parameters.append({"params": p, "lr": custom_lr})
rest_p = [
p
for n, p in self.model.named_parameters()
for n, p in self.named_parameters()
if all(param_name not in n for param_name in self.hparams["lr_overrides"])
]
parameters.append({"params": rest_p})