Skip to content

Commit

Permalink
add ability to set slopes on alibi as well as custom position across …
Browse files Browse the repository at this point in the history
…heads
  • Loading branch information
lucidrains committed Nov 8, 2024
1 parent 55148ba commit 144d9ba
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.8',
version = '1.42.9',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
21 changes: 21 additions & 0 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def test_neo_mlp():
assert out.shape == (3, 7)

def test_custom_alibi():

model = TransformerWrapper(
num_tokens = 20_000,
max_seq_len = 1024,
Expand All @@ -398,6 +399,26 @@ def test_custom_alibi():

logits = model(x, pos = pos)

def test_custom_alibi_across_heads():

model = Decoder(
dim = 512,
depth = 2,
heads = 2,
alibi_pos_bias = True,
rel_pos_kwargs = dict(
slopes = [1, 1]
),
)

x = torch.randn(2, 4, 512)

pos = torch.tensor([
[[0, 1, 2, 4], [1, 3, 5, 7]],
[[2, 3, 4, 5], [6, 8, 9, 10]]
])

embed = model(x, pos = pos)

@pytest.mark.parametrize('embedder_type', ('embedding', 'none', 'custom'))
def test_embedder(embedder_type):
Expand Down
27 changes: 19 additions & 8 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,13 +452,20 @@ def forward(self, i, j):
return bias

class AlibiPositionalBias(Module):
def __init__(self, heads, total_heads = None, **kwargs):
def __init__(
self,
heads,
total_heads = None,
slopes: list[int] | None = None,
**kwargs
):
super().__init__()
self.heads = heads
self.total_heads = default(total_heads, heads)

slopes = Tensor(self._get_slopes(heads))
slopes = Tensor(default(slopes, self._get_slopes(heads)))
slopes = rearrange(slopes, 'h -> h 1 1')

self.register_buffer('slopes', slopes, persistent = False)
self.register_buffer('bias', None, persistent = False)

Expand Down Expand Up @@ -487,7 +494,10 @@ def forward_custom_pos(
h, device = self.total_heads, self.device

pos_j = default(pos_j, pos_i)
bias = -einx.subtract('... j, ... i -> ... 1 i j', pos_j, pos_i).abs()
bias = -einx.subtract('... j, ... i -> ... i j', pos_j, pos_i).abs()

if bias.ndim == 3:
bias = rearrange(bias, 'b i j -> b 1 i j')

bias = bias * self.slopes
num_heads_unalibied = h - bias.shape[-3]
Expand Down Expand Up @@ -1531,8 +1541,9 @@ def __init__(
use_layerscale = False,
layerscale_init_value = 0.,
unet_skips = False,
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
rel_pos_kwargs: dict = dict(),
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -1573,14 +1584,14 @@ def __init__(

if rel_pos_bias:
assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance, **rel_pos_kwargs)
elif dynamic_pos_bias:
assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm, **rel_pos_kwargs)
elif alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)

assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
Expand Down

0 comments on commit 144d9ba

Please sign in to comment.