Skip to content

Commit

Permalink
Add layer norm
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Jun 6, 2024
1 parent f8e4458 commit c307da6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ python examples/variational_autoencoder.py -m
```

## Invariant MNIST (fully connected)
Some comparisons using parameter scans maxabs normalization as default
Some comparisons using parameter scans maxabs normalization as default. piecewise polynomial cases use 2 segments. I only
did one run each.
```
python3 examples/invariant_mnist.py -m mlp.n=2,3,4,5,6 mlp.hidden.width=128 mlp.layer_type=polynomial optimizer=sophia
```
Expand Down
25 changes: 25 additions & 0 deletions high_order_layers_torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F

from .FunctionalConvolution import *
from .FunctionalConvolutionTranspose import *
Expand All @@ -17,6 +18,28 @@
)


class LazyLayerNormLastDim(nn.Module):
"""
Lazily initialize the layer norm to the last dimension of the input
variable. Assumes dimension remains constant.
"""

def __init__(self, bias=True):
super().__init__()
self.weight = None
self.bias = bias

def forward(self, input):
if self.weight is None:
ndim = input.shape[-1]
self.weight = nn.Parameter(torch.ones(ndim)).to(input.device)
self.bias = (
nn.Parameter(torch.zeros(ndim)).to(input.device) if self.bias else None
)

return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class MaxAbsNormalizationLast(nn.Module):
"""
Normalize the last dimension of the input variable
Expand Down Expand Up @@ -101,7 +124,9 @@ def forward(self, x):

normalization_layers = {
"max_abs": MaxAbsNormalization,
"max_center": MaxCenterNormalization,
"l2": L2Normalization,
"layer_norm": LazyLayerNormLastDim,
}


Expand Down

0 comments on commit c307da6

Please sign in to comment.