Skip to content

Latest commit

 

History

History
56 lines (41 loc) · 2.95 KB

README.md

File metadata and controls

56 lines (41 loc) · 2.95 KB

Alpa

Documentation | Slack

CI Build Jaxlib

Alpa is a system for training large-scale neural networks. Scaling neural networks to hundreds of billions of parameters has enabled dramatic breakthroughs such as GPT-3, but training these large-scale neural networks requires complicated distributed training techniques. Alpa aims to automate large-scale distributed training with just a few lines of code.

The key features of Alpa include:

💻 Automatic Parallelization. Alpa automatically parallelizes users' single-device code on distributed clusters with data, operator, and pipeline parallelism.

🚀 Excellent Performance. Alpa achieves linear scaling on training models with billions of parameters on distributed clusters.

Tight Integration with Machine Learning Ecosystem. Alpa is backed by open-source, high-performance, and production-ready libraries such as Jax, XLA, and Ray

Quick Start

Use Alpa's decorator @parallelize to scale your single-device training code to distributed clusters.

import alpa

# Parallelize the training step in Jax by simply using a decorator
@alpa.parallelize
def train_step(model_state, batch):
    def loss_func(params):
        out = model_state.forward(params, batch["x"])
        return jnp.mean((out - batch["y"]) ** 2)

    grads = grad(loss_func)(model_state.params)
    new_model_state = model_state.apply_gradient(grads)
    return new_model_state

# The training loop now automatically runs on your designated cluster
model_state = create_train_state()
for batch in data_loader:
    model_state = train_step(model_state, batch)

Check out the Alpa Documentation site for installation instructions, tutorials, examples, and more.

More Information

Getting Involved

License

Alpa is licensed under the Apache-2.0 license.