Skip to content

Commit

Permalink
add dynamic LIMe from Gerasimov et al., making sure it is compatible …
Browse files Browse the repository at this point in the history
…with hyper connections
  • Loading branch information
lucidrains committed Feb 22, 2025
1 parent abee1d3 commit 5c53e58
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 14 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2378,4 +2378,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
}
```

```bibtex
@inproceedings{Gerasimov2025YouDN,
title = {You Do Not Fully Utilize Transformer's Representation Capacity},
author = {Gleb Gerasimov and Yaroslav Aksenov and Nikita Balagansky and Viacheslav Sinii and Daniil Gavrilov},
year = {2025},
url = {https://api.semanticscholar.org/CorpusID:276317819}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "x-transformers"
version = "2.0.5"
version = "2.1.1"
description = "X-Transformers"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
22 changes: 22 additions & 0 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,25 @@ def test_multi_latent_attention():
x = torch.randint(0, 20000, (2, 1024))

model(x)

@pytest.mark.parametrize('num_residual_streams', (1, 4))
@pytest.mark.parametrize('integrate_layers', (False, True))
def test_lime(
num_residual_streams,
integrate_layers
):
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 128,
depth = 6,
heads = 8,
num_residual_streams = num_residual_streams,
integrate_layers = integrate_layers
)
)

x = torch.randint(0, 20000, (2, 1024))

model(x)
99 changes: 86 additions & 13 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch.amp import autocast
import torch.nn.functional as F
from torch import nn, einsum, Tensor, cat, stack, arange
from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.nn import Module, ModuleList, ModuleDict

Expand Down Expand Up @@ -962,6 +962,45 @@ def forward(self, x, residuals, *, beta):
residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
return rearrange(residuals, 'b n s d -> (b s) n d')

# LIMe - layer integrated memory (dynamic version)

class DynamicLIMe(Module):
def __init__(
self,
dim,
num_layers,
num_views = 1,
use_softmax = True
):
super().__init__()
self.num_layers = num_layers
self.multiple_views = num_views > 1

self.to_weights = Sequential(
nn.Linear(dim, num_views * num_layers),
Rearrange('... (views layers) -> views ... layers', views = num_views),
nn.Softmax(dim = -1) if use_softmax else nn.ReLU()
)

def forward(
self,
x,
hiddens
):
if not is_tensor(hiddens):
hiddens = stack(hiddens)

assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'

weights = self.to_weights(x)

out = einsum('l b n d, v b n l -> v b n d', hiddens, weights)

if self.multiple_views:
return out

return rearrange(out, '1 ... -> ...')

# token shifting

def shift(t, amount, mask = None):
Expand Down Expand Up @@ -1306,7 +1345,7 @@ def __init__(

self.merge_heads = Rearrange('b h n d -> b n (h d)')

# whether qkv receives different residual stream combinations from hyper connections
# whether qkv receives different residual stream combinations from hyper connections or lime

self.qkv_receive_diff_residuals = qkv_receive_diff_residuals

Expand Down Expand Up @@ -1869,6 +1908,8 @@ def __init__(
use_layerscale = False,
layerscale_init_value = 0.,
unet_skips = False,
integrate_layers = False,
layer_integrate_use_softmax = True,
num_residual_streams = 1,
qkv_receive_diff_residuals = False,
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
Expand All @@ -1895,16 +1936,30 @@ def __init__(
self.causal = causal
self.layers = ModuleList([])

# greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
# routing related
# 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
# 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245

qkv_receive_diff_residuals |= integrate_layers # qkv always receives different views if integrating layers

# hyper connections

assert num_residual_streams > 0
has_hyper_connections = num_residual_streams > 1

self.num_residual_streams = num_residual_streams
self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None

assert not (num_residual_streams > 1 and gate_residual)
assert not (has_hyper_connections and gate_residual)

hyper_conn_produce_diff_views = qkv_receive_diff_residuals and not integrate_layers

# LIMe

assert not (num_residual_streams == 1 and qkv_receive_diff_residuals)
hiddens_counter = 0
self.layer_integrators = ModuleList([])

assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))

# positions related

Expand Down Expand Up @@ -2145,16 +2200,22 @@ def __init__(

# attention, cross attention, feedforward

layer_qkv_receives_diff_view = layer_type == 'a' and qkv_receive_diff_residuals and not (is_first_self_attn and integrate_layers)

if layer_type == 'a':
self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)

layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = layer_qkv_receives_diff_view, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
is_first_self_attn = False

elif layer_type == 'c':
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
is_first_cross_attn = False

elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)

else:
raise Exception(f'invalid layer type {layer_type}')

Expand All @@ -2166,10 +2227,18 @@ def __init__(
if exists(post_branch_fn):
layer = post_branch_fn(layer)

if num_residual_streams > 1:
layer_integrate = None

if integrate_layers:
num_layer_hiddens = ind + 1
layer_integrate_num_view = 3 if layer_qkv_receives_diff_view else 1

layer_integrate = DynamicLIMe(dim, num_layer_hiddens, num_views = layer_integrate_num_view, use_softmax = layer_integrate_use_softmax)

if has_hyper_connections:
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)

if layer_type == 'a' and qkv_receive_diff_residuals:
if layer_type == 'a' and hyper_conn_produce_diff_views:
residual_fn = partial(residual_fn, num_input_views = 3)

elif gate_residual:
Expand Down Expand Up @@ -2201,6 +2270,8 @@ def __init__(

self.skip_combines.append(skip_combine)

self.layer_integrators.append(layer_integrate)

self.layers.append(ModuleList([
norms,
layer,
Expand Down Expand Up @@ -2341,13 +2412,13 @@ def forward(
self.layer_types,
self.skip_combines,
self.layers,
self.layer_dropouts
self.layer_dropouts,
self.layer_integrators
)

# able to override the layers execution order on forward, for trying to depth extrapolate

layers_execute_order = default(layers_execute_order, self.layers_execute_order)

layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)

# derived input for reinjection if needed
Expand Down Expand Up @@ -2377,7 +2448,7 @@ def forward(

# go through the attention and feedforward layers

for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout, layer_integrator) in enumerate(zip(*layer_variables)):
is_last = ind == (len(self.layers) - 1)

# handle skip connections
Expand Down Expand Up @@ -2405,8 +2476,10 @@ def forward(

x, inner_residual, residual_kwargs = residual_fn.prepare(x)

if return_hiddens:
layer_hiddens.append(x)
layer_hiddens.append(x)

if exists(layer_integrator):
x = layer_integrator(x, layer_hiddens)

pre_norm, post_branch_norm, post_main_norm = norm

Expand Down

0 comments on commit 5c53e58

Please sign in to comment.