Skip to content

Commit

Permalink
xval document (#196)
Browse files Browse the repository at this point in the history
* add xval wrapper and autoregressive wrapper

* remove post emb norm option in xval wrapper, may not make sense

* document xval
  • Loading branch information
lucidrains committed Oct 19, 2023
1 parent d4cc232 commit 61e8a21
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 7 deletions.
57 changes: 55 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1317,7 +1317,7 @@ loss.backward()

## Miscellaneous

Cross Attention
### Cross Attention

```python
import torch
Expand All @@ -1337,7 +1337,7 @@ model(nodes, context = encoded_neighbors, mask = node_masks, context_mask = neig

```

Pass in continuous values
### Continuous Embeddings

```python
import torch
Expand Down Expand Up @@ -1397,6 +1397,59 @@ start_emb = torch.randn(1, 777)
generated = model.generate(start_emb, 17) # (17, 777)
```

### xVal - Continuous and Discrete

<img src="./images/xval.png" width="400px"></img>

This is promising work that resulted from the collaboration across many institutes (collectively known as Polymathic AI). They found that by offering a continuously scaled number token to the transformer, the transformer was able to generalize arithmetic and forecasting tasks better than the alternative encoding schemes.

```python
import torch

from x_transformers import (
Decoder,
XValTransformerWrapper,
XValAutoregressiveWrapper
)

model = XValTransformerWrapper(
num_tokens = 4,
numerical_token_id = 3,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8
)
)

# wrap it with the xval autoregressive wrapper

model = XValAutoregressiveWrapper(model)

# mock data

ids = torch.randint(0, 4, (1, 777))
nums = torch.randn(1, 777)
mask = torch.ones(1, 777).bool()

# train on a lot of data above

loss = model(ids, nums, mask = mask)
loss.backward()

# then generate

start_ids = torch.randint(0, 4, (1, 1))
start_nums = torch.randn(1, 1)

ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17) # (17, 777)

# (1, 17), (1, 17), (1, 17)

# discrete, continuous, mask for discrete / continuous
```

## Citations

```bibtex
Expand Down
Binary file added images/xval.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.24.0',
version = '1.24.2',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
4 changes: 0 additions & 4 deletions x_transformers/xval.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
tie_embedding = False,
max_mem_len = 0,
num_memory_tokens = None,
post_emb_norm = False,
emb_dropout = 0.,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False
Expand All @@ -82,7 +81,6 @@ def __init__(
else:
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)

self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
self.emb_dropout = nn.Dropout(emb_dropout)

# memory tokens
Expand Down Expand Up @@ -136,8 +134,6 @@ def forward(

x = x + self.pos_emb(x, pos = pos)

x = self.post_emb_norm(x)

# memory tokens

if self.has_memory_tokens:
Expand Down

0 comments on commit 61e8a21

Please sign in to comment.