Skip to content

Commit

Permalink
[Refactor] Move RevIN class to common module (#1083)
Browse files Browse the repository at this point in the history
* re-use the same MovingAvg classin fedformer

* re-use the same SeriesDecomp class in fedformer

* RevIn in patchtst replaced by commonly used Normalize class

* re-use the shared Normalize class in timellm and timemixer

* Rename Normalize to RevIN

* move module to common.salers

move to common.scalers

* Fix error due to unwanted change of cell-type

* Delete extra cell and duplicate declaration of functions

* Remove unwanted changes, unused imported class

* Fix error in fedformer

* Fix error due to formatting

* restore the unwanted deleted line

* Review: Move from _scalers.py to _modules.py

* RMok imports shared module of RevIN

---------

Co-authored-by: Olivier Sprangers <[email protected]>
  • Loading branch information
JQGoh and elephaint authored Sep 19, 2024
1 parent a3ab53d commit 5a86c7f
Show file tree
Hide file tree
Showing 13 changed files with 189 additions and 749 deletions.
79 changes: 79 additions & 0 deletions nbs/common.modules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,85 @@
" res = x - moving_mean\n",
" return res, moving_mean"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"\n",
"class RevIN(nn.Module):\n",
" \"\"\" RevIN (Reversible-Instance-Normalization)\n",
" \"\"\"\n",
" def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):\n",
" \"\"\"\n",
" :param num_features: the number of features or channels\n",
" :param eps: a value added for numerical stability\n",
" :param affine: if True, RevIN has learnable affine parameters\n",
" :param substract_last: if True, the substraction is based on the last value \n",
" instead of the mean in normalization\n",
" :param non_norm: if True, no normalization performed.\n",
" \"\"\"\n",
" super(RevIN, self).__init__()\n",
" self.num_features = num_features\n",
" self.eps = eps\n",
" self.affine = affine\n",
" self.subtract_last = subtract_last\n",
" self.non_norm = non_norm\n",
" if self.affine:\n",
" self._init_params()\n",
"\n",
" def forward(self, x, mode: str):\n",
" if mode == 'norm':\n",
" self._get_statistics(x)\n",
" x = self._normalize(x)\n",
" elif mode == 'denorm':\n",
" x = self._denormalize(x)\n",
" else:\n",
" raise NotImplementedError\n",
" return x\n",
"\n",
" def _init_params(self):\n",
" # initialize RevIN params: (C,)\n",
" self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n",
" self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n",
"\n",
" def _get_statistics(self, x):\n",
" dim2reduce = tuple(range(1, x.ndim - 1))\n",
" if self.subtract_last:\n",
" self.last = x[:, -1, :].unsqueeze(1)\n",
" else:\n",
" self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n",
" self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n",
"\n",
" def _normalize(self, x):\n",
" if self.non_norm:\n",
" return x\n",
" if self.subtract_last:\n",
" x = x - self.last\n",
" else:\n",
" x = x - self.mean\n",
" x = x / self.stdev\n",
" if self.affine:\n",
" x = x * self.affine_weight\n",
" x = x + self.affine_bias\n",
" return x\n",
"\n",
" def _denormalize(self, x):\n",
" if self.non_norm:\n",
" return x\n",
" if self.affine:\n",
" x = x - self.affine_bias\n",
" x = x / (self.affine_weight + self.eps * self.eps)\n",
" x = x * self.stdev\n",
" if self.subtract_last:\n",
" x = x + self.last\n",
" else:\n",
" x = x + self.mean\n",
" return x"
]
}
],
"metadata": {
Expand Down
32 changes: 1 addition & 31 deletions nbs/models.fedformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"import torch.nn.functional as F\n",
"\n",
"from neuralforecast.common._modules import DataEmbedding\n",
"from neuralforecast.common._modules import SeriesDecomp\n",
"from neuralforecast.common._base_windows import BaseWindows\n",
"\n",
"from neuralforecast.losses.pytorch import MAE"
Expand All @@ -86,36 +87,6 @@
"outputs": [],
"source": [
"#| export\n",
"class MovingAvg(nn.Module):\n",
" \"\"\"\n",
" Moving average block to highlight the trend of time series\n",
" \"\"\"\n",
" def __init__(self, kernel_size, stride):\n",
" super(MovingAvg, self).__init__()\n",
" self.kernel_size = kernel_size\n",
" self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)\n",
"\n",
" def forward(self, x):\n",
" # padding on the both ends of time series\n",
" front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n",
" end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n",
" x = torch.cat([front, x, end], dim=1)\n",
" x = self.avg(x.permute(0, 2, 1))\n",
" x = x.permute(0, 2, 1)\n",
" return x\n",
"\n",
"class SeriesDecomp(nn.Module):\n",
" \"\"\"\n",
" Series decomposition block\n",
" \"\"\"\n",
" def __init__(self, kernel_size):\n",
" super(SeriesDecomp, self).__init__()\n",
" self.MovingAvg = MovingAvg(kernel_size, stride=1)\n",
"\n",
" def forward(self, x):\n",
" moving_mean = self.MovingAvg(x)\n",
" res = x - moving_mean\n",
" return res, moving_mean\n",
" \n",
"class LayerNorm(nn.Module):\n",
" \"\"\"\n",
Expand Down Expand Up @@ -708,7 +679,6 @@
"\n",
"Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
"Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
"\n",
"model = FEDformer(h=12,\n",
" input_size=24,\n",
" modes=64,\n",
Expand Down
81 changes: 2 additions & 79 deletions nbs/models.patchtst.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"import torch.nn.functional as F\n",
"\n",
"from neuralforecast.common._base_windows import BaseWindows\n",
"from neuralforecast.common._modules import RevIN\n",
"\n",
"from neuralforecast.losses.pytorch import MAE"
]
Expand Down Expand Up @@ -195,84 +196,6 @@
" return nn.Parameter(W_pos, requires_grad=learn_pe)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### RevIN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class RevIN(nn.Module):\n",
" \"\"\"\n",
" RevIN\n",
" \"\"\" \n",
" def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):\n",
" \"\"\"\n",
" :param num_features: the number of features or channels\n",
" :param eps: a value added for numerical stability\n",
" :param affine: if True, RevIN has learnable affine parameters\n",
" \"\"\"\n",
" super(RevIN, self).__init__()\n",
" self.num_features = num_features\n",
" self.eps = eps\n",
" self.affine = affine\n",
" self.subtract_last = subtract_last\n",
" if self.affine:\n",
" self._init_params()\n",
"\n",
" def forward(self, x, mode:str):\n",
" if mode == 'norm':\n",
" self._get_statistics(x)\n",
" x = self._normalize(x)\n",
" elif mode == 'denorm':\n",
" x = self._denormalize(x)\n",
" else: raise NotImplementedError\n",
" return x\n",
"\n",
" def _init_params(self):\n",
" # initialize RevIN params: (C,)\n",
" self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n",
" self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n",
"\n",
" def _get_statistics(self, x):\n",
" dim2reduce = tuple(range(1, x.ndim-1))\n",
" if self.subtract_last:\n",
" self.last = x[:,-1,:].unsqueeze(1)\n",
" else:\n",
" self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n",
" self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n",
"\n",
" def _normalize(self, x):\n",
" if self.subtract_last:\n",
" x = x - self.last\n",
" else:\n",
" x = x - self.mean\n",
" x = x / self.stdev\n",
" if self.affine:\n",
" x = x * self.affine_weight\n",
" x = x + self.affine_bias\n",
" return x\n",
"\n",
" def _denormalize(self, x):\n",
" if self.affine:\n",
" x = x - self.affine_bias\n",
" x = x / (self.affine_weight + self.eps*self.eps)\n",
" x = x * self.stdev\n",
" if self.subtract_last:\n",
" x = x + self.last\n",
" else:\n",
" x = x + self.mean\n",
" return x"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -929,7 +852,7 @@
"\n",
"from neuralforecast import NeuralForecast\n",
"from neuralforecast.models import PatchTST\n",
"from neuralforecast.losses.pytorch import MQLoss, DistributionLoss\n",
"from neuralforecast.losses.pytorch import DistributionLoss\n",
"from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n",
"\n",
"AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n",
Expand Down
76 changes: 2 additions & 74 deletions nbs/models.rmok.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
"import torch.nn.functional as F\n",
"\n",
"from neuralforecast.losses.pytorch import MAE\n",
"from neuralforecast.common._base_multivariate import BaseMultivariate"
"from neuralforecast.common._base_multivariate import BaseMultivariate\n",
"from neuralforecast.common._modules import RevIN"
]
},
{
Expand Down Expand Up @@ -315,79 +316,6 @@
" return y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.4 RevIN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"\n",
"class RevIN(nn.Module):\n",
" def __init__(self, num_features: int, eps=1e-5, affine=True):\n",
" \"\"\"\n",
" :param num_features: the number of features or channels\n",
" :param eps: a value added for numerical stability\n",
" :param affine: if True, RevIN has learnable affine parameters\n",
" \"\"\"\n",
" super(RevIN, self).__init__()\n",
"\n",
" self.num_features = num_features\n",
" self.eps = eps\n",
" self.affine = affine\n",
"\n",
" if self.affine:\n",
" self._init_params()\n",
"\n",
" def forward(self, x, mode: str):\n",
" if mode == 'norm':\n",
" self._get_statistics(x)\n",
" x = self._normalize(x)\n",
"\n",
" elif mode == 'denorm':\n",
" x = self._denormalize(x)\n",
"\n",
" else:\n",
" raise NotImplementedError\n",
"\n",
" return x\n",
"\n",
" def _init_params(self):\n",
" # initialize RevIN params: (C,)\n",
" self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n",
" self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n",
"\n",
" def _get_statistics(self, x):\n",
" dim2reduce = tuple(range(1, x.ndim - 1))\n",
" self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n",
" self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n",
"\n",
" def _normalize(self, x):\n",
" x = x - self.mean\n",
" x = x / self.stdev\n",
" if self.affine:\n",
" x = x * self.affine_weight\n",
" x = x + self.affine_bias\n",
"\n",
" return x\n",
"\n",
" def _denormalize(self, x):\n",
" if self.affine:\n",
" x = x - self.affine_bias\n",
" x = x / (self.affine_weight + self.eps * self.eps)\n",
" x = x * self.stdev\n",
" x = x + self.mean\n",
"\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Loading

0 comments on commit 5a86c7f

Please sign in to comment.