Skip to content

Commit

Permalink
fix: typo
Browse files Browse the repository at this point in the history
  • Loading branch information
madtoinou committed Nov 17, 2024
1 parent 367829d commit ef416bb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
19 changes: 10 additions & 9 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,19 +1777,20 @@ def gridsearch(

if "model" in parameters:
valid_model_list = isinstance(parameters["model"], list)
valid_nested_params = parameters["model"].get(
"wrapped_model_class"
) and all(
isinstance(params, (list, np.ndarray))
for p_name, params in parameters["model"].items()
if p_name != "wrapped_model_class"
valid_nested_params = (
not valid_model_list
and parameters["model"].get("wrapped_model_class")
and all(
isinstance(params, (list, np.ndarray))
for p_name, params in parameters["model"].items()
if p_name != "wrapped_model_class"
)
)
if not (valid_model_list or valid_nested_params):
raise_log(
ValueError(
"The 'model' entry in `parameters` must either be a list of instantiated models or "
"a dictionary containing as keys hyperparameter names, and as values lists of values "
"plus a 'wrapped_model_class': model_cls item.",
"When the 'model' key is set as a dictionary, it must contain the 'wrapped_model_class' key, "
"which represents the class of the model to be wrapped.",
logger,
)
)
Expand Down
4 changes: 2 additions & 2 deletions darts/tests/models/forecasting/test_regression_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3253,7 +3253,7 @@ def test_grid_search(self):
# Create grid over wrapped model parameters too
parameters = {
"model": {
"model_class": RandomForestRegressor,
"wrapped_model_class": RandomForestRegressor,
"min_samples_split": [2, 3],
},
"lags": [1],
Expand Down Expand Up @@ -3297,7 +3297,7 @@ def test_grid_search_invalid_wrapped_model_dict(self):
with pytest.raises(
ValueError,
match="When the 'model' key is set as a dictionary, it must contain "
"the 'model_class' key, which represents the class of the model "
"the 'wrapped_model_class' key, which represents the class of the model "
"to be wrapped.",
):
RegressionModel.gridsearch(
Expand Down

0 comments on commit ef416bb

Please sign in to comment.