-
Notifications
You must be signed in to change notification settings - Fork 899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/val_sample_weight error for models inherited from RegressionModel #2626
base: master
Are you sure you want to change the base?
Fix/val_sample_weight error for models inherited from RegressionModel #2626
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thank you for your contribution.
Can you please add unit-tests so that we can make sure the fix works as expected for all the situations?
@@ -7,6 +7,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co | |||
|
|||
[Full Changelog](https://github.com/unit8co/darts/compare/0.31.0...master) | |||
|
|||
- Fix the bug in [#2579 ](https://github.com/unit8co/darts/issues/2579) that causes an error when `val_sample_weight` is set in the CatBoost and XGBoost models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Fix the bug in [#2579 ](https://github.com/unit8co/darts/issues/2579) that causes an error when `val_sample_weight` is set in the CatBoost and XGBoost models. | |
- Fix a bug in `RegressionModel` when `val_sample_weight` is used with a single timeseries. [#2626](https://github.com/unit8co/darts/pull/2626) by [Kylin Schmidt](https://github.com/kylinschmidt). |
@@ -1403,7 +1405,7 @@ ts: TimeSeries = AirPassengers().load() | |||
```python | |||
# Assuming a multivariate TimeSeries named series with 3 columns or variables. | |||
# To apply fn to columns with names '0' and '2': | |||
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you revert this change and the others below?
@@ -588,7 +588,7 @@ def _add_val_set_to_kwargs( | |||
val_weights = val_weights or None | |||
else: | |||
val_sets = [(val_samples, val_labels)] | |||
val_weights = val_weight | |||
val_weights = [val_weight] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks neat!
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2626 +/- ##
==========================================
- Coverage 94.24% 94.19% -0.06%
==========================================
Files 141 141
Lines 15463 15463
==========================================
- Hits 14573 14565 -8
- Misses 890 898 +8 ☔ View full report in Codecov by Sentry. |
Checklist before merging this PR:
Fix the bug in #2579 that causes an error when
val_sample_weight
is set in the CatBoost and XGBoost models.Summary
The CatBoost and XGBoost models, which inherit from the RegressionModel class, encounter an error when setting val_sample_weight. This issue is resolved by modifying the _add_val_set_to_kwargs method in RegressionModel, changing val_weights = val_weight to val_weights = [val_weight].
Other Information