Skip to content

pszemraj/samba-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

f4d5f6d Β· Nov 22, 2024

History

12 Commits
Nov 22, 2024
Nov 22, 2024
Nov 22, 2024
Nov 22, 2024
Nov 22, 2024
Nov 22, 2024
Nov 22, 2024

Repository files navigation

samba-pytorch

Implementation of Samba by Microsoft in PyTorch.

This aims to be a simpler implementation of the original repo.

Installation

Tip

The pip install command should install all dependencies and the package, but some CUDA-heavy dependencies are better installed separately. See below for more details.

git clone https://github.com/pszemraj/samba-pytorch.git
cd samba-pytorch
pip install -e .

Installing custom kernel packages first

After installing torch, xformers, and flash-attn, you may want to install mamba-ssm, causal-conv1d, and fla from source:

pip install --upgrade pip ninja
pip install git+https://github.com/state-spaces/mamba.git --no-build-isolation
pip install git+https://github.com/Dao-AILab/causal-conv1d.git --no-build-isolation
pip install git+https://github.com/sustcsonglin/flash-linear-attention@98c176e --no-build-isolation

Then, clone this repo and run commands as above.

Usage

A basic example of creating a random model from a named config:

from samba_pytorch import Config, GPT
cfg = Config.from_name('Samba_421M_1k_window')
print*(cfg)
model = GPT(cfg)
model

Training

A minimalist training script for a character-level language model on enwiki8:

python train.py

Credit to nGPT-pytorch for the enwik8 data set and the training script, which has been adapted for this repo.

repo structure

samba-pytorch/
β”œβ”€β”€ pyproject.toml
β”œβ”€β”€ README.md
β”œβ”€β”€ samba_pytorch/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ config.py
β”‚   β”œβ”€β”€ modules/
β”‚   β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚   β”œβ”€β”€ fused_rotary_embedding.py
β”‚   β”‚   β”œβ”€β”€ gla.py
β”‚   β”‚   β”œβ”€β”€ mamba_simple.py
β”‚   β”‚   β”œβ”€β”€ multiscale_retention.py
β”‚   β”‚   β”œβ”€β”€ rmsnorm.py
β”‚   β”‚   └── rotary.py
β”‚   β”œβ”€β”€ samba.py
β”‚   β”œβ”€β”€ tokenizer.py
β”‚   └── utils.py

Citations

@article{ren2024samba,
      title={Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling},
      author={Liliang Ren and Yang Liu and Yadong Lu and Yelong Shen and Chen Liang and Weizhu Chen},
      journal = {arXiv preprint},
      year={2024},
      url={https://arxiv.org/abs/2406.07522}
}