Skip to content

Official JAX implementation of xLSTM including fast and efficient training and inference code. 7B model available at https://huggingface.co/NX-AI/xLSTM-7b.

License

Notifications You must be signed in to change notification settings

NX-AI/xlstm-jax

Repository files navigation

xLSTM-jax

Paper | Model | Docs | Citation

The official repository for the xLSTM model and training code in JAX.

About

This repository contains the code to train and evaluate xLSTM on language modelling using JAX. xLSTM is a new Recurrent Neural Network architecture based on ideas of the original LSTM. Through Exponential Gating with appropriate normalization and stabilization techniques and a new Matrix Memory it overcomes the limitations of the original LSTM and shows promising performance on Language Modeling when compared to Transformers or State Space Models.

This code base supports a 3D parallelization strategy and is optimized for training on large-scale distributed systems with hundreds or thousands of GPUs. We developed performant triton kernels for xLSTM, resulting in much faster training and inference. Our kernels are implemented in this repository and included as a submodule.

xLSTM-7B

We used xlstm-jax to train a 7B parameter xLSTM model on 256 H100 GPUs. The xLSTM-7B shows competitive performance on common benchmarks compared to other 7B LLMs, while achieving much better token throughput for larger sequence lengths. xLSTM Figure

Documentation

The documentation is available at https://xlstm-jax.readthedocs.io/, covering

Contributing

Contributions to this repository are welcome.

  • If you find bugs or have suggestions for improvements, please open an issue with a detailed description of the problem or suggestion.
  • If you want to contribute, please open a pull request with a detailed description of the changes you made.
  • More general questions and discussions can be posted in the Discussions section.

Citation

If you use this codebase, or otherwise find our work valuable, please cite the xLSTM paper and this repository:

@article{xlstm,
  title={xLSTM: Extended Long Short-Term Memory},
  author={Beck, Maximilian and P{\"o}ppel, Korbinian and Spanring, Markus and Auer, Andreas and Prudnikova, Oleksandra and Kopp, Michael and Klambauer, G{\"u}nter and Brandstetter, Johannes and Hochreiter, Sepp},
  journal={arXiv preprint arXiv:2405.04517},
  year={2024}
}
@misc{xlstm-jax,
  title={xLSTM-jax},
  author={NXAI GmbH},
  year={2024},
  url={https://github.com/NX-AI/xlstm-jax/},
}

License

This project is licensed under the NXAI Community License, please see LICENSE.

About

Official JAX implementation of xLSTM including fast and efficient training and inference code. 7B model available at https://huggingface.co/NX-AI/xLSTM-7b.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published