Skip to content

Commit

Permalink
add a norm for lime
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 22, 2025
1 parent 5c53e58 commit 9b8b489
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "x-transformers"
version = "2.1.1"
version = "2.1.2"
description = "X-Transformers"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -37,7 +37,6 @@ Repository = "https://github.com/lucidrains/x-transformers"
examples = [
"lion-pytorch",
"tqdm",
"torchvision"
]

test = [
Expand Down
3 changes: 3 additions & 0 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,13 +970,15 @@ def __init__(
dim,
num_layers,
num_views = 1,
norm = True,
use_softmax = True
):
super().__init__()
self.num_layers = num_layers
self.multiple_views = num_views > 1

self.to_weights = Sequential(
RMSNorm(dim) if norm else None,
nn.Linear(dim, num_views * num_layers),
Rearrange('... (views layers) -> views ... layers', views = num_views),
nn.Softmax(dim = -1) if use_softmax else nn.ReLU()
Expand All @@ -987,6 +989,7 @@ def forward(
x,
hiddens
):

if not is_tensor(hiddens):
hiddens = stack(hiddens)

Expand Down

0 comments on commit 9b8b489

Please sign in to comment.