forked from Dao-AILab/flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
139 changed files
with
5,699 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile | ||
# ARG COMPAT=0 | ||
ARG PERSONAL=0 | ||
# FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0 | ||
FROM nvcr.io/nvidia/pytorch:22.11-py3 as base | ||
|
||
ENV HOST docker | ||
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 | ||
# https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes | ||
ENV TZ America/Los_Angeles | ||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone | ||
|
||
# git for installing dependencies | ||
# tzdata to set time zone | ||
# wget and unzip to download data | ||
# [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment. | ||
# [2021-12-07] TD: openmpi-bin for MPI (multi-node training) | ||
RUN apt-get update && apt-get install -y --no-install-recommends \ | ||
build-essential \ | ||
cmake \ | ||
curl \ | ||
ca-certificates \ | ||
sudo \ | ||
less \ | ||
htop \ | ||
git \ | ||
tzdata \ | ||
wget \ | ||
tmux \ | ||
zip \ | ||
unzip \ | ||
zsh stow subversion fasd \ | ||
&& rm -rf /var/lib/apt/lists/* | ||
# openmpi-bin \ | ||
|
||
# Allow running runmpi as root | ||
# ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 | ||
|
||
# # Create a non-root user and switch to it | ||
# RUN adduser --disabled-password --gecos '' --shell /bin/bash user \ | ||
# && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user | ||
# USER user | ||
|
||
# All users can use /home/user as their home directory | ||
ENV HOME=/home/user | ||
RUN mkdir -p /home/user && chmod 777 /home/user | ||
WORKDIR /home/user | ||
|
||
# Set up personal environment | ||
# FROM base-${COMPAT} as env-0 | ||
FROM base as env-0 | ||
FROM env-0 as env-1 | ||
# Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image | ||
# https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile | ||
ONBUILD COPY dotfiles ./dotfiles | ||
ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami) | ||
# nvcr pytorch image sets SHELL=/bin/bash | ||
ONBUILD ENV SHELL=/bin/zsh | ||
|
||
FROM env-${PERSONAL} as packages | ||
|
||
# Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for | ||
ENV PIP_NO_CACHE_DIR=1 | ||
|
||
# # apex and pytorch-fast-transformers take a while to compile so we install them first | ||
# TD [2022-04-28] apex is already installed. In case we need a newer commit: | ||
# RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex | ||
# TD [2021-10-28] pytorch-fast-transformers doesn't have a wheel compatible with CUDA 11.3 and Pytorch 1.10 | ||
# So we install from source, and change compiler flag -arch=compute_60 -> -arch=compute_70 for V100 | ||
# RUN pip install pytorch-fast-transformers==0.4.0 | ||
# RUN pip install git+git://github.com/idiap/[email protected] # doesn't work on V100 | ||
RUN git clone https://github.com/idiap/fast-transformers \ | ||
&& sed -i 's/\["-arch=compute_60"\]/\["-arch=compute_70"\]/' fast-transformers/setup.py \ | ||
&& pip install fast-transformers/ \ | ||
&& rm -rf fast-transformers | ||
|
||
# xgboost conflicts with deepspeed | ||
RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.5 | ||
|
||
# General packages that we don't care about the version | ||
# zstandard to extract the_pile dataset | ||
# psutil to get the number of cpu physical cores | ||
# twine to upload package to PyPI | ||
# ninja is broken for some reason, it returns error code 245 | ||
RUN pip uninstall -y ninja && pip install ninja | ||
RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine \ | ||
&& python -m spacy download en_core_web_sm | ||
# hydra | ||
RUN pip install hydra-core==1.2.0 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich | ||
# Core packages | ||
RUN pip install transformers==4.24.0 datasets==2.7.1 pytorch-lightning==1.7.7 triton==2.0.0.dev20221120 wandb==0.13.5 timm==0.6.12 torchmetrics==0.10.3 | ||
|
||
# For MLPerf | ||
RUN pip install git+https://github.com/mlcommons/[email protected] | ||
|
||
# Install FlashAttention | ||
RUN pip install flash-attn==0.2.2 | ||
|
||
# Install CUDA extensions for cross-entropy, fused dense, layer norm | ||
RUN git clone https://github.com/HazyResearch/flash-attention \ | ||
&& cd flash-attention && git checkout v0.2.2 \ | ||
&& cd csrc/fused_softmax && pip install . && cd ../../ \ | ||
&& cd csrc/rotary && pip install . && cd ../../ \ | ||
&& cd csrc/xentropy && pip install . && cd ../../ \ | ||
&& cd csrc/layer_norm && pip install . && cd ../../ \ | ||
&& cd csrc/fused_dense_lib && pip install . && cd ../../ \ | ||
&& cd .. && rm -rf flash-attention |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
Examples of how FlashAttention can be integrated into a model (e.g., GPT, ViT) | ||
and trained end-to-end. | ||
We also added optimized implementations of other layers (e.g., MLP, LayerNorm, | ||
cross-entropy loss, rotary embedding). | ||
|
||
Goals: | ||
- Performance: we optimize for model speed and memory, especially on 1-node | ||
(e.g., with 8 A100s). | ||
- Flexibility: we provide optimized building blocks (MLP, attention, LayerNorm), | ||
and the model code illustrates how these components can be put together. | ||
The training code also aims to be model- & task-agnostic. | ||
|
||
Non-goals (and other resources): | ||
- Support as many models as possible: Huggingface's | ||
[transformers](https://github.com/huggingface/transformers) and | ||
[timm](https://github.com/rwightman/pytorch-image-models/) are great for this. | ||
- Large-scale distributed training: our codebase has been used for multi-GPU and multi-node | ||
training for models up to 2.7B parameters. However, if you're looking for large-scale distributed | ||
training techniques (e.g., pipeline parallelism, tensor parallelism), | ||
check out [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/) and | ||
[DeepSpeed](https://github.com/microsoft/deepspeed). | ||
- Inference: we currently focus on training (this might change in the future). | ||
If you want fast inference, take a look at | ||
[FasterTransformer](https://github.com/NVIDIA/FasterTransformer). | ||
- Production: this codebase was written during several research projects to validate ideas | ||
on speeding up ML models. | ||
|
||
## Model Components | ||
|
||
The GPT model is implemented | ||
[here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py). | ||
|
||
We provide the following optimized components: | ||
|
||
- FlashAttention: fast and memory-efficient exact attention. This makes | ||
attention much faster and saves a lot of activation memory. As a result we don't need | ||
to use any activation checkpointing. | ||
```sh | ||
pip install flash-attn | ||
``` | ||
|
||
- Fused matmul + bias (forward and backward), and fused matmul + bias + gelu | ||
(forward and backward), adapted from Apex's | ||
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We | ||
make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before | ||
this doesn't have the best matmul + bias + gelu performance for bfloat16. | ||
```sh | ||
cd ../csrc/fused_dense_lib && pip install . | ||
``` | ||
- Optimized cross-entropy loss, adapted from Apex's | ||
[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory. | ||
```sh | ||
cd ../csrc/xentropy && pip install . | ||
``` | ||
- Fused rotary embedding: | ||
```sh | ||
cd ../csrc/rotary && pip install . | ||
``` | ||
- Fused dropout + residual + LayerNorm, adapted from Apex's | ||
[FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture. | ||
This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`. | ||
```sh | ||
cd ../csrc/layer_norm && pip install . | ||
``` | ||
|
||
## Training | ||
|
||
Feel free to use the model in your training setup. We also provide here training | ||
scripts to train GPT2 on Openwebtext and GPT3 on The Pile as examples. | ||
|
||
We use [Hydra](https://hydra.cc/) for configuration, | ||
[Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and | ||
[Wandb](https://wandb.ai/) for logging. | ||
|
||
We use the template from `https://github.com/ashleve/lightning-hydra-template`. | ||
Please read the instructions there to understand the repo structure. | ||
|
||
### Dataset preparation | ||
|
||
Running the training command would automatically download the datasets | ||
(Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the | ||
tokens, then save this cache to disk. Alternatively, you can also prepare the | ||
datasets as a separate steps. | ||
|
||
The cached datasets are saved to `${DATA_DIR}/openwebtext` and | ||
`${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to | ||
`./data/{openwebtext,the_pile}`. | ||
|
||
- Openwebtext: | ||
```sh | ||
export PYTHONPATH=$PWD:$PYTHONPATH | ||
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "openwebtext" | ||
``` | ||
This takes around 1h on a 64-core CPU. The processed dataset has size 17GB. | ||
|
||
- The Pile: | ||
```sh | ||
export PYTHONPATH=$PWD:$PYTHONPATH | ||
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile" | ||
``` | ||
This takes around 20h on a 96-core CPU. The processed dataset has size 699GB. | ||
|
||
### GPT2 training on Openwebtext | ||
To train GPT2 on Openwebtext with 8 GPUs: | ||
```sh | ||
python run.py experiment=owt/gpt2s-flash trainer.devices=8 | ||
python run.py experiment=owt/gpt2m-flash trainer.devices=8 | ||
python run.py experiment=owt/gpt2l-flash trainer.devices=8 | ||
python run.py experiment=owt/gpt2xl-flash trainer.devices=8 | ||
``` | ||
The default parameters are set for 8 x A100 80GB. | ||
|
||
To train with bf16 instead of fp16, add `trainer.precision=bf16`. | ||
To adjust device batch size to fit GPU memory (the global batch size stays the | ||
same, and gradient accumulation is calculated automatically), set `datamodule.batch_size=blah`. | ||
|
||
### GPT3 training on The Pile | ||
To train GPT3 on The Pile with 8 GPUs: | ||
```sh | ||
python run.py experiment=pile/gpt3s-flash trainer.devices=8 | ||
python run.py experiment=pile/gpt3m-flash trainer.devices=8 | ||
python run.py experiment=pile/gpt3l-flash trainer.devices=8 | ||
python run.py experiment=pile/gpt3xl-flash trainer.devices=8 | ||
``` | ||
The default parameters are set for 8 x A100 80GB. | ||
|
||
## Requirements | ||
|
||
Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core, | ||
hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn. | ||
We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) | ||
|
||
We provide a Dockerfile that lists all the required packages. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
causality-monitor: | ||
_target_: src.callbacks.causality_monitor.CausalityMonitor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# rich_progress_bar: | ||
# _target_: pytorch_lightning.callbacks.RichProgressBar | ||
|
||
rich_model_summary: | ||
_target_: pytorch_lightning.callbacks.RichModelSummary | ||
|
||
model_checkpoint: | ||
_target_: pytorch_lightning.callbacks.ModelCheckpoint | ||
monitor: "val/acc" # name of the logged metric which determines when model is improving | ||
mode: "max" # can be "max" or "min" | ||
save_top_k: 1 # save k best models (determined by above metric) | ||
save_last: True # additionaly always save model from last epoch | ||
verbose: False | ||
dirpath: ${oc.env:CHECKPOINT_DIR,checkpoints}/${oc.select:name,''} | ||
filename: "epoch_{epoch:03d}" | ||
auto_insert_metric_name: False | ||
|
||
early_stopping: | ||
_target_: pytorch_lightning.callbacks.EarlyStopping | ||
monitor: "val/acc" # name of the logged metric which determines when model is improving | ||
mode: "max" # can be "max" or "min" | ||
patience: 100 # how many epochs of not improving until training stops | ||
min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement | ||
|
||
learning_rate_monitor: | ||
_target_: pytorch_lightning.callbacks.LearningRateMonitor | ||
logging_interval: step | ||
|
||
speed_monitor: | ||
_target_: src.callbacks.speed_monitor.SpeedMonitor | ||
intra_step_time: True | ||
inter_step_time: True | ||
epoch_time: True | ||
|
||
loss_scale_monitor: | ||
_target_: src.callbacks.loss_scale_monitor.LossScaleMonitor | ||
|
||
params_log: | ||
_target_: src.callbacks.params_log.ParamsLog | ||
total_params_log: True | ||
trainable_params_log: True | ||
non_trainable_params_log: True | ||
|
||
gpu_affinity: | ||
_target_: src.callbacks.gpu_affinity.GpuAffinity |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
ema: | ||
_target_: src.callbacks.ema.EMACallback | ||
decay: ??? | ||
use_num_updates: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
flop_count: | ||
_target_: src.callbacks.flop_count.FlopCount | ||
profilers: ['fvcore'] | ||
input_size: [3, 224, 224] | ||
device: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
defaults: | ||
- default.yaml | ||
|
||
gpu_stats_monitor: | ||
_target_: pytorch_lightning.callbacks.GPUStatsMonitor | ||
# [2021-08-13] TD: I just want the intra_step_size but it'll error if I | ||
# don't have memory_utilization and gpu_utilization. | ||
# Maybe I should write a callback with just the intra_step_size. | ||
memory_utilization: True | ||
gpu_utilization: True | ||
intra_step_time: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
model_summary: | ||
_target_: pytorch_lightning.callbacks.RichModelSummary |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
norm_monitor: | ||
_target_: src.callbacks.norm_monitor.NormMonitor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
params_log: | ||
_target_: src.callbacks.params_log.ParamsLog | ||
total_params_log: True | ||
trainable_params_log: True | ||
non_trainable_params_log: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
defaults: | ||
- default.yaml | ||
|
||
watch_model: | ||
_target_: src.callbacks.wandb_callbacks.WatchModel | ||
log: "all" | ||
log_freq: 100 | ||
|
||
upload_code_as_artifact: | ||
_target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact | ||
code_dir: ${work_dir}/src | ||
|
||
upload_ckpts_as_artifact: | ||
_target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact | ||
ckpt_dir: "checkpoints/" | ||
upload_best_only: True | ||
|
||
log_f1_precision_recall_heatmap: | ||
_target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap | ||
|
||
log_confusion_matrix: | ||
_target_: src.callbacks.wandb_callbacks.LogConfusionMatrix | ||
|
||
log_image_predictions: | ||
_target_: src.callbacks.wandb_callbacks.LogImagePredictions | ||
num_samples: 8 |
Oops, something went wrong.