From 5a86c7f8217eee0f9d8e3d943d15738dee04d9f5 Mon Sep 17 00:00:00 2001 From: t-minus Date: Fri, 20 Sep 2024 02:20:19 +0800 Subject: [PATCH] [Refactor] Move RevIN class to common module (#1083) * re-use the same MovingAvg classin fedformer * re-use the same SeriesDecomp class in fedformer * RevIn in patchtst replaced by commonly used Normalize class * re-use the shared Normalize class in timellm and timemixer * Rename Normalize to RevIN * move module to common.salers move to common.scalers * Fix error due to unwanted change of cell-type * Delete extra cell and duplicate declaration of functions * Remove unwanted changes, unused imported class * Fix error in fedformer * Fix error due to formatting * restore the unwanted deleted line * Review: Move from _scalers.py to _modules.py * RMok imports shared module of RevIN --------- Co-authored-by: Olivier Sprangers <45119856+elephaint@users.noreply.github.com> --- nbs/common.modules.ipynb | 79 ++++++++++++++++++++++++++ nbs/models.fedformer.ipynb | 32 +---------- nbs/models.patchtst.ipynb | 81 +-------------------------- nbs/models.rmok.ipynb | 76 +------------------------ nbs/models.timellm.ipynb | 82 ++------------------------- nbs/models.timemixer.ipynb | 89 +---------------------------- neuralforecast/_modidx.py | 67 ---------------------- neuralforecast/common/_modules.py | 83 ++++++++++++++++++++++++++- neuralforecast/models/fedformer.py | 40 +------------ neuralforecast/models/patchtst.py | 72 +----------------------- neuralforecast/models/rmok.py | 63 +-------------------- neuralforecast/models/timellm.py | 84 +--------------------------- neuralforecast/models/timemixer.py | 90 ++---------------------------- 13 files changed, 189 insertions(+), 749 deletions(-) diff --git a/nbs/common.modules.ipynb b/nbs/common.modules.ipynb index 8fc350dac..f90e936da 100644 --- a/nbs/common.modules.ipynb +++ b/nbs/common.modules.ipynb @@ -612,6 +612,85 @@ " res = x - moving_mean\n", " return res, moving_mean" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class RevIN(nn.Module):\n", + " \"\"\" RevIN (Reversible-Instance-Normalization)\n", + " \"\"\"\n", + " def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):\n", + " \"\"\"\n", + " :param num_features: the number of features or channels\n", + " :param eps: a value added for numerical stability\n", + " :param affine: if True, RevIN has learnable affine parameters\n", + " :param substract_last: if True, the substraction is based on the last value \n", + " instead of the mean in normalization\n", + " :param non_norm: if True, no normalization performed.\n", + " \"\"\"\n", + " super(RevIN, self).__init__()\n", + " self.num_features = num_features\n", + " self.eps = eps\n", + " self.affine = affine\n", + " self.subtract_last = subtract_last\n", + " self.non_norm = non_norm\n", + " if self.affine:\n", + " self._init_params()\n", + "\n", + " def forward(self, x, mode: str):\n", + " if mode == 'norm':\n", + " self._get_statistics(x)\n", + " x = self._normalize(x)\n", + " elif mode == 'denorm':\n", + " x = self._denormalize(x)\n", + " else:\n", + " raise NotImplementedError\n", + " return x\n", + "\n", + " def _init_params(self):\n", + " # initialize RevIN params: (C,)\n", + " self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n", + " self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n", + "\n", + " def _get_statistics(self, x):\n", + " dim2reduce = tuple(range(1, x.ndim - 1))\n", + " if self.subtract_last:\n", + " self.last = x[:, -1, :].unsqueeze(1)\n", + " else:\n", + " self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n", + " self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n", + "\n", + " def _normalize(self, x):\n", + " if self.non_norm:\n", + " return x\n", + " if self.subtract_last:\n", + " x = x - self.last\n", + " else:\n", + " x = x - self.mean\n", + " x = x / self.stdev\n", + " if self.affine:\n", + " x = x * self.affine_weight\n", + " x = x + self.affine_bias\n", + " return x\n", + "\n", + " def _denormalize(self, x):\n", + " if self.non_norm:\n", + " return x\n", + " if self.affine:\n", + " x = x - self.affine_bias\n", + " x = x / (self.affine_weight + self.eps * self.eps)\n", + " x = x * self.stdev\n", + " if self.subtract_last:\n", + " x = x + self.last\n", + " else:\n", + " x = x + self.mean\n", + " return x" + ] } ], "metadata": { diff --git a/nbs/models.fedformer.ipynb b/nbs/models.fedformer.ipynb index 5126a26f9..f7c5e9f14 100644 --- a/nbs/models.fedformer.ipynb +++ b/nbs/models.fedformer.ipynb @@ -66,6 +66,7 @@ "import torch.nn.functional as F\n", "\n", "from neuralforecast.common._modules import DataEmbedding\n", + "from neuralforecast.common._modules import SeriesDecomp\n", "from neuralforecast.common._base_windows import BaseWindows\n", "\n", "from neuralforecast.losses.pytorch import MAE" @@ -86,36 +87,6 @@ "outputs": [], "source": [ "#| export\n", - "class MovingAvg(nn.Module):\n", - " \"\"\"\n", - " Moving average block to highlight the trend of time series\n", - " \"\"\"\n", - " def __init__(self, kernel_size, stride):\n", - " super(MovingAvg, self).__init__()\n", - " self.kernel_size = kernel_size\n", - " self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)\n", - "\n", - " def forward(self, x):\n", - " # padding on the both ends of time series\n", - " front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n", - " end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n", - " x = torch.cat([front, x, end], dim=1)\n", - " x = self.avg(x.permute(0, 2, 1))\n", - " x = x.permute(0, 2, 1)\n", - " return x\n", - "\n", - "class SeriesDecomp(nn.Module):\n", - " \"\"\"\n", - " Series decomposition block\n", - " \"\"\"\n", - " def __init__(self, kernel_size):\n", - " super(SeriesDecomp, self).__init__()\n", - " self.MovingAvg = MovingAvg(kernel_size, stride=1)\n", - "\n", - " def forward(self, x):\n", - " moving_mean = self.MovingAvg(x)\n", - " res = x - moving_mean\n", - " return res, moving_mean\n", " \n", "class LayerNorm(nn.Module):\n", " \"\"\"\n", @@ -708,7 +679,6 @@ "\n", "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n", - "\n", "model = FEDformer(h=12,\n", " input_size=24,\n", " modes=64,\n", diff --git a/nbs/models.patchtst.ipynb b/nbs/models.patchtst.ipynb index bd801a573..ff579bf86 100644 --- a/nbs/models.patchtst.ipynb +++ b/nbs/models.patchtst.ipynb @@ -62,6 +62,7 @@ "import torch.nn.functional as F\n", "\n", "from neuralforecast.common._base_windows import BaseWindows\n", + "from neuralforecast.common._modules import RevIN\n", "\n", "from neuralforecast.losses.pytorch import MAE" ] @@ -195,84 +196,6 @@ " return nn.Parameter(W_pos, requires_grad=learn_pe)" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### RevIN" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "class RevIN(nn.Module):\n", - " \"\"\"\n", - " RevIN\n", - " \"\"\" \n", - " def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):\n", - " \"\"\"\n", - " :param num_features: the number of features or channels\n", - " :param eps: a value added for numerical stability\n", - " :param affine: if True, RevIN has learnable affine parameters\n", - " \"\"\"\n", - " super(RevIN, self).__init__()\n", - " self.num_features = num_features\n", - " self.eps = eps\n", - " self.affine = affine\n", - " self.subtract_last = subtract_last\n", - " if self.affine:\n", - " self._init_params()\n", - "\n", - " def forward(self, x, mode:str):\n", - " if mode == 'norm':\n", - " self._get_statistics(x)\n", - " x = self._normalize(x)\n", - " elif mode == 'denorm':\n", - " x = self._denormalize(x)\n", - " else: raise NotImplementedError\n", - " return x\n", - "\n", - " def _init_params(self):\n", - " # initialize RevIN params: (C,)\n", - " self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n", - " self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n", - "\n", - " def _get_statistics(self, x):\n", - " dim2reduce = tuple(range(1, x.ndim-1))\n", - " if self.subtract_last:\n", - " self.last = x[:,-1,:].unsqueeze(1)\n", - " else:\n", - " self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n", - " self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n", - "\n", - " def _normalize(self, x):\n", - " if self.subtract_last:\n", - " x = x - self.last\n", - " else:\n", - " x = x - self.mean\n", - " x = x / self.stdev\n", - " if self.affine:\n", - " x = x * self.affine_weight\n", - " x = x + self.affine_bias\n", - " return x\n", - "\n", - " def _denormalize(self, x):\n", - " if self.affine:\n", - " x = x - self.affine_bias\n", - " x = x / (self.affine_weight + self.eps*self.eps)\n", - " x = x * self.stdev\n", - " if self.subtract_last:\n", - " x = x + self.last\n", - " else:\n", - " x = x + self.mean\n", - " return x" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -929,7 +852,7 @@ "\n", "from neuralforecast import NeuralForecast\n", "from neuralforecast.models import PatchTST\n", - "from neuralforecast.losses.pytorch import MQLoss, DistributionLoss\n", + "from neuralforecast.losses.pytorch import DistributionLoss\n", "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n", "\n", "AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n", diff --git a/nbs/models.rmok.ipynb b/nbs/models.rmok.ipynb index 08fee40c0..7f4160e09 100644 --- a/nbs/models.rmok.ipynb +++ b/nbs/models.rmok.ipynb @@ -73,7 +73,8 @@ "import torch.nn.functional as F\n", "\n", "from neuralforecast.losses.pytorch import MAE\n", - "from neuralforecast.common._base_multivariate import BaseMultivariate" + "from neuralforecast.common._base_multivariate import BaseMultivariate\n", + "from neuralforecast.common._modules import RevIN" ] }, { @@ -315,79 +316,6 @@ " return y" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1.4 RevIN" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "\n", - "class RevIN(nn.Module):\n", - " def __init__(self, num_features: int, eps=1e-5, affine=True):\n", - " \"\"\"\n", - " :param num_features: the number of features or channels\n", - " :param eps: a value added for numerical stability\n", - " :param affine: if True, RevIN has learnable affine parameters\n", - " \"\"\"\n", - " super(RevIN, self).__init__()\n", - "\n", - " self.num_features = num_features\n", - " self.eps = eps\n", - " self.affine = affine\n", - "\n", - " if self.affine:\n", - " self._init_params()\n", - "\n", - " def forward(self, x, mode: str):\n", - " if mode == 'norm':\n", - " self._get_statistics(x)\n", - " x = self._normalize(x)\n", - "\n", - " elif mode == 'denorm':\n", - " x = self._denormalize(x)\n", - "\n", - " else:\n", - " raise NotImplementedError\n", - "\n", - " return x\n", - "\n", - " def _init_params(self):\n", - " # initialize RevIN params: (C,)\n", - " self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n", - " self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n", - "\n", - " def _get_statistics(self, x):\n", - " dim2reduce = tuple(range(1, x.ndim - 1))\n", - " self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n", - " self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n", - "\n", - " def _normalize(self, x):\n", - " x = x - self.mean\n", - " x = x / self.stdev\n", - " if self.affine:\n", - " x = x * self.affine_weight\n", - " x = x + self.affine_bias\n", - "\n", - " return x\n", - "\n", - " def _denormalize(self, x):\n", - " if self.affine:\n", - " x = x - self.affine_bias\n", - " x = x / (self.affine_weight + self.eps * self.eps)\n", - " x = x * self.stdev\n", - " x = x + self.mean\n", - "\n", - " return x" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/nbs/models.timellm.ipynb b/nbs/models.timellm.ipynb index 00d0bab0f..dbc35a7bd 100755 --- a/nbs/models.timellm.ipynb +++ b/nbs/models.timellm.ipynb @@ -64,6 +64,7 @@ "import torch.nn as nn\n", "\n", "from neuralforecast.common._base_windows import BaseWindows\n", + "from neuralforecast.common._modules import RevIN\n", "\n", "from neuralforecast.losses.pytorch import MAE\n", "\n", @@ -222,74 +223,7 @@ " reprogramming_embedding = torch.einsum(\"bhls,she->blhe\", A, value_embedding)\n", "\n", " return reprogramming_embedding\n", - " \n", - "class Normalize(nn.Module):\n", - " \"\"\"\n", - " Normalize\n", - " \"\"\" \n", - " def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):\n", - " \"\"\"\n", - " :param num_features: the number of features or channels\n", - " :param eps: a value added for numerical stability\n", - " :param affine: if True, RevIN has learnable affine parameters\n", - " \"\"\"\n", - " super(Normalize, self).__init__()\n", - " self.num_features = num_features\n", - " self.eps = eps\n", - " self.affine = affine\n", - " self.subtract_last = subtract_last\n", - " self.non_norm = non_norm\n", - " if self.affine:\n", - " self._init_params()\n", - "\n", - " def forward(self, x, mode: str):\n", - " if mode == 'norm':\n", - " self._get_statistics(x)\n", - " x = self._normalize(x)\n", - " elif mode == 'denorm':\n", - " x = self._denormalize(x)\n", - " else:\n", - " raise NotImplementedError\n", - " return x\n", - "\n", - " def _init_params(self):\n", - " # initialize RevIN params: (C,)\n", - " self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n", - " self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n", - "\n", - " def _get_statistics(self, x):\n", - " dim2reduce = tuple(range(1, x.ndim - 1))\n", - " if self.subtract_last:\n", - " self.last = x[:, -1, :].unsqueeze(1)\n", - " else:\n", - " self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n", - " self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n", - "\n", - " def _normalize(self, x):\n", - " if self.non_norm:\n", - " return x\n", - " if self.subtract_last:\n", - " x = x - self.last\n", - " else:\n", - " x = x - self.mean\n", - " x = x / self.stdev\n", - " if self.affine:\n", - " x = x * self.affine_weight\n", - " x = x + self.affine_bias\n", - " return x\n", - "\n", - " def _denormalize(self, x):\n", - " if self.non_norm:\n", - " return x\n", - " if self.affine:\n", - " x = x - self.affine_bias\n", - " x = x / (self.affine_weight + self.eps * self.eps)\n", - " x = x * self.stdev\n", - " if self.subtract_last:\n", - " x = x + self.last\n", - " else:\n", - " x = x + self.mean\n", - " return x" + " \n" ] }, { @@ -517,7 +451,7 @@ "\n", " self.output_projection = FlattenHead(self.enc_in, self.head_nf, self.h, head_dropout=self.dropout)\n", "\n", - " self.normalize_layers = Normalize(self.enc_in, affine=False)\n", + " self.normalize_layers = RevIN(self.enc_in, affine=False)\n", "\n", " def forecast(self, x_enc):\n", "\n", @@ -594,8 +528,7 @@ " y_pred = y_pred[:, -self.h:, :]\n", " y_pred = self.loss.domain_map(y_pred)\n", " \n", - " return y_pred\n", - "\n" + " return y_pred\n" ] }, { @@ -666,13 +599,6 @@ "nf.fit(df=Y_train_df, val_size=12)\n", "forecasts = nf.predict(futr_df=Y_test_df)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/nbs/models.timemixer.ipynb b/nbs/models.timemixer.ipynb index 0aacd694c..00e974d39 100644 --- a/nbs/models.timemixer.ipynb +++ b/nbs/models.timemixer.ipynb @@ -42,7 +42,7 @@ "import torch.nn as nn\n", "\n", "from neuralforecast.common._base_multivariate import BaseMultivariate\n", - "from neuralforecast.common._modules import PositionalEmbedding, TokenEmbedding, TemporalEmbedding, SeriesDecomp\n", + "from neuralforecast.common._modules import PositionalEmbedding, TokenEmbedding, TemporalEmbedding, SeriesDecomp, RevIN\n", "\n", "from neuralforecast.losses.pytorch import MAE" ] @@ -58,91 +58,6 @@ "from nbdev.showdoc import show_doc" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Auxiliary functions\n", - "### Normalization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "\n", - "class Normalize(nn.Module):\n", - " \"\"\"\n", - " Normalize\n", - " \"\"\"\n", - " def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):\n", - " \"\"\"\n", - " :param num_features: the number of features or channels\n", - " :param eps: a value added for numerical stability\n", - " :param affine: if True, RevIN has learnable affine parameters\n", - " \"\"\"\n", - " super(Normalize, self).__init__()\n", - " self.num_features = num_features\n", - " self.eps = eps\n", - " self.affine = affine\n", - " self.subtract_last = subtract_last\n", - " self.non_norm = non_norm\n", - " if self.affine:\n", - " self._init_params()\n", - "\n", - " def forward(self, x, mode: str):\n", - " if mode == 'norm':\n", - " self._get_statistics(x)\n", - " x = self._normalize(x)\n", - " elif mode == 'denorm':\n", - " x = self._denormalize(x)\n", - " else:\n", - " raise NotImplementedError\n", - " return x\n", - "\n", - " def _init_params(self):\n", - " # initialize RevIN params: (C,)\n", - " self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n", - " self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n", - "\n", - " def _get_statistics(self, x):\n", - " dim2reduce = tuple(range(1, x.ndim - 1))\n", - " if self.subtract_last:\n", - " self.last = x[:, -1, :].unsqueeze(1)\n", - " else:\n", - " self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n", - " self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n", - "\n", - " def _normalize(self, x):\n", - " if self.non_norm:\n", - " return x\n", - " if self.subtract_last:\n", - " x = x - self.last\n", - " else:\n", - " x = x - self.mean\n", - " x = x / self.stdev\n", - " if self.affine:\n", - " x = x * self.affine_weight\n", - " x = x + self.affine_bias\n", - " return x\n", - "\n", - " def _denormalize(self, x):\n", - " if self.non_norm:\n", - " return x\n", - " if self.affine:\n", - " x = x - self.affine_bias\n", - " x = x / (self.affine_weight + self.eps * self.eps)\n", - " x = x * self.stdev\n", - " if self.subtract_last:\n", - " x = x + self.last\n", - " else:\n", - " x = x + self.mean\n", - " return x" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -565,7 +480,7 @@ "\n", " self.normalize_layers = torch.nn.ModuleList(\n", " [\n", - " Normalize(self.enc_in, affine=True, non_norm=False if self.use_norm else True)\n", + " RevIN(self.enc_in, affine=True, non_norm=False if self.use_norm else True)\n", " for i in range(self.down_sampling_layers + 1)\n", " ]\n", " )\n", diff --git a/neuralforecast/_modidx.py b/neuralforecast/_modidx.py index add079b58..6fcba7d61 100644 --- a/neuralforecast/_modidx.py +++ b/neuralforecast/_modidx.py @@ -726,18 +726,6 @@ 'neuralforecast/models/fedformer.py'), 'neuralforecast.models.fedformer.LayerNorm.forward': ( 'models.fedformer.html#layernorm.forward', 'neuralforecast/models/fedformer.py'), - 'neuralforecast.models.fedformer.MovingAvg': ( 'models.fedformer.html#movingavg', - 'neuralforecast/models/fedformer.py'), - 'neuralforecast.models.fedformer.MovingAvg.__init__': ( 'models.fedformer.html#movingavg.__init__', - 'neuralforecast/models/fedformer.py'), - 'neuralforecast.models.fedformer.MovingAvg.forward': ( 'models.fedformer.html#movingavg.forward', - 'neuralforecast/models/fedformer.py'), - 'neuralforecast.models.fedformer.SeriesDecomp': ( 'models.fedformer.html#seriesdecomp', - 'neuralforecast/models/fedformer.py'), - 'neuralforecast.models.fedformer.SeriesDecomp.__init__': ( 'models.fedformer.html#seriesdecomp.__init__', - 'neuralforecast/models/fedformer.py'), - 'neuralforecast.models.fedformer.SeriesDecomp.forward': ( 'models.fedformer.html#seriesdecomp.forward', - 'neuralforecast/models/fedformer.py'), 'neuralforecast.models.fedformer.get_frequency_modes': ( 'models.fedformer.html#get_frequency_modes', 'neuralforecast/models/fedformer.py')}, 'neuralforecast.models.gru': { 'neuralforecast.models.gru.GRU': ('models.gru.html#gru', 'neuralforecast/models/gru.py'), @@ -987,20 +975,6 @@ 'neuralforecast/models/patchtst.py'), 'neuralforecast.models.patchtst.PositionalEncoding': ( 'models.patchtst.html#positionalencoding', 'neuralforecast/models/patchtst.py'), - 'neuralforecast.models.patchtst.RevIN': ( 'models.patchtst.html#revin', - 'neuralforecast/models/patchtst.py'), - 'neuralforecast.models.patchtst.RevIN.__init__': ( 'models.patchtst.html#revin.__init__', - 'neuralforecast/models/patchtst.py'), - 'neuralforecast.models.patchtst.RevIN._denormalize': ( 'models.patchtst.html#revin._denormalize', - 'neuralforecast/models/patchtst.py'), - 'neuralforecast.models.patchtst.RevIN._get_statistics': ( 'models.patchtst.html#revin._get_statistics', - 'neuralforecast/models/patchtst.py'), - 'neuralforecast.models.patchtst.RevIN._init_params': ( 'models.patchtst.html#revin._init_params', - 'neuralforecast/models/patchtst.py'), - 'neuralforecast.models.patchtst.RevIN._normalize': ( 'models.patchtst.html#revin._normalize', - 'neuralforecast/models/patchtst.py'), - 'neuralforecast.models.patchtst.RevIN.forward': ( 'models.patchtst.html#revin.forward', - 'neuralforecast/models/patchtst.py'), 'neuralforecast.models.patchtst.TSTEncoder': ( 'models.patchtst.html#tstencoder', 'neuralforecast/models/patchtst.py'), 'neuralforecast.models.patchtst.TSTEncoder.__init__': ( 'models.patchtst.html#tstencoder.__init__', @@ -1052,19 +1026,6 @@ 'neuralforecast/models/rmok.py'), 'neuralforecast.models.rmok.RMoK.forward': ( 'models.rmok.html#rmok.forward', 'neuralforecast/models/rmok.py'), - 'neuralforecast.models.rmok.RevIN': ('models.rmok.html#revin', 'neuralforecast/models/rmok.py'), - 'neuralforecast.models.rmok.RevIN.__init__': ( 'models.rmok.html#revin.__init__', - 'neuralforecast/models/rmok.py'), - 'neuralforecast.models.rmok.RevIN._denormalize': ( 'models.rmok.html#revin._denormalize', - 'neuralforecast/models/rmok.py'), - 'neuralforecast.models.rmok.RevIN._get_statistics': ( 'models.rmok.html#revin._get_statistics', - 'neuralforecast/models/rmok.py'), - 'neuralforecast.models.rmok.RevIN._init_params': ( 'models.rmok.html#revin._init_params', - 'neuralforecast/models/rmok.py'), - 'neuralforecast.models.rmok.RevIN._normalize': ( 'models.rmok.html#revin._normalize', - 'neuralforecast/models/rmok.py'), - 'neuralforecast.models.rmok.RevIN.forward': ( 'models.rmok.html#revin.forward', - 'neuralforecast/models/rmok.py'), 'neuralforecast.models.rmok.TaylorKANLayer': ( 'models.rmok.html#taylorkanlayer', 'neuralforecast/models/rmok.py'), 'neuralforecast.models.rmok.TaylorKANLayer.__init__': ( 'models.rmok.html#taylorkanlayer.__init__', @@ -1223,20 +1184,6 @@ 'neuralforecast/models/timellm.py'), 'neuralforecast.models.timellm.FlattenHead.forward': ( 'models.timellm.html#flattenhead.forward', 'neuralforecast/models/timellm.py'), - 'neuralforecast.models.timellm.Normalize': ( 'models.timellm.html#normalize', - 'neuralforecast/models/timellm.py'), - 'neuralforecast.models.timellm.Normalize.__init__': ( 'models.timellm.html#normalize.__init__', - 'neuralforecast/models/timellm.py'), - 'neuralforecast.models.timellm.Normalize._denormalize': ( 'models.timellm.html#normalize._denormalize', - 'neuralforecast/models/timellm.py'), - 'neuralforecast.models.timellm.Normalize._get_statistics': ( 'models.timellm.html#normalize._get_statistics', - 'neuralforecast/models/timellm.py'), - 'neuralforecast.models.timellm.Normalize._init_params': ( 'models.timellm.html#normalize._init_params', - 'neuralforecast/models/timellm.py'), - 'neuralforecast.models.timellm.Normalize._normalize': ( 'models.timellm.html#normalize._normalize', - 'neuralforecast/models/timellm.py'), - 'neuralforecast.models.timellm.Normalize.forward': ( 'models.timellm.html#normalize.forward', - 'neuralforecast/models/timellm.py'), 'neuralforecast.models.timellm.PatchEmbedding': ( 'models.timellm.html#patchembedding', 'neuralforecast/models/timellm.py'), 'neuralforecast.models.timellm.PatchEmbedding.__init__': ( 'models.timellm.html#patchembedding.__init__', @@ -1297,20 +1244,6 @@ 'neuralforecast/models/timemixer.py'), 'neuralforecast.models.timemixer.MultiScaleTrendMixing.forward': ( 'models.timemixer.html#multiscaletrendmixing.forward', 'neuralforecast/models/timemixer.py'), - 'neuralforecast.models.timemixer.Normalize': ( 'models.timemixer.html#normalize', - 'neuralforecast/models/timemixer.py'), - 'neuralforecast.models.timemixer.Normalize.__init__': ( 'models.timemixer.html#normalize.__init__', - 'neuralforecast/models/timemixer.py'), - 'neuralforecast.models.timemixer.Normalize._denormalize': ( 'models.timemixer.html#normalize._denormalize', - 'neuralforecast/models/timemixer.py'), - 'neuralforecast.models.timemixer.Normalize._get_statistics': ( 'models.timemixer.html#normalize._get_statistics', - 'neuralforecast/models/timemixer.py'), - 'neuralforecast.models.timemixer.Normalize._init_params': ( 'models.timemixer.html#normalize._init_params', - 'neuralforecast/models/timemixer.py'), - 'neuralforecast.models.timemixer.Normalize._normalize': ( 'models.timemixer.html#normalize._normalize', - 'neuralforecast/models/timemixer.py'), - 'neuralforecast.models.timemixer.Normalize.forward': ( 'models.timemixer.html#normalize.forward', - 'neuralforecast/models/timemixer.py'), 'neuralforecast.models.timemixer.PastDecomposableMixing': ( 'models.timemixer.html#pastdecomposablemixing', 'neuralforecast/models/timemixer.py'), 'neuralforecast.models.timemixer.PastDecomposableMixing.__init__': ( 'models.timemixer.html#pastdecomposablemixing.__init__', diff --git a/neuralforecast/common/_modules.py b/neuralforecast/common/_modules.py index c235612ac..d50228b87 100644 --- a/neuralforecast/common/_modules.py +++ b/neuralforecast/common/_modules.py @@ -3,7 +3,8 @@ # %% auto 0 __all__ = ['ACTIVATIONS', 'MLP', 'Chomp1d', 'CausalConv1d', 'TemporalConvolutionEncoder', 'TransEncoderLayer', 'TransEncoder', 'TransDecoderLayer', 'TransDecoder', 'AttentionLayer', 'PositionalEmbedding', 'TokenEmbedding', - 'TimeFeatureEmbedding', 'FixedEmbedding', 'TemporalEmbedding', 'DataEmbedding', 'MovingAvg', 'SeriesDecomp'] + 'TimeFeatureEmbedding', 'FixedEmbedding', 'TemporalEmbedding', 'DataEmbedding', 'MovingAvg', 'SeriesDecomp', + 'RevIN'] # %% ../../nbs/common.modules.ipynb 3 import math @@ -520,3 +521,83 @@ def forward(self, x): moving_mean = self.MovingAvg(x) res = x - moving_mean return res, moving_mean + +# %% ../../nbs/common.modules.ipynb 20 +class RevIN(nn.Module): + """RevIN (Reversible-Instance-Normalization)""" + + def __init__( + self, + num_features: int, + eps=1e-5, + affine=False, + subtract_last=False, + non_norm=False, + ): + """ + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + :param substract_last: if True, the substraction is based on the last value + instead of the mean in normalization + :param non_norm: if True, no normalization performed. + """ + super(RevIN, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.subtract_last = subtract_last + self.non_norm = non_norm + if self.affine: + self._init_params() + + def forward(self, x, mode: str): + if mode == "norm": + self._get_statistics(x) + x = self._normalize(x) + elif mode == "denorm": + x = self._denormalize(x) + else: + raise NotImplementedError + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + def _get_statistics(self, x): + dim2reduce = tuple(range(1, x.ndim - 1)) + if self.subtract_last: + self.last = x[:, -1, :].unsqueeze(1) + else: + self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() + self.stdev = torch.sqrt( + torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps + ).detach() + + def _normalize(self, x): + if self.non_norm: + return x + if self.subtract_last: + x = x - self.last + else: + x = x - self.mean + x = x / self.stdev + if self.affine: + x = x * self.affine_weight + x = x + self.affine_bias + return x + + def _denormalize(self, x): + if self.non_norm: + return x + if self.affine: + x = x - self.affine_bias + x = x / (self.affine_weight + self.eps * self.eps) + x = x * self.stdev + if self.subtract_last: + x = x + self.last + else: + x = x + self.mean + return x diff --git a/neuralforecast/models/fedformer.py b/neuralforecast/models/fedformer.py index d7c03261f..d811b6dce 100644 --- a/neuralforecast/models/fedformer.py +++ b/neuralforecast/models/fedformer.py @@ -1,8 +1,8 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.fedformer.ipynb. # %% auto 0 -__all__ = ['MovingAvg', 'SeriesDecomp', 'LayerNorm', 'AutoCorrelationLayer', 'EncoderLayer', 'Encoder', 'DecoderLayer', 'Decoder', - 'get_frequency_modes', 'FourierBlock', 'FourierCrossAttention', 'FEDformer'] +__all__ = ['LayerNorm', 'AutoCorrelationLayer', 'EncoderLayer', 'Encoder', 'DecoderLayer', 'Decoder', 'get_frequency_modes', + 'FourierBlock', 'FourierCrossAttention', 'FEDformer'] # %% ../../nbs/models.fedformer.ipynb 5 import numpy as np @@ -13,46 +13,12 @@ import torch.nn.functional as F from ..common._modules import DataEmbedding +from ..common._modules import SeriesDecomp from ..common._base_windows import BaseWindows from ..losses.pytorch import MAE # %% ../../nbs/models.fedformer.ipynb 7 -class MovingAvg(nn.Module): - """ - Moving average block to highlight the trend of time series - """ - - def __init__(self, kernel_size, stride): - super(MovingAvg, self).__init__() - self.kernel_size = kernel_size - self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) - - def forward(self, x): - # padding on the both ends of time series - front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) - end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) - x = torch.cat([front, x, end], dim=1) - x = self.avg(x.permute(0, 2, 1)) - x = x.permute(0, 2, 1) - return x - - -class SeriesDecomp(nn.Module): - """ - Series decomposition block - """ - - def __init__(self, kernel_size): - super(SeriesDecomp, self).__init__() - self.MovingAvg = MovingAvg(kernel_size, stride=1) - - def forward(self, x): - moving_mean = self.MovingAvg(x) - res = x - moving_mean - return res, moving_mean - - class LayerNorm(nn.Module): """ Special designed layernorm for the seasonal part diff --git a/neuralforecast/models/patchtst.py b/neuralforecast/models/patchtst.py index 3090ee2f0..add87d623 100644 --- a/neuralforecast/models/patchtst.py +++ b/neuralforecast/models/patchtst.py @@ -2,7 +2,7 @@ # %% auto 0 __all__ = ['SinCosPosEncoding', 'Transpose', 'get_activation_fn', 'PositionalEncoding', 'Coord2dPosEncoding', - 'Coord1dPosEncoding', 'positional_encoding', 'RevIN', 'PatchTST_backbone', 'Flatten_Head', 'TSTiEncoder', + 'Coord1dPosEncoding', 'positional_encoding', 'PatchTST_backbone', 'Flatten_Head', 'TSTiEncoder', 'TSTEncoder', 'TSTEncoderLayer', 'PatchTST'] # %% ../../nbs/models.patchtst.ipynb 5 @@ -15,6 +15,7 @@ import torch.nn.functional as F from ..common._base_windows import BaseWindows +from ..common._modules import RevIN from ..losses.pytorch import MAE @@ -138,73 +139,6 @@ def positional_encoding(pe, learn_pe, q_len, hidden_size): return nn.Parameter(W_pos, requires_grad=learn_pe) # %% ../../nbs/models.patchtst.ipynb 13 -class RevIN(nn.Module): - """ - RevIN - """ - - def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False): - """ - :param num_features: the number of features or channels - :param eps: a value added for numerical stability - :param affine: if True, RevIN has learnable affine parameters - """ - super(RevIN, self).__init__() - self.num_features = num_features - self.eps = eps - self.affine = affine - self.subtract_last = subtract_last - if self.affine: - self._init_params() - - def forward(self, x, mode: str): - if mode == "norm": - self._get_statistics(x) - x = self._normalize(x) - elif mode == "denorm": - x = self._denormalize(x) - else: - raise NotImplementedError - return x - - def _init_params(self): - # initialize RevIN params: (C,) - self.affine_weight = nn.Parameter(torch.ones(self.num_features)) - self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) - - def _get_statistics(self, x): - dim2reduce = tuple(range(1, x.ndim - 1)) - if self.subtract_last: - self.last = x[:, -1, :].unsqueeze(1) - else: - self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() - self.stdev = torch.sqrt( - torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps - ).detach() - - def _normalize(self, x): - if self.subtract_last: - x = x - self.last - else: - x = x - self.mean - x = x / self.stdev - if self.affine: - x = x * self.affine_weight - x = x + self.affine_bias - return x - - def _denormalize(self, x): - if self.affine: - x = x - self.affine_bias - x = x / (self.affine_weight + self.eps * self.eps) - x = x * self.stdev - if self.subtract_last: - x = x + self.last - else: - x = x + self.mean - return x - -# %% ../../nbs/models.patchtst.ipynb 15 class PatchTST_backbone(nn.Module): """ PatchTST_backbone @@ -850,7 +784,7 @@ def forward( else: return output, attn_weights -# %% ../../nbs/models.patchtst.ipynb 17 +# %% ../../nbs/models.patchtst.ipynb 15 class PatchTST(BaseWindows): """PatchTST diff --git a/neuralforecast/models/rmok.py b/neuralforecast/models/rmok.py index c83585e1b..7f9e5718b 100644 --- a/neuralforecast/models/rmok.py +++ b/neuralforecast/models/rmok.py @@ -1,7 +1,7 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.rmok.ipynb. # %% auto 0 -__all__ = ['WaveKANLayer', 'TaylorKANLayer', 'JacobiKANLayer', 'RevIN', 'RMoK'] +__all__ = ['WaveKANLayer', 'TaylorKANLayer', 'JacobiKANLayer', 'RMoK'] # %% ../../nbs/models.rmok.ipynb 6 import math @@ -12,6 +12,7 @@ from ..losses.pytorch import MAE from ..common._base_multivariate import BaseMultivariate +from ..common._modules import RevIN # %% ../../nbs/models.rmok.ipynb 8 class WaveKANLayer(nn.Module): @@ -255,66 +256,6 @@ def forward(self, x): return y # %% ../../nbs/models.rmok.ipynb 14 -class RevIN(nn.Module): - def __init__(self, num_features: int, eps=1e-5, affine=True): - """ - :param num_features: the number of features or channels - :param eps: a value added for numerical stability - :param affine: if True, RevIN has learnable affine parameters - """ - super(RevIN, self).__init__() - - self.num_features = num_features - self.eps = eps - self.affine = affine - - if self.affine: - self._init_params() - - def forward(self, x, mode: str): - if mode == "norm": - self._get_statistics(x) - x = self._normalize(x) - - elif mode == "denorm": - x = self._denormalize(x) - - else: - raise NotImplementedError - - return x - - def _init_params(self): - # initialize RevIN params: (C,) - self.affine_weight = nn.Parameter(torch.ones(self.num_features)) - self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) - - def _get_statistics(self, x): - dim2reduce = tuple(range(1, x.ndim - 1)) - self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() - self.stdev = torch.sqrt( - torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps - ).detach() - - def _normalize(self, x): - x = x - self.mean - x = x / self.stdev - if self.affine: - x = x * self.affine_weight - x = x + self.affine_bias - - return x - - def _denormalize(self, x): - if self.affine: - x = x - self.affine_bias - x = x / (self.affine_weight + self.eps * self.eps) - x = x * self.stdev - x = x + self.mean - - return x - -# %% ../../nbs/models.rmok.ipynb 16 class RMoK(BaseMultivariate): """Reversible Mixture of KAN **Parameters**
diff --git a/neuralforecast/models/timellm.py b/neuralforecast/models/timellm.py index 603eb6265..4c58a4b23 100644 --- a/neuralforecast/models/timellm.py +++ b/neuralforecast/models/timellm.py @@ -1,7 +1,7 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.timellm.ipynb. # %% auto 0 -__all__ = ['ReplicationPad1d', 'TokenEmbedding', 'PatchEmbedding', 'FlattenHead', 'ReprogrammingLayer', 'Normalize', 'TimeLLM'] +__all__ = ['ReplicationPad1d', 'TokenEmbedding', 'PatchEmbedding', 'FlattenHead', 'ReprogrammingLayer', 'TimeLLM'] # %% ../../nbs/models.timellm.ipynb 6 import math @@ -11,6 +11,7 @@ import torch.nn as nn from ..common._base_windows import BaseWindows +from ..common._modules import RevIN from ..losses.pytorch import MAE @@ -163,85 +164,6 @@ def reprogramming(self, target_embedding, source_embedding, value_embedding): return reprogramming_embedding - -class Normalize(nn.Module): - """ - Normalize - """ - - def __init__( - self, - num_features: int, - eps=1e-5, - affine=False, - subtract_last=False, - non_norm=False, - ): - """ - :param num_features: the number of features or channels - :param eps: a value added for numerical stability - :param affine: if True, RevIN has learnable affine parameters - """ - super(Normalize, self).__init__() - self.num_features = num_features - self.eps = eps - self.affine = affine - self.subtract_last = subtract_last - self.non_norm = non_norm - if self.affine: - self._init_params() - - def forward(self, x, mode: str): - if mode == "norm": - self._get_statistics(x) - x = self._normalize(x) - elif mode == "denorm": - x = self._denormalize(x) - else: - raise NotImplementedError - return x - - def _init_params(self): - # initialize RevIN params: (C,) - self.affine_weight = nn.Parameter(torch.ones(self.num_features)) - self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) - - def _get_statistics(self, x): - dim2reduce = tuple(range(1, x.ndim - 1)) - if self.subtract_last: - self.last = x[:, -1, :].unsqueeze(1) - else: - self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() - self.stdev = torch.sqrt( - torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps - ).detach() - - def _normalize(self, x): - if self.non_norm: - return x - if self.subtract_last: - x = x - self.last - else: - x = x - self.mean - x = x / self.stdev - if self.affine: - x = x * self.affine_weight - x = x + self.affine_bias - return x - - def _denormalize(self, x): - if self.non_norm: - return x - if self.affine: - x = x - self.affine_bias - x = x / (self.affine_weight + self.eps * self.eps) - x = x * self.stdev - if self.subtract_last: - x = x + self.last - else: - x = x + self.mean - return x - # %% ../../nbs/models.timellm.ipynb 11 class TimeLLM(BaseWindows): """TimeLLM @@ -465,7 +387,7 @@ def __init__( self.enc_in, self.head_nf, self.h, head_dropout=self.dropout ) - self.normalize_layers = Normalize(self.enc_in, affine=False) + self.normalize_layers = RevIN(self.enc_in, affine=False) def forecast(self, x_enc): diff --git a/neuralforecast/models/timemixer.py b/neuralforecast/models/timemixer.py index 571c4e96e..602e602c7 100644 --- a/neuralforecast/models/timemixer.py +++ b/neuralforecast/models/timemixer.py @@ -1,7 +1,7 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.timemixer.ipynb. # %% auto 0 -__all__ = ['Normalize', 'DataEmbedding_wo_pos', 'DFT_series_decomp', 'MultiScaleSeasonMixing', 'MultiScaleTrendMixing', +__all__ = ['DataEmbedding_wo_pos', 'DFT_series_decomp', 'MultiScaleSeasonMixing', 'MultiScaleTrendMixing', 'PastDecomposableMixing', 'TimeMixer'] # %% ../../nbs/models.timemixer.ipynb 3 @@ -17,90 +17,12 @@ TokenEmbedding, TemporalEmbedding, SeriesDecomp, + RevIN, ) from ..losses.pytorch import MAE # %% ../../nbs/models.timemixer.ipynb 6 -class Normalize(nn.Module): - """ - Normalize - """ - - def __init__( - self, - num_features: int, - eps=1e-5, - affine=False, - subtract_last=False, - non_norm=False, - ): - """ - :param num_features: the number of features or channels - :param eps: a value added for numerical stability - :param affine: if True, RevIN has learnable affine parameters - """ - super(Normalize, self).__init__() - self.num_features = num_features - self.eps = eps - self.affine = affine - self.subtract_last = subtract_last - self.non_norm = non_norm - if self.affine: - self._init_params() - - def forward(self, x, mode: str): - if mode == "norm": - self._get_statistics(x) - x = self._normalize(x) - elif mode == "denorm": - x = self._denormalize(x) - else: - raise NotImplementedError - return x - - def _init_params(self): - # initialize RevIN params: (C,) - self.affine_weight = nn.Parameter(torch.ones(self.num_features)) - self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) - - def _get_statistics(self, x): - dim2reduce = tuple(range(1, x.ndim - 1)) - if self.subtract_last: - self.last = x[:, -1, :].unsqueeze(1) - else: - self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() - self.stdev = torch.sqrt( - torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps - ).detach() - - def _normalize(self, x): - if self.non_norm: - return x - if self.subtract_last: - x = x - self.last - else: - x = x - self.mean - x = x / self.stdev - if self.affine: - x = x * self.affine_weight - x = x + self.affine_bias - return x - - def _denormalize(self, x): - if self.non_norm: - return x - if self.affine: - x = x - self.affine_bias - x = x / (self.affine_weight + self.eps * self.eps) - x = x * self.stdev - if self.subtract_last: - x = x + self.last - else: - x = x + self.mean - return x - -# %% ../../nbs/models.timemixer.ipynb 8 class DataEmbedding_wo_pos(nn.Module): """ DataEmbedding_wo_pos @@ -125,7 +47,7 @@ def forward(self, x, x_mark): x = self.value_embedding(x) + self.temporal_embedding(x_mark) return self.dropout(x) -# %% ../../nbs/models.timemixer.ipynb 10 +# %% ../../nbs/models.timemixer.ipynb 8 class DFT_series_decomp(nn.Module): """ Series decomposition block @@ -145,7 +67,7 @@ def forward(self, x): x_trend = x - x_season return x_season, x_trend -# %% ../../nbs/models.timemixer.ipynb 12 +# %% ../../nbs/models.timemixer.ipynb 10 class MultiScaleSeasonMixing(nn.Module): """ Bottom-up mixing season pattern @@ -326,7 +248,7 @@ def forward(self, x_list): out_list.append(out[:, :length, :]) return out_list -# %% ../../nbs/models.timemixer.ipynb 14 +# %% ../../nbs/models.timemixer.ipynb 12 class TimeMixer(BaseMultivariate): """TimeMixer **Parameters**
@@ -507,7 +429,7 @@ def __init__( self.normalize_layers = torch.nn.ModuleList( [ - Normalize( + RevIN( self.enc_in, affine=True, non_norm=False if self.use_norm else True ) for i in range(self.down_sampling_layers + 1)