From 9d75e7f330f8032ef3599c4411dbbed0556a7d26 Mon Sep 17 00:00:00 2001 From: Yanam24 Date: Tue, 17 Dec 2024 19:01:55 +0100 Subject: [PATCH] Custom rnn layers for TFT --- nbs/models.tft.ipynb | 1590 +++++++++++++++++++++++++++------- neuralforecast/models/tft.py | 144 ++- 2 files changed, 1421 insertions(+), 313 deletions(-) diff --git a/nbs/models.tft.ipynb b/nbs/models.tft.ipynb index bae287acf..cc47e3dfe 100644 --- a/nbs/models.tft.ipynb +++ b/nbs/models.tft.ipynb @@ -4,7 +4,15 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: PYTORCH_ENABLE_MPS_FALLBACK=1\n" + ] + } + ], "source": [ "%set_env PYTORCH_ENABLE_MPS_FALLBACK=1" ] @@ -15,7 +23,7 @@ "metadata": {}, "outputs": [], "source": [ - "#| default_exp models.tft" + "# | default_exp models.tft" ] }, { @@ -52,17 +60,18 @@ "metadata": {}, "outputs": [], "source": [ - "#| export\n", - "from typing import Tuple, Optional, Callable\n", + "# | export\n", + "from typing import Callable, Optional, Tuple\n", "\n", + "import pandas as pd\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch import Tensor\n", "from torch.nn import LayerNorm\n", - "import pandas as pd\n", - "from neuralforecast.losses.pytorch import MAE\n", - "from neuralforecast.common._base_windows import BaseWindows" + "\n", + "from neuralforecast.common._base_windows import BaseWindows\n", + "from neuralforecast.losses.pytorch import MAE" ] }, { @@ -71,11 +80,10 @@ "metadata": {}, "outputs": [], "source": [ - "#| hide\n", + "# | hide\n", "import logging\n", "import warnings\n", "\n", - "from fastcore.test import test_eq\n", "from nbdev.showdoc import show_doc" ] }, @@ -85,7 +93,7 @@ "metadata": {}, "outputs": [], "source": [ - "#| hide\n", + "# | hide\n", "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n", "warnings.filterwarnings(\"ignore\")" ] @@ -132,32 +140,33 @@ "metadata": {}, "outputs": [], "source": [ - "#| exporti\n", + "# | exporti\n", "def get_activation_fn(activation_str: str) -> Callable:\n", " activation_map = {\n", - " 'ReLU': F.relu,\n", - " 'Softplus': F.softplus,\n", - " 'Tanh': F.tanh,\n", - " 'SELU': F.selu,\n", - " 'LeakyReLU': F.leaky_relu,\n", - " 'Sigmoid': F.sigmoid,\n", - " 'ELU': F.elu,\n", - " 'GLU': F.glu\n", - " }\n", + " \"ReLU\": F.relu,\n", + " \"Softplus\": F.softplus,\n", + " \"Tanh\": F.tanh,\n", + " \"SELU\": F.selu,\n", + " \"LeakyReLU\": F.leaky_relu,\n", + " \"Sigmoid\": F.sigmoid,\n", + " \"ELU\": F.elu,\n", + " \"GLU\": F.glu,\n", + " }\n", " return activation_map.get(activation_str, F.elu)\n", "\n", + "\n", "class MaybeLayerNorm(nn.Module):\n", " def __init__(self, output_size, hidden_size, eps):\n", " super().__init__()\n", " if output_size and output_size == 1:\n", " self.ln = nn.Identity()\n", " else:\n", - " self.ln = LayerNorm(output_size if output_size else hidden_size,\n", - " eps=eps)\n", + " self.ln = LayerNorm(output_size if output_size else hidden_size, eps=eps)\n", "\n", " def forward(self, x):\n", " return self.ln(x)\n", "\n", + "\n", "class GLU(nn.Module):\n", " def __init__(self, hidden_size, output_size):\n", " super().__init__()\n", @@ -168,14 +177,17 @@ " x = F.glu(x)\n", " return x\n", "\n", + "\n", "class GRN(nn.Module):\n", - " def __init__(self,\n", - " input_size,\n", - " hidden_size,\n", - " output_size=None,\n", - " context_hidden_size=None,\n", - " dropout=0,\n", - " activation='ELU',):\n", + " def __init__(\n", + " self,\n", + " input_size,\n", + " hidden_size,\n", + " output_size=None,\n", + " context_hidden_size=None,\n", + " dropout=0,\n", + " activation=\"ELU\",\n", + " ):\n", " super().__init__()\n", " self.layer_norm = MaybeLayerNorm(output_size, hidden_size, eps=1e-3)\n", " self.lin_a = nn.Linear(input_size, hidden_size)\n", @@ -186,7 +198,7 @@ " self.dropout = nn.Dropout(dropout)\n", " self.out_proj = nn.Linear(input_size, output_size) if output_size else None\n", " self.activation_fn = get_activation_fn(activation)\n", - " \n", + "\n", " def forward(self, a: Tensor, c: Optional[Tensor] = None):\n", " x = self.lin_a(a)\n", " if c is not None:\n", @@ -238,9 +250,11 @@ "metadata": {}, "outputs": [], "source": [ - "#| exporti\n", + "# | exporti\n", "class TFTEmbedding(nn.Module):\n", - " def __init__(self, hidden_size, stat_input_size, futr_input_size, hist_input_size, tgt_size):\n", + " def __init__(\n", + " self, hidden_size, stat_input_size, futr_input_size, hist_input_size, tgt_size\n", + " ):\n", " super().__init__()\n", " # There are 4 types of input:\n", " # 1. Static continuous\n", @@ -253,92 +267,111 @@ " self.stat_input_size = stat_input_size\n", " self.futr_input_size = futr_input_size\n", " self.hist_input_size = hist_input_size\n", - " self.tgt_size = tgt_size\n", + " self.tgt_size = tgt_size\n", "\n", " # Instantiate Continuous Embeddings if size is not None\n", - " for attr, size in [('stat_exog_embedding', stat_input_size), \n", - " ('futr_exog_embedding', futr_input_size),\n", - " ('hist_exog_embedding', hist_input_size),\n", - " ('tgt_embedding', tgt_size)]:\n", + " for attr, size in [\n", + " (\"stat_exog_embedding\", stat_input_size),\n", + " (\"futr_exog_embedding\", futr_input_size),\n", + " (\"hist_exog_embedding\", hist_input_size),\n", + " (\"tgt_embedding\", tgt_size),\n", + " ]:\n", " if size:\n", " vectors = nn.Parameter(torch.Tensor(size, hidden_size))\n", " bias = nn.Parameter(torch.zeros(size, hidden_size))\n", " torch.nn.init.xavier_normal_(vectors)\n", - " setattr(self, attr+'_vectors', vectors)\n", - " setattr(self, attr+'_bias', bias)\n", + " setattr(self, attr + \"_vectors\", vectors)\n", + " setattr(self, attr + \"_bias\", bias)\n", " else:\n", - " setattr(self, attr+'_vectors', None)\n", - " setattr(self, attr+'_bias', None)\n", - "\n", - " def _apply_embedding(self,\n", - " cont: Optional[Tensor],\n", - " cont_emb: Tensor,\n", - " cont_bias: Tensor,\n", - " ):\n", - "\n", - " if (cont is not None):\n", - " #the line below is equivalent to following einsums\n", - " #e_cont = torch.einsum('btf,fh->bthf', cont, cont_emb)\n", - " #e_cont = torch.einsum('bf,fh->bhf', cont, cont_emb) \n", + " setattr(self, attr + \"_vectors\", None)\n", + " setattr(self, attr + \"_bias\", None)\n", + "\n", + " def _apply_embedding(\n", + " self,\n", + " cont: Optional[Tensor],\n", + " cont_emb: Tensor,\n", + " cont_bias: Tensor,\n", + " ):\n", + " if cont is not None:\n", + " # the line below is equivalent to following einsums\n", + " # e_cont = torch.einsum('btf,fh->bthf', cont, cont_emb)\n", + " # e_cont = torch.einsum('bf,fh->bhf', cont, cont_emb)\n", " e_cont = torch.mul(cont.unsqueeze(-1), cont_emb)\n", " e_cont = e_cont + cont_bias\n", " return e_cont\n", - " \n", + "\n", " return None\n", "\n", - " def forward(self, target_inp, \n", - " stat_exog=None, futr_exog=None, hist_exog=None):\n", - " # temporal/static categorical/continuous known/observed input \n", + " def forward(self, target_inp, stat_exog=None, futr_exog=None, hist_exog=None):\n", + " # temporal/static categorical/continuous known/observed input\n", " # tries to get input, if fails returns None\n", "\n", " # Static inputs are expected to be equal for all timesteps\n", " # For memory efficiency there is no assert statement\n", - " stat_exog = stat_exog[:,:] if stat_exog is not None else None\n", - "\n", - " s_inp = self._apply_embedding(cont=stat_exog,\n", - " cont_emb=self.stat_exog_embedding_vectors,\n", - " cont_bias=self.stat_exog_embedding_bias)\n", - " k_inp = self._apply_embedding(cont=futr_exog,\n", - " cont_emb=self.futr_exog_embedding_vectors,\n", - " cont_bias=self.futr_exog_embedding_bias)\n", - " o_inp = self._apply_embedding(cont=hist_exog,\n", - " cont_emb=self.hist_exog_embedding_vectors,\n", - " cont_bias=self.hist_exog_embedding_bias)\n", + " stat_exog = stat_exog[:, :] if stat_exog is not None else None\n", + "\n", + " s_inp = self._apply_embedding(\n", + " cont=stat_exog,\n", + " cont_emb=self.stat_exog_embedding_vectors,\n", + " cont_bias=self.stat_exog_embedding_bias,\n", + " )\n", + " k_inp = self._apply_embedding(\n", + " cont=futr_exog,\n", + " cont_emb=self.futr_exog_embedding_vectors,\n", + " cont_bias=self.futr_exog_embedding_bias,\n", + " )\n", + " o_inp = self._apply_embedding(\n", + " cont=hist_exog,\n", + " cont_emb=self.hist_exog_embedding_vectors,\n", + " cont_bias=self.hist_exog_embedding_bias,\n", + " )\n", "\n", " # Temporal observed targets\n", - " # t_observed_tgt = torch.einsum('btf,fh->btfh', \n", - " # target_inp, self.tgt_embedding_vectors) \n", - " target_inp = torch.matmul(target_inp.unsqueeze(3).unsqueeze(4),\n", - " self.tgt_embedding_vectors.unsqueeze(1)).squeeze(3)\n", + " # t_observed_tgt = torch.einsum('btf,fh->btfh',\n", + " # target_inp, self.tgt_embedding_vectors)\n", + " target_inp = torch.matmul(\n", + " target_inp.unsqueeze(3).unsqueeze(4),\n", + " self.tgt_embedding_vectors.unsqueeze(1),\n", + " ).squeeze(3)\n", " target_inp = target_inp + self.tgt_embedding_bias\n", "\n", " return s_inp, k_inp, o_inp, target_inp\n", "\n", + "\n", "class VariableSelectionNetwork(nn.Module):\n", " def __init__(self, hidden_size, num_inputs, dropout, grn_activation):\n", " super().__init__()\n", - " self.joint_grn = GRN(input_size=hidden_size*num_inputs, \n", - " hidden_size=hidden_size, \n", - " output_size=num_inputs, \n", - " context_hidden_size=hidden_size,\n", - " activation=grn_activation)\n", + " self.joint_grn = GRN(\n", + " input_size=hidden_size * num_inputs,\n", + " hidden_size=hidden_size,\n", + " output_size=num_inputs,\n", + " context_hidden_size=hidden_size,\n", + " activation=grn_activation,\n", + " )\n", " self.var_grns = nn.ModuleList(\n", - " [GRN(input_size=hidden_size, \n", - " hidden_size=hidden_size, dropout=dropout, activation=grn_activation)\n", - " for _ in range(num_inputs)])\n", + " [\n", + " GRN(\n", + " input_size=hidden_size,\n", + " hidden_size=hidden_size,\n", + " dropout=dropout,\n", + " activation=grn_activation,\n", + " )\n", + " for _ in range(num_inputs)\n", + " ]\n", + " )\n", "\n", " def forward(self, x: Tensor, context: Optional[Tensor] = None):\n", " Xi = x.reshape(*x.shape[:-2], -1)\n", " grn_outputs = self.joint_grn(Xi, c=context)\n", " sparse_weights = F.softmax(grn_outputs, dim=-1)\n", - " transformed_embed_list = [m(x[...,i,:])\n", - " for i, m in enumerate(self.var_grns)]\n", + " transformed_embed_list = [m(x[..., i, :]) for i, m in enumerate(self.var_grns)]\n", " transformed_embed = torch.stack(transformed_embed_list, dim=-1)\n", - " #the line below performs batched matrix vector multiplication\n", - " #for temporal features it's bthf,btf->bth\n", - " #for static features it's bhf,bf->bh\n", - " variable_ctx = torch.matmul(transformed_embed, \n", - " sparse_weights.unsqueeze(-1)).squeeze(-1)\n", + " # the line below performs batched matrix vector multiplication\n", + " # for temporal features it's bthf,btf->bth\n", + " # for static features it's bhf,bf->bh\n", + " variable_ctx = torch.matmul(\n", + " transformed_embed, sparse_weights.unsqueeze(-1)\n", + " ).squeeze(-1)\n", "\n", " return variable_ctx, sparse_weights" ] @@ -375,7 +408,7 @@ "metadata": {}, "outputs": [], "source": [ - "#| exporti\n", + "# | exporti\n", "class InterpretableMultiHeadAttention(nn.Module):\n", " def __init__(self, n_head, hidden_size, example_length, attn_dropout, dropout):\n", " super().__init__()\n", @@ -467,17 +500,37 @@ "metadata": {}, "outputs": [], "source": [ - "#| exporti\n", + "# | exporti\n", "class StaticCovariateEncoder(nn.Module):\n", - " def __init__(self, hidden_size, num_static_vars, dropout, grn_activation):\n", + " def __init__(\n", + " self,\n", + " hidden_size,\n", + " num_static_vars,\n", + " dropout,\n", + " grn_activation,\n", + " rnn_type=\"lstm\",\n", + " n_rnn_layers=1,\n", + " one_rnn_initial_state=False,\n", + " ):\n", " super().__init__()\n", " self.vsn = VariableSelectionNetwork(\n", - " hidden_size=hidden_size, num_inputs=num_static_vars, dropout=dropout, grn_activation=grn_activation\n", + " hidden_size=hidden_size,\n", + " num_inputs=num_static_vars,\n", + " dropout=dropout,\n", + " grn_activation=grn_activation,\n", " )\n", + " self.rnn_type = rnn_type.lower()\n", + "\n", + " self.n_rnn_layers = n_rnn_layers\n", + "\n", + " self.n_states = 1 if one_rnn_initial_state else n_rnn_layers\n", + "\n", + " n_contexts = 2 + 2 * self.n_states if rnn_type == \"lstm\" else 2 + self.n_states\n", + "\n", " self.context_grns = nn.ModuleList(\n", " [\n", " GRN(input_size=hidden_size, hidden_size=hidden_size, dropout=dropout)\n", - " for _ in range(4)\n", + " for _ in range(n_contexts)\n", " ]\n", " )\n", "\n", @@ -489,9 +542,46 @@ " # enrichment context\n", " # state_c context\n", " # state_h context\n", - " cs, ce, ch, cc = tuple(m(variable_ctx) for m in self.context_grns) # type: ignore\n", "\n", - " return cs, ce, ch, cc, sparse_weights # type: ignore" + " cs, ce = list(m(variable_ctx) for m in self.context_grns[:2]) # type: ignore\n", + "\n", + " if self.n_states == 1:\n", + " ch = torch.cat(\n", + " self.n_rnn_layers\n", + " * list(\n", + " m(variable_ctx).unsqueeze(0)\n", + " for m in self.context_grns[2 : self.n_states + 2]\n", + " )\n", + " )\n", + "\n", + " if self.rnn_type == \"lstm\":\n", + " cc = torch.cat(\n", + " self.n_rnn_layers\n", + " * list(\n", + " m(variable_ctx).unsqueeze(0)\n", + " for m in self.context_grns[self.n_states + 2 :]\n", + " )\n", + " )\n", + "\n", + " else:\n", + " ch = torch.cat(\n", + " list(\n", + " m(variable_ctx).unsqueeze(0)\n", + " for m in self.context_grns[2 : self.n_states + 2]\n", + " )\n", + " )\n", + "\n", + " if self.rnn_type == \"lstm\":\n", + " cc = torch.cat(\n", + " list(\n", + " m(variable_ctx).unsqueeze(0)\n", + " for m in self.context_grns[self.n_states + 2 :]\n", + " )\n", + " )\n", + " if self.rnn_type != \"lstm\":\n", + " cc = ch\n", + "\n", + " return cs, ce, ch, cc, sparse_weights # type: ignore" ] }, { @@ -524,23 +614,64 @@ "metadata": {}, "outputs": [], "source": [ - "#| exporti\n", + "# | exporti\n", "class TemporalCovariateEncoder(nn.Module):\n", - " def __init__(self, hidden_size, num_historic_vars, num_future_vars, dropout, grn_activation):\n", + " def __init__(\n", + " self,\n", + " hidden_size,\n", + " num_historic_vars,\n", + " num_future_vars,\n", + " dropout,\n", + " grn_activation,\n", + " rnn_type=\"lstm\",\n", + " n_rnn_layers=1,\n", + " ):\n", " super(TemporalCovariateEncoder, self).__init__()\n", + " self.rnn_type = rnn_type.lower()\n", + " self.n_rnn_layers = n_rnn_layers\n", "\n", " self.history_vsn = VariableSelectionNetwork(\n", - " hidden_size=hidden_size, num_inputs=num_historic_vars, dropout=dropout, grn_activation=grn_activation\n", - " )\n", - " self.history_encoder = nn.LSTM(\n", - " input_size=hidden_size, hidden_size=hidden_size, batch_first=True\n", + " hidden_size=hidden_size,\n", + " num_inputs=num_historic_vars,\n", + " dropout=dropout,\n", + " grn_activation=grn_activation,\n", " )\n", + " if self.rnn_type == \"lstm\":\n", + " self.history_encoder = nn.LSTM(\n", + " input_size=hidden_size,\n", + " hidden_size=hidden_size,\n", + " batch_first=True,\n", + " num_layers=n_rnn_layers,\n", + " )\n", + "\n", + " self.future_encoder = nn.LSTM(\n", + " input_size=hidden_size,\n", + " hidden_size=hidden_size,\n", + " batch_first=True,\n", + " num_layers=n_rnn_layers,\n", + " )\n", + "\n", + " elif self.rnn_type == \"gru\":\n", + " self.history_encoder = nn.GRU(\n", + " input_size=hidden_size,\n", + " hidden_size=hidden_size,\n", + " batch_first=True,\n", + " num_layers=n_rnn_layers,\n", + " )\n", + " self.future_encoder = nn.GRU(\n", + " input_size=hidden_size,\n", + " hidden_size=hidden_size,\n", + " batch_first=True,\n", + " num_layers=n_rnn_layers,\n", + " )\n", + " else:\n", + " raise ValueError('RNN type should be in [\"lstm\",\"gru\"] !')\n", "\n", " self.future_vsn = VariableSelectionNetwork(\n", - " hidden_size=hidden_size, num_inputs=num_future_vars, dropout=dropout, grn_activation=grn_activation\n", - " )\n", - " self.future_encoder = nn.LSTM(\n", - " input_size=hidden_size, hidden_size=hidden_size, batch_first=True\n", + " hidden_size=hidden_size,\n", + " num_inputs=num_future_vars,\n", + " dropout=dropout,\n", + " grn_activation=grn_activation,\n", " )\n", "\n", " # Shared Gated-Skip Connection\n", @@ -552,7 +683,11 @@ " historical_features, history_vsn_sparse_weights = self.history_vsn(\n", " historical_inputs, cs\n", " )\n", - " history, state = self.history_encoder(historical_features, (ch, cc))\n", + " if self.rnn_type == \"lstm\":\n", + " history, state = self.history_encoder(historical_features, (ch, cc))\n", + "\n", + " elif self.rnn_type == \"gru\":\n", + " history, state = self.history_encoder(historical_features, ch)\n", "\n", " future_features, future_vsn_sparse_weights = self.future_vsn(future_inputs, cs)\n", " future, _ = self.future_encoder(future_features, state)\n", @@ -588,10 +723,17 @@ "metadata": {}, "outputs": [], "source": [ - "#| exporti\n", + "# | exporti\n", "class TemporalFusionDecoder(nn.Module):\n", " def __init__(\n", - " self, n_head, hidden_size, example_length, encoder_length, attn_dropout, dropout, grn_activation\n", + " self,\n", + " n_head,\n", + " hidden_size,\n", + " example_length,\n", + " encoder_length,\n", + " attn_dropout,\n", + " dropout,\n", + " grn_activation,\n", " ):\n", " super(TemporalFusionDecoder, self).__init__()\n", " self.encoder_length = encoder_length\n", @@ -602,7 +744,7 @@ " hidden_size=hidden_size,\n", " context_hidden_size=hidden_size,\n", " dropout=dropout,\n", - " activation=grn_activation\n", + " activation=grn_activation,\n", " )\n", " self.attention = InterpretableMultiHeadAttention(\n", " n_head=n_head,\n", @@ -615,7 +757,10 @@ " self.attention_ln = LayerNorm(normalized_shape=hidden_size, eps=1e-3)\n", "\n", " self.positionwise_grn = GRN(\n", - " input_size=hidden_size, hidden_size=hidden_size, dropout=dropout, activation=grn_activation\n", + " input_size=hidden_size,\n", + " hidden_size=hidden_size,\n", + " dropout=dropout,\n", + " activation=grn_activation,\n", " )\n", "\n", " # ---------------------- Decoder -----------------------#\n", @@ -657,7 +802,7 @@ "metadata": {}, "outputs": [], "source": [ - "#| export\n", + "# | export\n", "class TFT(BaseWindows):\n", " \"\"\"TFT\n", "\n", @@ -678,6 +823,9 @@ " `n_head`: int=4, number of attention heads in temporal fusion decoder.
\n", " `attn_dropout`: float (0, 1), dropout of fusion decoder's attention layer.
\n", " `grn_activation`: str, activation for the GRN module from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid', 'ELU', 'GLU'].
\n", + " `rnn_type`: str=\"LSTM\", recurrent neural network (RNN) layer type from [\"LSTM\",\"GRU\"].
\n", + " `n_rnn_layers`: int=1, number of RNN layers.
\n", + " `one_rnn_initial_state`:str=False, Initialize all rnn layers with the same initial states computed from static covariates.
\n", " `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", " `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", " `max_steps`: int=1000, maximum number of training steps.
\n", @@ -725,7 +873,10 @@ " hidden_size: int = 128,\n", " n_head: int = 4,\n", " attn_dropout: float = 0.0,\n", - " grn_activation: str = 'ELU',\n", + " grn_activation: str = \"ELU\",\n", + " n_rnn_layers: int = 1,\n", + " rnn_type: str = \"LSTM\",\n", + " one_rnn_initial_state: bool = False,\n", " dropout: float = 0.1,\n", " loss=MAE(),\n", " valid_loss=None,\n", @@ -748,10 +899,9 @@ " optimizer_kwargs=None,\n", " lr_scheduler=None,\n", " lr_scheduler_kwargs=None,\n", - " dataloader_kwargs = None,\n", + " dataloader_kwargs=None,\n", " **trainer_kwargs,\n", " ):\n", - "\n", " # Inherit BaseWindows class\n", " super(TFT, self).__init__(\n", " h=h,\n", @@ -784,32 +934,40 @@ " **trainer_kwargs,\n", " )\n", " self.example_length = input_size + h\n", - " self.interpretability_params = dict([]) # type: ignore\n", + " self.interpretability_params = dict([]) # type: ignore\n", " self.tgt_size = tgt_size\n", " self.grn_activation = grn_activation\n", " futr_exog_size = max(self.futr_exog_size, 1)\n", " num_historic_vars = futr_exog_size + self.hist_exog_size + tgt_size\n", + " self.n_rnn_layers = n_rnn_layers\n", + " # ------------------------------- Encoders -----------------------------#\n", + " self.embedding = TFTEmbedding(\n", + " hidden_size=hidden_size,\n", + " stat_input_size=self.stat_exog_size,\n", + " futr_input_size=futr_exog_size,\n", + " hist_input_size=self.hist_exog_size,\n", + " tgt_size=tgt_size,\n", + " )\n", "\n", - " #------------------------------- Encoders -----------------------------#\n", - " self.embedding = TFTEmbedding(hidden_size=hidden_size,\n", - " stat_input_size=self.stat_exog_size,\n", - " futr_input_size=futr_exog_size,\n", - " hist_input_size=self.hist_exog_size,\n", - " tgt_size=tgt_size)\n", - " \n", " if self.stat_exog_size > 0:\n", " self.static_encoder = StaticCovariateEncoder(\n", - " hidden_size=hidden_size,\n", - " num_static_vars=self.stat_exog_size,\n", - " dropout=dropout,\n", - " grn_activation=self.grn_activation)\n", + " hidden_size=hidden_size,\n", + " num_static_vars=self.stat_exog_size,\n", + " dropout=dropout,\n", + " grn_activation=self.grn_activation,\n", + " rnn_type=rnn_type,\n", + " n_rnn_layers=n_rnn_layers,\n", + " one_rnn_initial_state=one_rnn_initial_state,\n", + " )\n", "\n", " self.temporal_encoder = TemporalCovariateEncoder(\n", " hidden_size=hidden_size,\n", " num_historic_vars=num_historic_vars,\n", " num_future_vars=futr_exog_size,\n", " dropout=dropout,\n", - " grn_activation=self.grn_activation\n", + " grn_activation=self.grn_activation,\n", + " n_rnn_layers=n_rnn_layers,\n", + " rnn_type=rnn_type,\n", " )\n", "\n", " # ------------------------------ Decoders -----------------------------#\n", @@ -820,7 +978,7 @@ " encoder_length=self.input_size,\n", " attn_dropout=attn_dropout,\n", " dropout=dropout,\n", - " grn_activation=self.grn_activation\n", + " grn_activation=self.grn_activation,\n", " )\n", "\n", " # Adapter with Loss dependent dimensions\n", @@ -829,7 +987,6 @@ " )\n", "\n", " def forward(self, windows_batch):\n", - "\n", " # Parsiw windows_batch\n", " y_insample = windows_batch[\"insample_y\"][:, :, None] # <- [B,T,1]\n", " futr_exog = windows_batch[\"futr_exog\"]\n", @@ -851,17 +1008,19 @@ " # Static context\n", " if s_inp is not None:\n", " cs, ce, ch, cc, static_encoder_sparse_weights = self.static_encoder(s_inp)\n", - " ch, cc = ch.unsqueeze(0), cc.unsqueeze(0) # LSTM initial states\n", + " # ch, cc = ch.unsqueeze(0), cc.unsqueeze(0) # LSTM initial states\n", " else:\n", " # If None add zeros\n", " batch_size, example_length, target_size, hidden_size = t_observed_tgt.shape\n", " cs = torch.zeros(size=(batch_size, hidden_size), device=y_insample.device)\n", " ce = torch.zeros(size=(batch_size, hidden_size), device=y_insample.device)\n", " ch = torch.zeros(\n", - " size=(1, batch_size, hidden_size), device=y_insample.device\n", + " size=(self.n_rnn_layers, batch_size, hidden_size),\n", + " device=y_insample.device,\n", " )\n", " cc = torch.zeros(\n", - " size=(1, batch_size, hidden_size), device=y_insample.device\n", + " size=(self.n_rnn_layers, batch_size, hidden_size),\n", + " device=y_insample.device,\n", " )\n", " static_encoder_sparse_weights = []\n", "\n", @@ -942,16 +1101,17 @@ " self.mean_on_batch(hist_vsn_wgts).cpu().numpy(), columns=hist_exog_list\n", " )\n", " importances[\"Past variable importance over time\"] = hist_vsn_imp\n", - " # importances[\"Past variable importance\"] = hist_vsn_imp.mean(axis=0).sort_values()\n", + " # importances[\"Past variable importance\"] = hist_vsn_imp.mean(axis=0).sort_values()\n", "\n", " # Future feature importances\n", " if self.futr_exog_size > 0:\n", " future_vsn_wgts = self.interpretability_params.get(\"future_vsn_wgts\")\n", " future_vsn_imp = pd.DataFrame(\n", - " self.mean_on_batch(future_vsn_wgts).cpu().numpy(), columns=self.futr_exog_list\n", + " self.mean_on_batch(future_vsn_wgts).cpu().numpy(),\n", + " columns=self.futr_exog_list,\n", " )\n", " importances[\"Future variable importance over time\"] = future_vsn_imp\n", - " # importances[\"Future variable importance\"] = future_vsn_imp.mean(axis=0).sort_values()\n", + " # importances[\"Future variable importance\"] = future_vsn_imp.mean(axis=0).sort_values()\n", "\n", " # Static feature importances\n", " if self.stat_exog_size > 0:\n", @@ -969,16 +1129,16 @@ " )\n", "\n", " return importances\n", - " \n", + "\n", " def attention_weights(self):\n", - " \"\"\" \n", + " \"\"\"\n", " Batch average attention weights\n", - " \n", + "\n", " Returns:\n", " np.ndarray: A 1D array containing the attention weights for each time step.\n", - " \n", + "\n", " \"\"\"\n", - " \n", + "\n", " attention = (\n", " self.mean_on_batch(self.interpretability_params[\"attn_wts\"])\n", " .mean(dim=0)\n", @@ -987,11 +1147,11 @@ " )\n", "\n", " return attention\n", - " \n", - " def feature_importance_correlations(self)-> pd.DataFrame:\n", + "\n", + " def feature_importance_correlations(self) -> pd.DataFrame:\n", " \"\"\"\n", " Compute the correlation between the past and future feature importances and the mean attention weights.\n", - " \n", + "\n", " Returns:\n", " pd.DataFrame: A DataFrame containing the correlation coefficients between the past feature importances and the mean attention weights.\n", " \"\"\"\n", @@ -1013,54 +1173,318 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "### TFT.fit\n", + "\n", + "> TFT.fit (dataset, val_size=0, test_size=0, random_seed=None,\n", + "> distributed_config=None)\n", + "\n", + "*Fit.\n", + "\n", + "The `fit` method, optimizes the neural network's weights using the\n", + "initialization parameters (`learning_rate`, `windows_batch_size`, ...)\n", + "and the `loss` function as defined during the initialization.\n", + "Within `fit` we use a PyTorch Lightning `Trainer` that\n", + "inherits the initialization's `self.trainer_kwargs`, to customize\n", + "its inputs, see [PL's trainer arguments](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).\n", + "\n", + "The method is designed to be compatible with SKLearn-like classes\n", + "and in particular to be compatible with the StatsForecast library.\n", + "\n", + "By default the `model` is not saving training checkpoints to protect\n", + "disk memory, to get them change `enable_checkpointing=True` in `__init__`.\n", + "\n", + "**Parameters:**
\n", + "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).
\n", + "`val_size`: int, validation size for temporal cross-validation.
\n", + "`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.
\n", + "`test_size`: int, test size for temporal cross-validation.
*" + ], + "text/plain": [ + "---\n", + "\n", + "### TFT.fit\n", + "\n", + "> TFT.fit (dataset, val_size=0, test_size=0, random_seed=None,\n", + "> distributed_config=None)\n", + "\n", + "*Fit.\n", + "\n", + "The `fit` method, optimizes the neural network's weights using the\n", + "initialization parameters (`learning_rate`, `windows_batch_size`, ...)\n", + "and the `loss` function as defined during the initialization.\n", + "Within `fit` we use a PyTorch Lightning `Trainer` that\n", + "inherits the initialization's `self.trainer_kwargs`, to customize\n", + "its inputs, see [PL's trainer arguments](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).\n", + "\n", + "The method is designed to be compatible with SKLearn-like classes\n", + "and in particular to be compatible with the StatsForecast library.\n", + "\n", + "By default the `model` is not saving training checkpoints to protect\n", + "disk memory, to get them change `enable_checkpointing=True` in `__init__`.\n", + "\n", + "**Parameters:**
\n", + "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).
\n", + "`val_size`: int, validation size for temporal cross-validation.
\n", + "`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.
\n", + "`test_size`: int, test size for temporal cross-validation.
*" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "show_doc(TFT.fit, name='TFT.fit', title_level=3)" + "show_doc(TFT.fit, name=\"TFT.fit\", title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "### TFT.predict\n", + "\n", + "> TFT.predict (dataset, test_size=None, step_size=1, random_seed=None,\n", + "> **data_module_kwargs)\n", + "\n", + "*Predict.\n", + "\n", + "Neural network prediction with PL's `Trainer` execution of `predict_step`.\n", + "\n", + "**Parameters:**
\n", + "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).
\n", + "`test_size`: int=None, test size for temporal cross-validation.
\n", + "`step_size`: int=1, Step size between each window.
\n", + "`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.
\n", + "`**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule).*" + ], + "text/plain": [ + "---\n", + "\n", + "### TFT.predict\n", + "\n", + "> TFT.predict (dataset, test_size=None, step_size=1, random_seed=None,\n", + "> **data_module_kwargs)\n", + "\n", + "*Predict.\n", + "\n", + "Neural network prediction with PL's `Trainer` execution of `predict_step`.\n", + "\n", + "**Parameters:**
\n", + "`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).
\n", + "`test_size`: int=None, test size for temporal cross-validation.
\n", + "`step_size`: int=1, Step size between each window.
\n", + "`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.
\n", + "`**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule).*" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "show_doc(TFT.predict, name='TFT.predict', title_level=3)" + "show_doc(TFT.predict, name=\"TFT.predict\", title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L679){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### TFT.feature_importances,\n", + "\n", + "> TFT.feature_importances, ()\n", + "\n", + "*Compute the feature importances for historical, future, and static features.\n", + "\n", + "Returns:\n", + " dict: A dictionary containing the feature importances for each feature type.\n", + " The keys are 'hist_vsn', 'future_vsn', and 'static_vsn', and the values\n", + " are pandas DataFrames with the corresponding feature importances.*" + ], + "text/plain": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L679){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### TFT.feature_importances,\n", + "\n", + "> TFT.feature_importances, ()\n", + "\n", + "*Compute the feature importances for historical, future, and static features.\n", + "\n", + "Returns:\n", + " dict: A dictionary containing the feature importances for each feature type.\n", + " The keys are 'hist_vsn', 'future_vsn', and 'static_vsn', and the values\n", + " are pandas DataFrames with the corresponding feature importances.*" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "show_doc(TFT.feature_importances, name='TFT.feature_importances,', title_level=3)" + "show_doc(TFT.feature_importances, name=\"TFT.feature_importances,\", title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L738){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### TFT.attention_weights\n", + "\n", + "> TFT.attention_weights ()\n", + "\n", + "*Batch average attention weights\n", + "\n", + "Returns:\n", + "np.ndarray: A 1D array containing the attention weights for each time step.*" + ], + "text/plain": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L738){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### TFT.attention_weights\n", + "\n", + "> TFT.attention_weights ()\n", + "\n", + "*Batch average attention weights\n", + "\n", + "Returns:\n", + "np.ndarray: A 1D array containing the attention weights for each time step.*" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "show_doc(TFT.attention_weights , name='TFT.attention_weights', title_level=3)" + "show_doc(TFT.attention_weights, name=\"TFT.attention_weights\", title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L738){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### TFT.attention_weights\n", + "\n", + "> TFT.attention_weights ()\n", + "\n", + "*Batch average attention weights\n", + "\n", + "Returns:\n", + "np.ndarray: A 1D array containing the attention weights for each time step.*" + ], + "text/plain": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L738){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### TFT.attention_weights\n", + "\n", + "> TFT.attention_weights ()\n", + "\n", + "*Batch average attention weights\n", + "\n", + "Returns:\n", + "np.ndarray: A 1D array containing the attention weights for each time step.*" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "show_doc(TFT.attention_weights , name='TFT.attention_weights', title_level=3)" + "show_doc(TFT.attention_weights, name=\"TFT.attention_weights\", title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/markdown": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L756){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### TFT.feature_importance_correlations\n", + "\n", + "> TFT.feature_importance_correlations ()\n", + "\n", + "*Compute the correlation between the past and future feature importances and the mean attention weights.\n", + "\n", + "Returns:\n", + "pd.DataFrame: A DataFrame containing the correlation coefficients between the past feature importances and the mean attention weights.*" + ], + "text/plain": [ + "---\n", + "\n", + "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L756){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", + "### TFT.feature_importance_correlations\n", + "\n", + "> TFT.feature_importance_correlations ()\n", + "\n", + "*Compute the correlation between the past and future feature importances and the mean attention weights.\n", + "\n", + "Returns:\n", + "pd.DataFrame: A DataFrame containing the correlation coefficients between the past feature importances and the mean attention weights.*" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "show_doc(TFT.feature_importance_correlations , name='TFT.feature_importance_correlations', title_level=3)" + "show_doc(\n", + " TFT.feature_importance_correlations,\n", + " name=\"TFT.feature_importance_correlations\",\n", + " title_level=3,\n", + ")" ] }, { @@ -1075,55 +1499,365 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 1\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e36d46990ccd41b592b8b41272f824a8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#| eval: false\n", - "import pandas as pd\n", + "# | eval: false\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", + "import pandas as pd\n", + "\n", "from neuralforecast import NeuralForecast\n", - "from neuralforecast.models import TFT\n", + "\n", + "# from neuralforecast.models import TFT\n", "from neuralforecast.losses.pytorch import DistributionLoss\n", "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n", "\n", - "AirPassengersPanel['month']=AirPassengersPanel.ds.dt.month\n", - "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n", + "AirPassengersPanel[\"month\"] = AirPassengersPanel.ds.dt.month\n", + "Y_train_df = AirPassengersPanel[\n", + " AirPassengersPanel.ds < AirPassengersPanel[\"ds\"].values[-12]\n", + "] # 132 train\n", + "Y_test_df = AirPassengersPanel[\n", + " AirPassengersPanel.ds >= AirPassengersPanel[\"ds\"].values[-12]\n", + "].reset_index(drop=True) # 12 test\n", "\n", "nf = NeuralForecast(\n", - " models=[TFT(h=12, input_size=48,\n", - " hidden_size=20,\n", - " grn_activation='ELU',\n", - " loss=DistributionLoss(distribution='StudentT', level=[80, 90]),\n", - " learning_rate=0.005,\n", - " stat_exog_list=['airline1'],\n", - " futr_exog_list=['y_[lag12]','month'],\n", - " hist_exog_list=['trend'],\n", - " max_steps=300,\n", - " val_check_steps=10,\n", - " early_stop_patience_steps=10,\n", - " scaler_type='robust',\n", - " windows_batch_size=None,\n", - " enable_progress_bar=True),\n", + " models=[\n", + " TFT(\n", + " h=12,\n", + " input_size=48,\n", + " hidden_size=20,\n", + " grn_activation=\"ELU\",\n", + " rnn_type=\"lstm\",\n", + " n_rnn_layers=1,\n", + " one_rnn_initial_state=False,\n", + " loss=DistributionLoss(distribution=\"StudentT\", level=[80, 90]),\n", + " learning_rate=0.005,\n", + " stat_exog_list=[\"airline1\"],\n", + " futr_exog_list=[\"y_[lag12]\", \"month\"],\n", + " hist_exog_list=[\"trend\"],\n", + " max_steps=300,\n", + " val_check_steps=10,\n", + " early_stop_patience_steps=10,\n", + " scaler_type=\"robust\",\n", + " windows_batch_size=None,\n", + " enable_progress_bar=True,\n", + " ),\n", " ],\n", - " freq='M'\n", + " freq=\"M\",\n", ")\n", "nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n", "Y_hat_df = nf.predict(futr_df=Y_test_df)\n", "\n", "# Plot quantile predictions\n", - "Y_hat_df = Y_hat_df.reset_index(drop=False).drop(columns=['unique_id','ds'])\n", + "Y_hat_df = Y_hat_df.reset_index(drop=False).drop(columns=[\"unique_id\", \"ds\"])\n", "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n", "plot_df = pd.concat([Y_train_df, plot_df])\n", "\n", - "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n", - "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n", - "plt.plot(plot_df['ds'], plot_df['TFT'], c='purple', label='mean')\n", - "plt.plot(plot_df['ds'], plot_df['TFT-median'], c='blue', label='median')\n", - "plt.fill_between(x=plot_df['ds'][-12:], \n", - " y1=plot_df['TFT-lo-90'][-12:].values, \n", - " y2=plot_df['TFT-hi-90'][-12:].values,\n", - " alpha=0.4, label='level 90')\n", + "plot_df = plot_df[plot_df.unique_id == \"Airline1\"].drop(\"unique_id\", axis=1)\n", + "plt.plot(plot_df[\"ds\"], plot_df[\"y\"], c=\"black\", label=\"True\")\n", + "plt.plot(plot_df[\"ds\"], plot_df[\"TFT\"], c=\"purple\", label=\"mean\")\n", + "plt.plot(plot_df[\"ds\"], plot_df[\"TFT-median\"], c=\"blue\", label=\"median\")\n", + "plt.fill_between(\n", + " x=plot_df[\"ds\"][-12:],\n", + " y1=plot_df[\"TFT-lo-90\"][-12:].values,\n", + " y2=plot_df[\"TFT-hi-90\"][-12:].values,\n", + " alpha=0.4,\n", + " label=\"level 90\",\n", + ")\n", "plt.legend()\n", "plt.grid()\n", "plt.plot()" @@ -1149,7 +1883,7 @@ "metadata": {}, "outputs": [], "source": [ - "#| eval: false\n", + "# | eval: false\n", "attention = nf.models[0].attention_weights()" ] }, @@ -1159,82 +1893,104 @@ "metadata": {}, "outputs": [], "source": [ - "#| eval: false\n", - "def plot_attention(self, plot:str=\"time\", output:str='plot', width:int=800, height:int=400):\n", - " \"\"\"\n", - " Plot the attention weights.\n", - "\n", - " Args:\n", - " plot (str, optional): The type of plot to generate. Can be one of the following:\n", - " - 'time': Display the mean attention weights over time.\n", - " - 'all': Display the attention weights for each horizon.\n", - " - 'heatmap': Display the attention weights as a heatmap.\n", - " - An integer in the range [1, model.h) to display the attention weights for a specific horizon.\n", - " output (str, optional): The type of output to generate. Can be one of the following:\n", - " - 'plot': Display the plot directly.\n", - " - 'figure': Return the plot as a figure object.\n", - " width (int, optional): Width of the plot in pixels. Default is 800.\n", - " height (int, optional): Height of the plot in pixels. Default is 400.\n", - "\n", - " Returns:\n", - " matplotlib.figure.Figure: If `output` is 'figure', the function returns the plot as a figure object.\n", - " \"\"\"\n", + "# | eval: false\n", + "def plot_attention(\n", + " self, plot: str = \"time\", output: str = \"plot\", width: int = 800, height: int = 400\n", + "):\n", + " \"\"\"\n", + " Plot the attention weights.\n", + "\n", + " Args:\n", + " plot (str, optional): The type of plot to generate. Can be one of the following:\n", + " - 'time': Display the mean attention weights over time.\n", + " - 'all': Display the attention weights for each horizon.\n", + " - 'heatmap': Display the attention weights as a heatmap.\n", + " - An integer in the range [1, model.h) to display the attention weights for a specific horizon.\n", + " output (str, optional): The type of output to generate. Can be one of the following:\n", + " - 'plot': Display the plot directly.\n", + " - 'figure': Return the plot as a figure object.\n", + " width (int, optional): Width of the plot in pixels. Default is 800.\n", + " height (int, optional): Height of the plot in pixels. Default is 400.\n", + "\n", + " Returns:\n", + " matplotlib.figure.Figure: If `output` is 'figure', the function returns the plot as a figure object.\n", + " \"\"\"\n", "\n", - " attention = (\n", - " self.mean_on_batch(self.interpretability_params[\"attn_wts\"])\n", - " .mean(dim=0)\n", - " .cpu()\n", - " .numpy()\n", - " )\n", + " attention = (\n", + " self.mean_on_batch(self.interpretability_params[\"attn_wts\"])\n", + " .mean(dim=0)\n", + " .cpu()\n", + " .numpy()\n", + " )\n", "\n", - " fig, ax = plt.subplots(figsize=(width / 100, height / 100))\n", - "\n", - " if plot == \"time\":\n", - " attention = attention[self.input_size:, :].mean(axis=0)\n", - " ax.plot(np.arange(-self.input_size, self.h), attention)\n", - " ax.axvline(x=0, color='black', linewidth=3, linestyle='--', label=\"prediction start\")\n", - " ax.set_title(\"Mean Attention\")\n", - " ax.set_xlabel(\"time\")\n", - " ax.set_ylabel(\"Attention\")\n", - " ax.legend()\n", - "\n", - " elif plot == \"all\":\n", - " for i in range(self.input_size, attention.shape[0]):\n", - " ax.plot(np.arange(-self.input_size, self.h), attention[i, :], label=f\"horizon {i-self.input_size+1}\")\n", - " ax.axvline(x=0, color='black', linewidth=3, linestyle='--', label=\"prediction start\")\n", - " ax.set_title(\"Attention per horizon\")\n", - " ax.set_xlabel(\"time\")\n", - " ax.set_ylabel(\"Attention\")\n", - " ax.legend()\n", - "\n", - " elif plot == \"heatmap\":\n", - " cax = ax.imshow(attention, aspect='auto', cmap='viridis',\n", - " extent=[-self.input_size, self.h, -self.input_size, self.h])\n", - " fig.colorbar(cax)\n", - " ax.set_title(\"Attention Heatmap\")\n", - " ax.set_xlabel(\"Attention (current time step)\")\n", - " ax.set_ylabel(\"Attention (previous time step)\")\n", - "\n", - " elif isinstance(plot, int) and (plot in np.arange(1, self.h + 1)):\n", - " i = self.input_size + plot - 1\n", - " ax.plot(np.arange(-self.input_size, self.h), attention[i, :], label=f\"horizon {plot}\")\n", - " ax.axvline(x=0, color='black', linewidth=3, linestyle='--', label=\"prediction start\")\n", - " ax.set_title(f\"Attention weight for horizon {plot}\")\n", - " ax.set_xlabel(\"time\")\n", - " ax.set_ylabel(\"Attention\")\n", - " ax.legend()\n", + " fig, ax = plt.subplots(figsize=(width / 100, height / 100))\n", "\n", - " else:\n", - " raise ValueError('plot has to be in [\"time\",\"all\",\"heatmap\"] or integer in range(1,model.h)')\n", + " if plot == \"time\":\n", + " attention = attention[self.input_size :, :].mean(axis=0)\n", + " ax.plot(np.arange(-self.input_size, self.h), attention)\n", + " ax.axvline(\n", + " x=0, color=\"black\", linewidth=3, linestyle=\"--\", label=\"prediction start\"\n", + " )\n", + " ax.set_title(\"Mean Attention\")\n", + " ax.set_xlabel(\"time\")\n", + " ax.set_ylabel(\"Attention\")\n", + " ax.legend()\n", + "\n", + " elif plot == \"all\":\n", + " for i in range(self.input_size, attention.shape[0]):\n", + " ax.plot(\n", + " np.arange(-self.input_size, self.h),\n", + " attention[i, :],\n", + " label=f\"horizon {i-self.input_size+1}\",\n", + " )\n", + " ax.axvline(\n", + " x=0, color=\"black\", linewidth=3, linestyle=\"--\", label=\"prediction start\"\n", + " )\n", + " ax.set_title(\"Attention per horizon\")\n", + " ax.set_xlabel(\"time\")\n", + " ax.set_ylabel(\"Attention\")\n", + " ax.legend()\n", + "\n", + " elif plot == \"heatmap\":\n", + " cax = ax.imshow(\n", + " attention,\n", + " aspect=\"auto\",\n", + " cmap=\"viridis\",\n", + " extent=[-self.input_size, self.h, -self.input_size, self.h],\n", + " )\n", + " fig.colorbar(cax)\n", + " ax.set_title(\"Attention Heatmap\")\n", + " ax.set_xlabel(\"Attention (current time step)\")\n", + " ax.set_ylabel(\"Attention (previous time step)\")\n", + "\n", + " elif isinstance(plot, int) and (plot in np.arange(1, self.h + 1)):\n", + " i = self.input_size + plot - 1\n", + " ax.plot(\n", + " np.arange(-self.input_size, self.h),\n", + " attention[i, :],\n", + " label=f\"horizon {plot}\",\n", + " )\n", + " ax.axvline(\n", + " x=0, color=\"black\", linewidth=3, linestyle=\"--\", label=\"prediction start\"\n", + " )\n", + " ax.set_title(f\"Attention weight for horizon {plot}\")\n", + " ax.set_xlabel(\"time\")\n", + " ax.set_ylabel(\"Attention\")\n", + " ax.legend()\n", + "\n", + " else:\n", + " raise ValueError(\n", + " 'plot has to be in [\"time\",\"all\",\"heatmap\"] or integer in range(1,model.h)'\n", + " )\n", "\n", - " plt.tight_layout()\n", + " plt.tight_layout()\n", "\n", - " if output == 'plot':\n", - " plt.show()\n", - " elif output == 'figure':\n", - " return fig\n", - " else:\n", - " raise ValueError(f\"Invalid output: {output}. Expected 'plot' or 'figure'.\")" + " if output == \"plot\":\n", + " plt.show()\n", + " elif output == \"figure\":\n", + " return fig\n", + " else:\n", + " raise ValueError(f\"Invalid output: {output}. Expected 'plot' or 'figure'.\")" ] }, { @@ -1248,9 +2004,20 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#| eval: false\n", + "# | eval: false\n", "plot_attention(nf.models[0], plot=\"time\")" ] }, @@ -1265,9 +2032,20 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#| eval: false\n", + "# | eval: false\n", "plot_attention(nf.models[0], plot=\"all\")" ] }, @@ -1282,9 +2060,20 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#| eval: false\n", + "# | eval: false\n", "plot_attention(nf.models[0], plot=8)" ] }, @@ -1300,9 +2089,20 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['Past variable importance over time', 'Future variable importance over time', 'Static covariates'])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "#| eval: false\n", + "# | eval: false\n", "\n", "feature_importances = nf.models[0].feature_importances()\n", "feature_importances.keys()" @@ -1319,10 +2119,31 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#| eval: false\n", - "feature_importances['Static covariates'].sort_values(by='importance').plot(kind='barh')" + "# | eval: false\n", + "feature_importances[\"Static covariates\"].sort_values(by=\"importance\").plot(kind=\"barh\")" ] }, { @@ -1336,10 +2157,33 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#| eval: false\n", - "feature_importances['Past variable importance over time'].mean().sort_values().plot(kind='barh')" + "# | eval: false\n", + "feature_importances[\"Past variable importance over time\"].mean().sort_values().plot(\n", + " kind=\"barh\"\n", + ")" ] }, { @@ -1353,10 +2197,33 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#| eval: false\n", - "feature_importances['Future variable importance over time'].mean().sort_values().plot(kind='barh')" + "# | eval: false\n", + "feature_importances[\"Future variable importance over time\"].mean().sort_values().plot(\n", + " kind=\"barh\"\n", + ")" ] }, { @@ -1378,18 +2245,29 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#| eval: false\n", - "df=feature_importances['Future variable importance over time']\n", + "# | eval: false\n", + "df = feature_importances[\"Future variable importance over time\"]\n", "\n", "\n", "fig, ax = plt.subplots(figsize=(20, 10))\n", "bottom = np.zeros(len(df.index))\n", "for col in df.columns:\n", - " p = ax.bar(np.arange(-len(df),0), df[col].values, 0.6, label=col, bottom=bottom)\n", + " p = ax.bar(np.arange(-len(df), 0), df[col].values, 0.6, label=col, bottom=bottom)\n", " bottom += df[col]\n", - "ax.set_title('Future variable importance over time ponderated by attention')\n", + "ax.set_title(\"Future variable importance over time ponderated by attention\")\n", "ax.set_ylabel(\"Importance\")\n", "ax.set_xlabel(\"Time\")\n", "ax.grid(True)\n", @@ -1415,18 +2293,29 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#| eval: false\n", - "df= feature_importances['Past variable importance over time']\n", + "# | eval: false\n", + "df = feature_importances[\"Past variable importance over time\"]\n", "\n", "fig, ax = plt.subplots(figsize=(20, 10))\n", "bottom = np.zeros(len(df.index))\n", "\n", "for col in df.columns:\n", - " p = ax.bar(np.arange(-len(df),0), df[col].values, 0.6, label=col, bottom=bottom)\n", + " p = ax.bar(np.arange(-len(df), 0), df[col].values, 0.6, label=col, bottom=bottom)\n", " bottom += df[col]\n", - "ax.set_title('Past variable importance over time')\n", + "ax.set_title(\"Past variable importance over time\")\n", "ax.set_ylabel(\"Importance\")\n", "ax.set_xlabel(\"Time\")\n", "ax.legend()\n", @@ -1447,25 +2336,48 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "#| eval: false\n", - "df= feature_importances['Past variable importance over time']\n", - "mean_attention = nf.models[0].attention_weights()[nf.models[0].input_size:,:].mean(axis=0)[:nf.models[0].input_size]\n", + "# | eval: false\n", + "df = feature_importances[\"Past variable importance over time\"]\n", + "mean_attention = (\n", + " nf.models[0]\n", + " .attention_weights()[nf.models[0].input_size :, :]\n", + " .mean(axis=0)[: nf.models[0].input_size]\n", + ")\n", "df = df.multiply(mean_attention, axis=0)\n", "\n", "fig, ax = plt.subplots(figsize=(20, 10))\n", "bottom = np.zeros(len(df.index))\n", "\n", "for col in df.columns:\n", - " p = ax.bar(np.arange(-len(df),0), df[col].values, 0.6, label=col, bottom=bottom)\n", + " p = ax.bar(np.arange(-len(df), 0), df[col].values, 0.6, label=col, bottom=bottom)\n", " bottom += df[col]\n", - "ax.set_title('Past variable importance over time ponderated by attention')\n", + "ax.set_title(\"Past variable importance over time ponderated by attention\")\n", "ax.set_ylabel(\"Importance\")\n", "ax.set_xlabel(\"Time\")\n", "ax.legend()\n", "ax.grid(True)\n", - "plt.plot(np.arange(-len(df),0), mean_attention, color='black', marker='o', linestyle='-', linewidth=2, label='mean_attention')\n", + "plt.plot(\n", + " np.arange(-len(df), 0),\n", + " mean_attention,\n", + " color=\"black\",\n", + " marker=\"o\",\n", + " linestyle=\"-\",\n", + " linewidth=2,\n", + " label=\"mean_attention\",\n", + ")\n", "plt.legend()\n", "plt.show()" ] @@ -1482,9 +2394,103 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
trendy_[lag12]monthobserved_targetCorrelation with Mean Attention
trend1.00-0.45-0.29-0.41-0.43
y_[lag12]-0.451.00-0.56-0.180.68
month-0.29-0.561.000.18-0.38
observed_target-0.41-0.180.181.000.07
Correlation with Mean Attention-0.430.68-0.380.071.00
\n", + "
" + ], + "text/plain": [ + " trend y_[lag12] month observed_target \\\n", + "trend 1.00 -0.45 -0.29 -0.41 \n", + "y_[lag12] -0.45 1.00 -0.56 -0.18 \n", + "month -0.29 -0.56 1.00 0.18 \n", + "observed_target -0.41 -0.18 0.18 1.00 \n", + "Correlation with Mean Attention -0.43 0.68 -0.38 0.07 \n", + "\n", + " Correlation with Mean Attention \n", + "trend -0.43 \n", + "y_[lag12] 0.68 \n", + "month -0.38 \n", + "observed_target 0.07 \n", + "Correlation with Mean Attention 1.00 " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "#| eval: false\n", + "# | eval: false\n", "nf.models[0].feature_importance_correlations()" ] } diff --git a/neuralforecast/models/tft.py b/neuralforecast/models/tft.py index f96d5646b..53b0c0cfc 100644 --- a/neuralforecast/models/tft.py +++ b/neuralforecast/models/tft.py @@ -4,16 +4,17 @@ __all__ = ['TFT'] # %% ../../nbs/models.tft.ipynb 5 -from typing import Tuple, Optional, Callable +from typing import Callable, Optional, Tuple +import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.nn import LayerNorm -import pandas as pd -from ..losses.pytorch import MAE + from ..common._base_windows import BaseWindows +from ..losses.pytorch import MAE # %% ../../nbs/models.tft.ipynb 11 def get_activation_fn(activation_str: str) -> Callable: @@ -129,7 +130,6 @@ def _apply_embedding( cont_emb: Tensor, cont_bias: Tensor, ): - if cont is not None: # the line below is equivalent to following einsums # e_cont = torch.einsum('btf,fh->bthf', cont, cont_emb) @@ -270,7 +270,16 @@ def forward( # %% ../../nbs/models.tft.ipynb 19 class StaticCovariateEncoder(nn.Module): - def __init__(self, hidden_size, num_static_vars, dropout, grn_activation): + def __init__( + self, + hidden_size, + num_static_vars, + dropout, + grn_activation, + rnn_type="lstm", + n_rnn_layers=1, + one_rnn_initial_state=False, + ): super().__init__() self.vsn = VariableSelectionNetwork( hidden_size=hidden_size, @@ -278,10 +287,18 @@ def __init__(self, hidden_size, num_static_vars, dropout, grn_activation): dropout=dropout, grn_activation=grn_activation, ) + self.rnn_type = rnn_type.lower() + + self.n_rnn_layers = n_rnn_layers + + self.n_states = 1 if one_rnn_initial_state else n_rnn_layers + + n_contexts = 2 + 2 * self.n_states if rnn_type == "lstm" else 2 + self.n_states + self.context_grns = nn.ModuleList( [ GRN(input_size=hidden_size, hidden_size=hidden_size, dropout=dropout) - for _ in range(4) + for _ in range(n_contexts) ] ) @@ -293,16 +310,62 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # enrichment context # state_c context # state_h context - cs, ce, ch, cc = tuple(m(variable_ctx) for m in self.context_grns) # type: ignore + + cs, ce = list(m(variable_ctx) for m in self.context_grns[:2]) # type: ignore + + if self.n_states == 1: + ch = torch.cat( + self.n_rnn_layers + * list( + m(variable_ctx).unsqueeze(0) + for m in self.context_grns[2 : self.n_states + 2] + ) + ) + + if self.rnn_type == "lstm": + cc = torch.cat( + self.n_rnn_layers + * list( + m(variable_ctx).unsqueeze(0) + for m in self.context_grns[self.n_states + 2 :] + ) + ) + + else: + ch = torch.cat( + list( + m(variable_ctx).unsqueeze(0) + for m in self.context_grns[2 : self.n_states + 2] + ) + ) + + if self.rnn_type == "lstm": + cc = torch.cat( + list( + m(variable_ctx).unsqueeze(0) + for m in self.context_grns[self.n_states + 2 :] + ) + ) + if self.rnn_type != "lstm": + cc = ch return cs, ce, ch, cc, sparse_weights # type: ignore # %% ../../nbs/models.tft.ipynb 21 class TemporalCovariateEncoder(nn.Module): def __init__( - self, hidden_size, num_historic_vars, num_future_vars, dropout, grn_activation + self, + hidden_size, + num_historic_vars, + num_future_vars, + dropout, + grn_activation, + rnn_type="lstm", + n_rnn_layers=1, ): super(TemporalCovariateEncoder, self).__init__() + self.rnn_type = rnn_type.lower() + self.n_rnn_layers = n_rnn_layers self.history_vsn = VariableSelectionNetwork( hidden_size=hidden_size, @@ -310,9 +373,36 @@ def __init__( dropout=dropout, grn_activation=grn_activation, ) - self.history_encoder = nn.LSTM( - input_size=hidden_size, hidden_size=hidden_size, batch_first=True - ) + if self.rnn_type == "lstm": + self.history_encoder = nn.LSTM( + input_size=hidden_size, + hidden_size=hidden_size, + batch_first=True, + num_layers=n_rnn_layers, + ) + + self.future_encoder = nn.LSTM( + input_size=hidden_size, + hidden_size=hidden_size, + batch_first=True, + num_layers=n_rnn_layers, + ) + + elif self.rnn_type == "gru": + self.history_encoder = nn.GRU( + input_size=hidden_size, + hidden_size=hidden_size, + batch_first=True, + num_layers=n_rnn_layers, + ) + self.future_encoder = nn.GRU( + input_size=hidden_size, + hidden_size=hidden_size, + batch_first=True, + num_layers=n_rnn_layers, + ) + else: + raise ValueError('RNN type should be in ["lstm","gru"] !') self.future_vsn = VariableSelectionNetwork( hidden_size=hidden_size, @@ -320,9 +410,6 @@ def __init__( dropout=dropout, grn_activation=grn_activation, ) - self.future_encoder = nn.LSTM( - input_size=hidden_size, hidden_size=hidden_size, batch_first=True - ) # Shared Gated-Skip Connection self.input_gate = GLU(hidden_size, hidden_size) @@ -333,7 +420,11 @@ def forward(self, historical_inputs, future_inputs, cs, ch, cc): historical_features, history_vsn_sparse_weights = self.history_vsn( historical_inputs, cs ) - history, state = self.history_encoder(historical_features, (ch, cc)) + if self.rnn_type == "lstm": + history, state = self.history_encoder(historical_features, (ch, cc)) + + elif self.rnn_type == "gru": + history, state = self.history_encoder(historical_features, ch) future_features, future_vsn_sparse_weights = self.future_vsn(future_inputs, cs) future, _ = self.future_encoder(future_features, state) @@ -439,6 +530,9 @@ class TFT(BaseWindows): `n_head`: int=4, number of attention heads in temporal fusion decoder.
`attn_dropout`: float (0, 1), dropout of fusion decoder's attention layer.
`grn_activation`: str, activation for the GRN module from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid', 'ELU', 'GLU'].
+ `rnn_type`: str="LSTM", recurrent neural network (RNN) layer type from ["LSTM","GRU"].
+ `n_rnn_layers`: int=1, number of RNN layers.
+ `one_rnn_initial_state`:str=False, Initialize all rnn layers with the same initial states computed from static covariates.
`loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
`valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
`max_steps`: int=1000, maximum number of training steps.
@@ -487,6 +581,9 @@ def __init__( n_head: int = 4, attn_dropout: float = 0.0, grn_activation: str = "ELU", + n_rnn_layers: int = 1, + rnn_type: str = "LSTM", + one_rnn_initial_state: bool = False, dropout: float = 0.1, loss=MAE(), valid_loss=None, @@ -512,7 +609,6 @@ def __init__( dataloader_kwargs=None, **trainer_kwargs, ): - # Inherit BaseWindows class super(TFT, self).__init__( h=h, @@ -550,7 +646,7 @@ def __init__( self.grn_activation = grn_activation futr_exog_size = max(self.futr_exog_size, 1) num_historic_vars = futr_exog_size + self.hist_exog_size + tgt_size - + self.n_rnn_layers = n_rnn_layers # ------------------------------- Encoders -----------------------------# self.embedding = TFTEmbedding( hidden_size=hidden_size, @@ -566,6 +662,9 @@ def __init__( num_static_vars=self.stat_exog_size, dropout=dropout, grn_activation=self.grn_activation, + rnn_type=rnn_type, + n_rnn_layers=n_rnn_layers, + one_rnn_initial_state=one_rnn_initial_state, ) self.temporal_encoder = TemporalCovariateEncoder( @@ -574,6 +673,8 @@ def __init__( num_future_vars=futr_exog_size, dropout=dropout, grn_activation=self.grn_activation, + n_rnn_layers=n_rnn_layers, + rnn_type=rnn_type, ) # ------------------------------ Decoders -----------------------------# @@ -593,7 +694,6 @@ def __init__( ) def forward(self, windows_batch): - # Parsiw windows_batch y_insample = windows_batch["insample_y"][:, :, None] # <- [B,T,1] futr_exog = windows_batch["futr_exog"] @@ -615,17 +715,19 @@ def forward(self, windows_batch): # Static context if s_inp is not None: cs, ce, ch, cc, static_encoder_sparse_weights = self.static_encoder(s_inp) - ch, cc = ch.unsqueeze(0), cc.unsqueeze(0) # LSTM initial states + # ch, cc = ch.unsqueeze(0), cc.unsqueeze(0) # LSTM initial states else: # If None add zeros batch_size, example_length, target_size, hidden_size = t_observed_tgt.shape cs = torch.zeros(size=(batch_size, hidden_size), device=y_insample.device) ce = torch.zeros(size=(batch_size, hidden_size), device=y_insample.device) ch = torch.zeros( - size=(1, batch_size, hidden_size), device=y_insample.device + size=(self.n_rnn_layers, batch_size, hidden_size), + device=y_insample.device, ) cc = torch.zeros( - size=(1, batch_size, hidden_size), device=y_insample.device + size=(self.n_rnn_layers, batch_size, hidden_size), + device=y_insample.device, ) static_encoder_sparse_weights = []