diff --git a/README.md b/README.md index e553d4e2..2b766aec 100644 --- a/README.md +++ b/README.md @@ -1317,7 +1317,7 @@ loss.backward() ## Miscellaneous -Cross Attention +### Cross Attention ```python import torch @@ -1337,7 +1337,7 @@ model(nodes, context = encoded_neighbors, mask = node_masks, context_mask = neig ``` -Pass in continuous values +### Continuous Embeddings ```python import torch @@ -1397,6 +1397,59 @@ start_emb = torch.randn(1, 777) generated = model.generate(start_emb, 17) # (17, 777) ``` +### xVal - Continuous and Discrete + + + +This is promising work that resulted from the collaboration across many institutes (collectively known as Polymathic AI). They found that by offering a continuously scaled number token to the transformer, the transformer was able to generalize arithmetic and forecasting tasks better than the alternative encoding schemes. + +```python +import torch + +from x_transformers import ( + Decoder, + XValTransformerWrapper, + XValAutoregressiveWrapper +) + +model = XValTransformerWrapper( + num_tokens = 4, + numerical_token_id = 3, + max_seq_len = 1024, + attn_layers = Decoder( + dim = 512, + depth = 12, + heads = 8 + ) +) + +# wrap it with the xval autoregressive wrapper + +model = XValAutoregressiveWrapper(model) + +# mock data + +ids = torch.randint(0, 4, (1, 777)) +nums = torch.randn(1, 777) +mask = torch.ones(1, 777).bool() + +# train on a lot of data above + +loss = model(ids, nums, mask = mask) +loss.backward() + +# then generate + +start_ids = torch.randint(0, 4, (1, 1)) +start_nums = torch.randn(1, 1) + +ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17) # (17, 777) + +# (1, 17), (1, 17), (1, 17) + +# discrete, continuous, mask for discrete / continuous +``` + ## Citations ```bibtex diff --git a/images/xval.png b/images/xval.png new file mode 100644 index 00000000..090e92f5 Binary files /dev/null and b/images/xval.png differ diff --git a/setup.py b/setup.py index 9f811a5f..5d3b8276 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.24.0', + version = '1.24.2', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/xval.py b/x_transformers/xval.py index 9574356c..4e1fd25f 100644 --- a/x_transformers/xval.py +++ b/x_transformers/xval.py @@ -57,7 +57,6 @@ def __init__( tie_embedding = False, max_mem_len = 0, num_memory_tokens = None, - post_emb_norm = False, emb_dropout = 0., use_abs_pos_emb = True, scaled_sinu_pos_emb = False @@ -82,7 +81,6 @@ def __init__( else: self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) - self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity() self.emb_dropout = nn.Dropout(emb_dropout) # memory tokens @@ -136,8 +134,6 @@ def forward( x = x + self.pos_emb(x, pos = pos) - x = self.post_emb_norm(x) - # memory tokens if self.has_memory_tokens: