From c307da65a1ce22c2020891e42a3a67c68d7780fc Mon Sep 17 00:00:00 2001 From: John Loverich Date: Thu, 6 Jun 2024 05:13:11 -0700 Subject: [PATCH] Add layer norm --- README.md | 3 ++- high_order_layers_torch/layers.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 435af32..83fb344 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,8 @@ python examples/variational_autoencoder.py -m ``` ## Invariant MNIST (fully connected) -Some comparisons using parameter scans maxabs normalization as default +Some comparisons using parameter scans maxabs normalization as default. piecewise polynomial cases use 2 segments. I only +did one run each. ``` python3 examples/invariant_mnist.py -m mlp.n=2,3,4,5,6 mlp.hidden.width=128 mlp.layer_type=polynomial optimizer=sophia ``` diff --git a/high_order_layers_torch/layers.py b/high_order_layers_torch/layers.py index 174e0d4..69f0426 100644 --- a/high_order_layers_torch/layers.py +++ b/high_order_layers_torch/layers.py @@ -2,6 +2,7 @@ import torch.nn as nn from torch.nn import Linear +import torch.nn.functional as F from .FunctionalConvolution import * from .FunctionalConvolutionTranspose import * @@ -17,6 +18,28 @@ ) +class LazyLayerNormLastDim(nn.Module): + """ + Lazily initialize the layer norm to the last dimension of the input + variable. Assumes dimension remains constant. + """ + + def __init__(self, bias=True): + super().__init__() + self.weight = None + self.bias = bias + + def forward(self, input): + if self.weight is None: + ndim = input.shape[-1] + self.weight = nn.Parameter(torch.ones(ndim)).to(input.device) + self.bias = ( + nn.Parameter(torch.zeros(ndim)).to(input.device) if self.bias else None + ) + + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) + + class MaxAbsNormalizationLast(nn.Module): """ Normalize the last dimension of the input variable @@ -101,7 +124,9 @@ def forward(self, x): normalization_layers = { "max_abs": MaxAbsNormalization, + "max_center": MaxCenterNormalization, "l2": L2Normalization, + "layer_norm": LazyLayerNormLastDim, }