Skip to content

Commit

Permalink
Merge pull request #234 from Nixtla/fix/api_auto
Browse files Browse the repository at this point in the history
fix/api auto
  • Loading branch information
AzulGarza authored Apr 21, 2022
2 parents b9e4063 + c84efb7 commit 763faf3
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 1,004 deletions.
938 changes: 79 additions & 859 deletions nbs/auto.ipynb

Large diffs are not rendered by default.

118 changes: 6 additions & 112 deletions nbs/experiments__utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1350,28 +1350,9 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:hyperopt.tpe:build_posterior_wrapper took 0.013117 seconds\n",
"INFO:hyperopt.tpe:TPE using 0 trials\n",
"/Users/cchallu/opt/anaconda3/envs/neuralforecast/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:133: UserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" f\"The dataloader, {name}, does not have many workers which may be a bottleneck.\"\n",
"/Users/cchallu/opt/anaconda3/envs/neuralforecast/lib/python3.7/site-packages/torch/nn/functional.py:3635: UserWarning: Default upsampling behavior when mode=linear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n",
" \"See the documentation of nn.Upsample for details.\".format(mode)\n",
"/Users/cchallu/opt/anaconda3/envs/neuralforecast/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:133: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" f\"The dataloader, {name}, does not have many workers which may be a bottleneck.\"\n",
"/Users/cchallu/opt/anaconda3/envs/neuralforecast/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:133: UserWarning: The dataloader, predict_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
" f\"The dataloader, {name}, does not have many workers which may be a bottleneck.\"\n",
"INFO:hyperopt.tpe:build_posterior_wrapper took 0.010917 seconds\n",
"INFO:hyperopt.tpe:TPE using 1/1 trials with best loss 8.598013\n"
]
}
],
"outputs": [],
"source": [
"space = nhits_space(n_time_out=24) #, n_series=1, n_x=1, n_s=0, frequency='H')\n",
"space = nhits_space(horizon=24) #, n_series=1, n_x=1, n_s=0, frequency='H')\n",
"space['max_steps'] = hp.choice('max_steps', [1]) # Override max_steps for faster example\n",
"# The suggested spaces are partial, here we complete them with data specific information\n",
"space['n_series'] = hp.choice('n_series', [ Y_df['unique_id'].nunique() ])\n",
Expand All @@ -1395,83 +1376,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"NHITS(\n",
" (model): _NHITS(\n",
" (blocks): ModuleList(\n",
" (0): _NHITSBlock(\n",
" (pooling_layer): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=True)\n",
" (static_encoder): _StaticFeaturesEncoder(\n",
" (encoder): Sequential(\n",
" (0): Dropout(p=0.5, inplace=False)\n",
" (1): Linear(in_features=1, out_features=1, bias=True)\n",
" (2): ReLU()\n",
" )\n",
" )\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=169, out_features=256, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=256, out_features=256, bias=True)\n",
" (3): ReLU()\n",
" (4): Linear(in_features=256, out_features=256, bias=True)\n",
" (5): ReLU()\n",
" (6): Linear(in_features=256, out_features=73, bias=True)\n",
" )\n",
" (basis): _IdentityBasis()\n",
" )\n",
" (1): _NHITSBlock(\n",
" (pooling_layer): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=True)\n",
" (static_encoder): _StaticFeaturesEncoder(\n",
" (encoder): Sequential(\n",
" (0): Dropout(p=0.5, inplace=False)\n",
" (1): Linear(in_features=1, out_features=1, bias=True)\n",
" (2): ReLU()\n",
" )\n",
" )\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=169, out_features=256, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=256, out_features=256, bias=True)\n",
" (3): ReLU()\n",
" (4): Linear(in_features=256, out_features=256, bias=True)\n",
" (5): ReLU()\n",
" (6): Linear(in_features=256, out_features=73, bias=True)\n",
" )\n",
" (basis): _IdentityBasis()\n",
" )\n",
" (2): _NHITSBlock(\n",
" (pooling_layer): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=True)\n",
" (static_encoder): _StaticFeaturesEncoder(\n",
" (encoder): Sequential(\n",
" (0): Dropout(p=0.5, inplace=False)\n",
" (1): Linear(in_features=1, out_features=1, bias=True)\n",
" (2): ReLU()\n",
" )\n",
" )\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=169, out_features=256, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=256, out_features=256, bias=True)\n",
" (3): ReLU()\n",
" (4): Linear(in_features=256, out_features=256, bias=True)\n",
" (5): ReLU()\n",
" (6): Linear(in_features=256, out_features=96, bias=True)\n",
" )\n",
" (basis): _IdentityBasis()\n",
" )\n",
" )\n",
" )\n",
")"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"model"
]
Expand All @@ -1487,20 +1392,9 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:hyperopt.tpe:build_posterior_wrapper took 0.010783 seconds\n",
"INFO:hyperopt.tpe:TPE using 0 trials\n",
"INFO:hyperopt.tpe:build_posterior_wrapper took 0.016312 seconds\n",
"INFO:hyperopt.tpe:TPE using 1/1 trials with best loss 8.788279\n"
]
}
],
"outputs": [],
"source": [
"space = nbeats_space(n_time_out=24)\n",
"space = nbeats_space(horizon=24)\n",
"space['max_steps'] = hp.choice('max_steps', [1]) # Override max_steps for faster example\n",
" # The suggested spaces are partial, here we complete them with data specific information\n",
"space['n_series'] = hp.choice('n_series', [ Y_df['unique_id'].nunique() ])\n",
Expand Down Expand Up @@ -1529,7 +1423,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
Expand Down
78 changes: 45 additions & 33 deletions neuralforecast/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

# Cell
class AutoBaseModel(object):
def __init__(self, n_time_out):
def __init__(self, horizon):
super(AutoBaseModel, self).__init__()

self.n_time_out = n_time_out
self.horizon = horizon

def fit(self, Y_df, X_df, S_df, hyperopt_steps, loss_function_val, n_ts_val, results_dir,
save_trials=False, loss_functions_test=None, n_ts_test=0, return_test_forecast=False, verbose=False):
Expand Down Expand Up @@ -55,21 +55,21 @@ def forecast(self, Y_df: pd.DataFrame, X_df: pd.DataFrame = None, S_df: pd.DataF

# Cell
class NHITS(AutoBaseModel):
def __init__(self, n_time_out, space=None):
super(NHITS, self).__init__(n_time_out)
def __init__(self, horizon, space=None):
super(NHITS, self).__init__(horizon)

if space is None:
space = nhits_space(n_time_out=n_time_out)
space = nhits_space(horizon=horizon)
self.space = space


def nhits_space(n_time_out: int) -> dict:
def nhits_space(horizon: int) -> dict:
"""
Suggested hyperparameters search space for tuning. To be used with hyperopt library.
Parameters
----------
n_time_out: int
horizon: int
Forecasting horizon.
Returns
Expand All @@ -81,8 +81,8 @@ def nhits_space(n_time_out: int) -> dict:
space= {# Architecture parameters
'model':'nhits',
'mode': 'simple',
'n_time_in': hp.choice('n_time_in', [2*n_time_out, 3*n_time_out, 5*n_time_out]),
'n_time_out': hp.choice('n_time_out', [n_time_out]),
'n_time_in': hp.choice('n_time_in', [2*horizon, 3*horizon, 5*horizon]),
'n_time_out': hp.choice('n_time_out', [horizon]),
'shared_weights': hp.choice('shared_weights', [False]),
'activation': hp.choice('activation', ['ReLU']),
'initialization': hp.choice('initialization', ['lecun_normal']),
Expand Down Expand Up @@ -124,20 +124,20 @@ def nhits_space(n_time_out: int) -> dict:

# Cell
class NBEATS(AutoBaseModel):
def __init__(self, n_time_out, space=None):
super(NBEATS, self).__init__(n_time_out)
def __init__(self, horizon, space=None):
super(NBEATS, self).__init__(horizon)

if space is None:
space = nbeats_space(n_time_out=n_time_out)
space = nbeats_space(horizon=horizon)
self.space = space

def nbeats_space(n_time_out: int) -> dict:
def nbeats_space(horizon: int) -> dict:
"""
Suggested hyperparameters search space for tuning. To be used with hyperopt library.
Parameters
----------
n_time_out: int
horizon: int
Forecasting horizon.
Returns
Expand All @@ -149,8 +149,8 @@ def nbeats_space(n_time_out: int) -> dict:
space= {# Architecture parameters
'model':'nbeats',
'mode': 'simple',
'n_time_in': hp.choice('n_time_in', [2*n_time_out, 3*n_time_out, 5*n_time_out]),
'n_time_out': hp.choice('n_time_out', [n_time_out]),
'n_time_in': hp.choice('n_time_in', [2*horizon, 3*horizon, 5*horizon]),
'n_time_out': hp.choice('n_time_out', [horizon]),
'shared_weights': hp.choice('shared_weights', [False]),
'activation': hp.choice('activation', ['ReLU']),
'initialization': hp.choice('initialization', ['lecun_normal']),
Expand Down Expand Up @@ -186,22 +186,22 @@ def nbeats_space(n_time_out: int) -> dict:

# Cell
class RNN(AutoBaseModel):
def __init__(self, n_time_out, space=None):
super(RNN, self).__init__(n_time_out)
def __init__(self, horizon, space=None):
super(RNN, self).__init__(horizon)

if space is None:
space = rnn_space(n_time_out=n_time_out)
space = rnn_space(horizon=horizon)
self.space = space

def rnn_space(n_time_out: int) -> dict:
def rnn_space(horizon: int) -> dict:
"""
Suggested hyperparameters search space for tuning. To be used with hyperopt library.
This space is not complete for training, will be completed automatically within
the fit method of the AutoBaseModels.
Parameters
----------
n_time_out: int
horizon: int
Forecasting horizon
Returns
Expand All @@ -213,8 +213,8 @@ def rnn_space(n_time_out: int) -> dict:
space= {# Architecture parameters
'model':'rnn',
'mode': 'full',
'n_time_in': hp.choice('n_time_in', [1*n_time_out, 2*n_time_out, 3*n_time_out]),
'n_time_out': hp.choice('n_time_out', [n_time_out]),
'n_time_in': hp.choice('n_time_in', [1*horizon, 2*horizon, 3*horizon]),
'n_time_out': hp.choice('n_time_out', [horizon]),
'cell_type': hp.choice('cell_type', ['LSTM', 'GRU']),
'state_hsize': hp.choice('state_hsize', [10, 20, 50, 100]),
'dilations': hp.choice('dilations', [ [[1, 2]], [[1, 2, 4, 8]], [[1,2],[4,8]] ]),
Expand Down Expand Up @@ -254,11 +254,13 @@ def rnn_space(n_time_out: int) -> dict:

# Cell
class AutoNF(object):
def __init__(self, config_dict, n_time_out):
def __init__(self, models, horizon):
super(AutoNF, self).__init__()

self.config_dict = config_dict
self.n_time_out = n_time_out
if isinstance(models, list):
self.config_dict = {model: dict(space=None) for model in models}
else:
self.config_dict = models
self.horizon = horizon

"""
The AutoNF class is an automated machine learning class that simultaneously explores hyperparameters
Expand All @@ -270,13 +272,24 @@ def __init__(self, config_dict, n_time_out):
available for non-Machine Learning experts.
The AutoNF class inherits the optimized neural forecast `fit` and `predict` methods.
Parameters
----------
models: List or Dict
List of models or Dictionary with configuration.
Keys should be name of models.
For each model specify the hyperparameter space
(None will use default suggested space), hyperopt steps and timeout.
horizon: int
Forecast horizon
"""

def fit(self,
Y_df: pd.DataFrame, X_df: pd.DataFrame, S_df: pd.DataFrame,
loss_function_val: callable, loss_functions_test: dict,
n_ts_val: int, n_ts_test: int,
results_dir: str,
hyperopt_steps: int = None,
return_forecasts: bool = False,
verbose: bool = False):
"""
Expand All @@ -285,10 +298,6 @@ def fit(self,
Parameters
----------
config_dict: Dict
Dictionary with configuration. Keys should be name of models.
For each model specify the hyperparameter space
(None will use default suggested space), hyperopt steps and timeout.
Y_df: pd.DataFrame
Target time series with columns ['unique_id', 'ds', 'y'].
X_df: pd.DataFrame
Expand All @@ -305,6 +314,8 @@ def fit(self,
Number of timestamps in validation.
ts_in_test: int
Number of timestamps in test.
hyperopt_steps: int
Number of hyperopt steps.
return_forecasts: bool
If true return forecast on test.
verbose:
Expand All @@ -324,9 +335,10 @@ def fit(self,
model_config = self.config_dict[model_str]

# Run automated hyperparameter optimization
hyperopt_steps = model_config['hyperopt_steps']
if hyperopt_steps is None:
hyperopt_steps = model_config['hyperopt_steps']
results_dir_model = f'{results_dir}/{model_str}'
model = MODEL_DICT[model_str](n_time_out=self.n_time_out, space=model_config['space'])
model = MODEL_DICT[model_str](horizon=self.horizon, space=model_config['space'])

model.fit(Y_df=Y_df, X_df=X_df, S_df=S_df, hyperopt_steps=hyperopt_steps,
n_ts_val=n_ts_val,
Expand Down

0 comments on commit 763faf3

Please sign in to comment.