diff --git a/csrc/layer_norm/README.md b/csrc/layer_norm/README.md index c5cd8ad43..79855c8f7 100644 --- a/csrc/layer_norm/README.md +++ b/csrc/layer_norm/README.md @@ -2,6 +2,8 @@ This CUDA extension implements fused dropout + residual + LayerNorm, based on Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture. +This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`. + It has only been tested on A100s. ```sh diff --git a/training/Dockerfile b/training/Dockerfile new file mode 100644 index 000000000..1dbabd5b4 --- /dev/null +++ b/training/Dockerfile @@ -0,0 +1,107 @@ +# Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile +# ARG COMPAT=0 +ARG PERSONAL=0 +# FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0 +FROM nvcr.io/nvidia/pytorch:22.11-py3 as base + +ENV HOST docker +ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 +# https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes +ENV TZ America/Los_Angeles +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +# git for installing dependencies +# tzdata to set time zone +# wget and unzip to download data +# [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment. +# [2021-12-07] TD: openmpi-bin for MPI (multi-node training) +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cmake \ + curl \ + ca-certificates \ + sudo \ + less \ + htop \ + git \ + tzdata \ + wget \ + tmux \ + zip \ + unzip \ + zsh stow subversion fasd \ + && rm -rf /var/lib/apt/lists/* + # openmpi-bin \ + +# Allow running runmpi as root +# ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 + +# # Create a non-root user and switch to it +# RUN adduser --disabled-password --gecos '' --shell /bin/bash user \ +# && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user +# USER user + +# All users can use /home/user as their home directory +ENV HOME=/home/user +RUN mkdir -p /home/user && chmod 777 /home/user +WORKDIR /home/user + +# Set up personal environment +# FROM base-${COMPAT} as env-0 +FROM base as env-0 +FROM env-0 as env-1 +# Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image +# https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile +ONBUILD COPY dotfiles ./dotfiles +ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami) +# nvcr pytorch image sets SHELL=/bin/bash +ONBUILD ENV SHELL=/bin/zsh + +FROM env-${PERSONAL} as packages + +# Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for +ENV PIP_NO_CACHE_DIR=1 + +# # apex and pytorch-fast-transformers take a while to compile so we install them first +# TD [2022-04-28] apex is already installed. In case we need a newer commit: +# RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex +# TD [2021-10-28] pytorch-fast-transformers doesn't have a wheel compatible with CUDA 11.3 and Pytorch 1.10 +# So we install from source, and change compiler flag -arch=compute_60 -> -arch=compute_70 for V100 +# RUN pip install pytorch-fast-transformers==0.4.0 +# RUN pip install git+git://github.com/idiap/fast-transformers.git@v0.4.0 # doesn't work on V100 +RUN git clone https://github.com/idiap/fast-transformers \ + && sed -i 's/\["-arch=compute_60"\]/\["-arch=compute_70"\]/' fast-transformers/setup.py \ + && pip install fast-transformers/ \ + && rm -rf fast-transformers + +# xgboost conflicts with deepspeed +RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.5 + +# General packages that we don't care about the version +# zstandard to extract the_pile dataset +# psutil to get the number of cpu physical cores +# twine to upload package to PyPI +# ninja is broken for some reason, it returns error code 245 +RUN pip uninstall -y ninja && pip install ninja +RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine \ + && python -m spacy download en_core_web_sm +# hydra +RUN pip install hydra-core==1.2.0 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich +# Core packages +RUN pip install transformers==4.24.0 datasets==2.7.1 pytorch-lightning==1.7.7 triton==2.0.0.dev20221120 wandb==0.13.5 timm==0.6.12 torchmetrics==0.10.3 + +# For MLPerf +RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 + +# Install FlashAttention +RUN pip install flash-attn==0.2.2 + +# Install CUDA extensions for cross-entropy, fused dense, layer norm +RUN git clone https://github.com/HazyResearch/flash-attention \ + && cd flash-attention && git checkout v0.2.2 \ + && cd csrc/fused_softmax && pip install . && cd ../../ \ + && cd csrc/rotary && pip install . && cd ../../ \ + && cd csrc/xentropy && pip install . && cd ../../ \ + && cd csrc/layer_norm && pip install . && cd ../../ \ + && cd csrc/fused_dense_lib && pip install . && cd ../../ \ + && cd .. && rm -rf flash-attention diff --git a/training/README.md b/training/README.md new file mode 100644 index 000000000..24ca990a2 --- /dev/null +++ b/training/README.md @@ -0,0 +1,133 @@ +Examples of how FlashAttention can be integrated into a model (e.g., GPT, ViT) +and trained end-to-end. +We also added optimized implementations of other layers (e.g., MLP, LayerNorm, +cross-entropy loss, rotary embedding). + +Goals: +- Performance: we optimize for model speed and memory, especially on 1-node + (e.g., with 8 A100s). +- Flexibility: we provide optimized building blocks (MLP, attention, LayerNorm), + and the model code illustrates how these components can be put together. + The training code also aims to be model- & task-agnostic. + +Non-goals (and other resources): +- Support as many models as possible: Huggingface's + [transformers](https://github.com/huggingface/transformers) and + [timm](https://github.com/rwightman/pytorch-image-models/) are great for this. +- Large-scale distributed training: our codebase has been used for multi-GPU and multi-node + training for models up to 2.7B parameters. However, if you're looking for large-scale distributed + training techniques (e.g., pipeline parallelism, tensor parallelism), + check out [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/) and + [DeepSpeed](https://github.com/microsoft/deepspeed). +- Inference: we currently focus on training (this might change in the future). + If you want fast inference, take a look at + [FasterTransformer](https://github.com/NVIDIA/FasterTransformer). +- Production: this codebase was written during several research projects to validate ideas + on speeding up ML models. + +## Model Components + +The GPT model is implemented +[here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py). + +We provide the following optimized components: + +- FlashAttention: fast and memory-efficient exact attention. This makes +attention much faster and saves a lot of activation memory. As a result we don't need +to use any activation checkpointing. +```sh +pip install flash-attn +``` + +- Fused matmul + bias (forward and backward), and fused matmul + bias + gelu +(forward and backward), adapted from Apex's +[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We +make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before +this doesn't have the best matmul + bias + gelu performance for bfloat16. +```sh +cd ../csrc/fused_dense_lib && pip install . +``` +- Optimized cross-entropy loss, adapted from Apex's +[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory. +```sh +cd ../csrc/xentropy && pip install . +``` +- Fused rotary embedding: +```sh +cd ../csrc/rotary && pip install . +``` +- Fused dropout + residual + LayerNorm, adapted from Apex's +[FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture. +This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`. +```sh +cd ../csrc/layer_norm && pip install . +``` + +## Training + +Feel free to use the model in your training setup. We also provide here training +scripts to train GPT2 on Openwebtext and GPT3 on The Pile as examples. + +We use [Hydra](https://hydra.cc/) for configuration, +[Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and +[Wandb](https://wandb.ai/) for logging. + +We use the template from `https://github.com/ashleve/lightning-hydra-template`. +Please read the instructions there to understand the repo structure. + +### Dataset preparation + +Running the training command would automatically download the datasets +(Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the +tokens, then save this cache to disk. Alternatively, you can also prepare the +datasets as a separate steps. + +The cached datasets are saved to `${DATA_DIR}/openwebtext` and +`${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to +`./data/{openwebtext,the_pile}`. + +- Openwebtext: +```sh +export PYTHONPATH=$PWD:$PYTHONPATH +pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "openwebtext" +``` +This takes around 1h on a 64-core CPU. The processed dataset has size 17GB. + +- The Pile: +```sh +export PYTHONPATH=$PWD:$PYTHONPATH +pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile" +``` +This takes around 20h on a 96-core CPU. The processed dataset has size 699GB. + +### GPT2 training on Openwebtext +To train GPT2 on Openwebtext with 8 GPUs: +```sh +python run.py experiment=owt/gpt2s-flash trainer.devices=8 +python run.py experiment=owt/gpt2m-flash trainer.devices=8 +python run.py experiment=owt/gpt2l-flash trainer.devices=8 +python run.py experiment=owt/gpt2xl-flash trainer.devices=8 +``` +The default parameters are set for 8 x A100 80GB. + +To train with bf16 instead of fp16, add `trainer.precision=bf16`. +To adjust device batch size to fit GPU memory (the global batch size stays the +same, and gradient accumulation is calculated automatically), set `datamodule.batch_size=blah`. + +### GPT3 training on The Pile +To train GPT3 on The Pile with 8 GPUs: +```sh +python run.py experiment=pile/gpt3s-flash trainer.devices=8 +python run.py experiment=pile/gpt3m-flash trainer.devices=8 +python run.py experiment=pile/gpt3l-flash trainer.devices=8 +python run.py experiment=pile/gpt3xl-flash trainer.devices=8 +``` +The default parameters are set for 8 x A100 80GB. + +## Requirements + +Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core, +hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn. +We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) + +We provide a Dockerfile that lists all the required packages. diff --git a/training/configs/callbacks/causality-monitor.yaml b/training/configs/callbacks/causality-monitor.yaml new file mode 100644 index 000000000..fbac5b68e --- /dev/null +++ b/training/configs/callbacks/causality-monitor.yaml @@ -0,0 +1,2 @@ +causality-monitor: + _target_: src.callbacks.causality_monitor.CausalityMonitor \ No newline at end of file diff --git a/training/configs/callbacks/default.yaml b/training/configs/callbacks/default.yaml new file mode 100644 index 000000000..e351d9383 --- /dev/null +++ b/training/configs/callbacks/default.yaml @@ -0,0 +1,45 @@ +# rich_progress_bar: +# _target_: pytorch_lightning.callbacks.RichProgressBar + +rich_model_summary: + _target_: pytorch_lightning.callbacks.RichModelSummary + +model_checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: "val/acc" # name of the logged metric which determines when model is improving + mode: "max" # can be "max" or "min" + save_top_k: 1 # save k best models (determined by above metric) + save_last: True # additionaly always save model from last epoch + verbose: False + dirpath: ${oc.env:CHECKPOINT_DIR,checkpoints}/${oc.select:name,''} + filename: "epoch_{epoch:03d}" + auto_insert_metric_name: False + +early_stopping: + _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: "val/acc" # name of the logged metric which determines when model is improving + mode: "max" # can be "max" or "min" + patience: 100 # how many epochs of not improving until training stops + min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement + +learning_rate_monitor: + _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: step + +speed_monitor: + _target_: src.callbacks.speed_monitor.SpeedMonitor + intra_step_time: True + inter_step_time: True + epoch_time: True + +loss_scale_monitor: + _target_: src.callbacks.loss_scale_monitor.LossScaleMonitor + +params_log: + _target_: src.callbacks.params_log.ParamsLog + total_params_log: True + trainable_params_log: True + non_trainable_params_log: True + +gpu_affinity: + _target_: src.callbacks.gpu_affinity.GpuAffinity diff --git a/training/configs/callbacks/ema.yaml b/training/configs/callbacks/ema.yaml new file mode 100644 index 000000000..d5586db26 --- /dev/null +++ b/training/configs/callbacks/ema.yaml @@ -0,0 +1,4 @@ +ema: + _target_: src.callbacks.ema.EMACallback + decay: ??? + use_num_updates: False diff --git a/training/configs/callbacks/flop-count.yaml b/training/configs/callbacks/flop-count.yaml new file mode 100644 index 000000000..ee45b9158 --- /dev/null +++ b/training/configs/callbacks/flop-count.yaml @@ -0,0 +1,5 @@ +flop_count: + _target_: src.callbacks.flop_count.FlopCount + profilers: ['fvcore'] + input_size: [3, 224, 224] + device: null diff --git a/training/configs/callbacks/gpu-monitor.yaml b/training/configs/callbacks/gpu-monitor.yaml new file mode 100644 index 000000000..6780f6d1c --- /dev/null +++ b/training/configs/callbacks/gpu-monitor.yaml @@ -0,0 +1,11 @@ +defaults: + - default.yaml + +gpu_stats_monitor: + _target_: pytorch_lightning.callbacks.GPUStatsMonitor + # [2021-08-13] TD: I just want the intra_step_size but it'll error if I + # don't have memory_utilization and gpu_utilization. + # Maybe I should write a callback with just the intra_step_size. + memory_utilization: True + gpu_utilization: True + intra_step_time: True diff --git a/training/configs/callbacks/model-summary.yaml b/training/configs/callbacks/model-summary.yaml new file mode 100644 index 000000000..3dba049ad --- /dev/null +++ b/training/configs/callbacks/model-summary.yaml @@ -0,0 +1,2 @@ +model_summary: + _target_: pytorch_lightning.callbacks.RichModelSummary diff --git a/training/configs/callbacks/none.yaml b/training/configs/callbacks/none.yaml new file mode 100644 index 000000000..e69de29bb diff --git a/training/configs/callbacks/norm-monitor.yaml b/training/configs/callbacks/norm-monitor.yaml new file mode 100644 index 000000000..f4c6e2ccb --- /dev/null +++ b/training/configs/callbacks/norm-monitor.yaml @@ -0,0 +1,2 @@ +norm_monitor: + _target_: src.callbacks.norm_monitor.NormMonitor diff --git a/training/configs/callbacks/params-log.yaml b/training/configs/callbacks/params-log.yaml new file mode 100644 index 000000000..b2a49dd8d --- /dev/null +++ b/training/configs/callbacks/params-log.yaml @@ -0,0 +1,5 @@ +params_log: + _target_: src.callbacks.params_log.ParamsLog + total_params_log: True + trainable_params_log: True + non_trainable_params_log: True diff --git a/training/configs/callbacks/wandb.yaml b/training/configs/callbacks/wandb.yaml new file mode 100644 index 000000000..c6ae21d3a --- /dev/null +++ b/training/configs/callbacks/wandb.yaml @@ -0,0 +1,26 @@ +defaults: + - default.yaml + +watch_model: + _target_: src.callbacks.wandb_callbacks.WatchModel + log: "all" + log_freq: 100 + +upload_code_as_artifact: + _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact + code_dir: ${work_dir}/src + +upload_ckpts_as_artifact: + _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact + ckpt_dir: "checkpoints/" + upload_best_only: True + +log_f1_precision_recall_heatmap: + _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap + +log_confusion_matrix: + _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix + +log_image_predictions: + _target_: src.callbacks.wandb_callbacks.LogImagePredictions + num_samples: 8 diff --git a/training/configs/config.yaml b/training/configs/config.yaml new file mode 100644 index 000000000..f7c8f510f --- /dev/null +++ b/training/configs/config.yaml @@ -0,0 +1,50 @@ +# @package _global_ + +# specify here default training configuration +defaults: + - _self_ + - trainer: default + - optimizer: adamw + - scheduler: null + - task: sequence-model + - model: null + - datamodule: null + - callbacks: default # set this to null if you don't want to use callbacks + - metrics: null + - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`) + + - mode: default + + - experiment: null + - hparams_search: null + + # enable color logging + - override hydra/hydra_logging: colorlog + - override hydra/job_logging: colorlog + +# path to original working directory +# hydra hijacks working directory by changing it to the current log directory, +# so it's useful to have this path as a special variable +# https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory +work_dir: ${hydra:runtime.cwd} + +# path to folder with data +data_dir: ${work_dir}/data/ + +# pretty print config at the start of the run using Rich library +print_config: True + +# disable python warnings if they annoy you +ignore_warnings: True + +# check performance on test set, using the best model achieved during training +# lightning chooses best model based on metric specified in checkpoint callback +test_after_training: True + +resume: False + +# seed for random number generators in pytorch, numpy and python.random +seed: null + +# name of the run, accessed by loggers +name: null diff --git a/training/configs/datamodule/openwebtext.yaml b/training/configs/datamodule/openwebtext.yaml new file mode 100644 index 000000000..327decbd8 --- /dev/null +++ b/training/configs/datamodule/openwebtext.yaml @@ -0,0 +1,15 @@ +_target_: src.datamodules.language_modeling_hf.LMDataModule +dataset_name: openwebtext +dataset_config_name: null +tokenizer_name: gpt2 +cache_dir: ${oc.env:DATA_DIR,${data_dir}}/openwebtext/cache +max_length: 1024 +val_ratio: 0.0005 +val_split_seed: 2357 +add_eos: True +batch_size: 8 # per GPU +batch_size_eval: ${eval:${.batch_size} * 2} +num_workers: 32 # For preprocessing only +shuffle: True +pin_memory: True +__train_len: ${div_up:9035582198, ${.max_length}} diff --git a/training/configs/datamodule/thepile.yaml b/training/configs/datamodule/thepile.yaml new file mode 100644 index 000000000..d0f93535c --- /dev/null +++ b/training/configs/datamodule/thepile.yaml @@ -0,0 +1,14 @@ +_target_: src.datamodules.language_modeling_hf.LMDataModule +dataset_name: the_pile +dataset_config_name: null +tokenizer_name: gpt2 +cache_dir: ${oc.env:DATA_DIR,${data_dir}}/the_pile/cache +max_length: 2048 +add_eos: True +batch_size: 4 # per GPU +batch_size_eval: ${eval:${.batch_size} * 2} +num_workers: 64 # For preprocessing only +use_shmem: False +shuffle: True +pin_memory: True +__train_len: ${div_up:374337375694, ${.max_length}} diff --git a/training/configs/experiment/owt/base.yaml b/training/configs/experiment/owt/base.yaml new file mode 100644 index 000000000..988e186d2 --- /dev/null +++ b/training/configs/experiment/owt/base.yaml @@ -0,0 +1,82 @@ +# @package _global_ +defaults: + - override /trainer: default # choose trainer from 'configs/trainer/' + - override /model: null + - override /datamodule: openwebtext + # FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time + # per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms. + # For GPT2-medium time per global goes from 997ms to 972ms. + - override /optimizer: adamw-apex + - override /scheduler: linear-warmup + - override /callbacks: [default, norm-monitor] + - override /metrics: [perplexity, num-tokens] + - override /logger: wandb + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +task: + _target_: src.tasks.seq.SequenceLMModel + +seed: 1111 + +trainer: + accelerator: gpu + devices: 8 + num_nodes: 1 + accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}} + max_steps: 400000 + val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}} + check_val_every_n_epoch: null # We don't care about epoch boundary + precision: 16 + gradient_clip_val: 1.0 + strategy: null + +datamodule: + batch_size: 16 # Per GPU + batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k + max_length: 1024 + fault_tolerant: True + ddp: ${eval:"${trainer.devices} > 1"} + +train: + gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} + global_batch_size: 512 + optimizer: + lr: 6e-4 + weight_decay: 0.1 + optimizer_param_grouping: + bias_weight_decay: False + normalization_weight_decay: False + scheduler: + num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}} + num_training_steps: ${trainer.max_steps} + loss_fn: + # This is faster and uses less memory than torch.nn.CrossEntropyLoss. + # It's also more numerically stable if we're using DeepSpeed 16 bits. + _target_: src.losses.cross_entropy_apex.CrossEntropyLossApex + inplace_backward: True # to save memory + +eval: + log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step + +callbacks: + model_checkpoint: + monitor: val/loss + mode: min + save_top_k: 3 + save_last: True + every_n_train_steps: 1000 + dirpath: ${work_dir}/checkpoints/${oc.select:name,''} + filename: step_{step} + auto_insert_metric_name: False + model_checkpoint_progress: + _target_: src.callbacks.model_checkpoint.ModelCheckpointMine + fault_tolerant: True + every_n_train_steps: 50000 + save_last: False + save_top_k: -1 # Save all the checkpoints + dirpath: ${..model_checkpoint.dirpath} + filename: progress_step_{step} + auto_insert_metric_name: False + early_stopping: null diff --git a/training/configs/experiment/owt/gpt2l-flash.yaml b/training/configs/experiment/owt/gpt2l-flash.yaml new file mode 100644 index 000000000..5d81f10ce --- /dev/null +++ b/training/configs/experiment/owt/gpt2l-flash.yaml @@ -0,0 +1,41 @@ +# @package _global_ +defaults: + - /experiment/owt/gpt2m-flash.yaml + - override /model/gpt2model: gpt2-large + # TD [2022-08-03] Surprisingly it's faster to use the ZeRO optimizer than just AdamW. + # Still, fairscale is even faster and uses less memory. + # I think it's because Pytorch is using ZeRO stage 1 and fairscale is using ZeRO stage 2? + # However, fairscale has issues with saving checkpoint (either OOM or very + # slow since it goes through the CPU?). Fairscale says Pytorch ZeRO is the + # upstream version of OSS + # https://github.com/facebookresearch/fairscale/issues/937 + # Pytorch ZeRO as also very slow for saving checkpoints due to + # consolidate_state_dict(), but I've fixed it to save separate checkpoint per GPU. + - override /optimizer: adamw-zero + + # FusedAdam doesn't seem to speed things up here, time per global step + # (i.e. batch size 512) on 8 A100s is around 2056ms for both AdamW and FusedAdam. + # This could be because each GPU is only doing the optimizer step for 1 / + # world_size of the parameters. + # Maybe the bottleneck here is the NCCL call to exchange parameters (ZeRO). + # - override /optimizer: adamw-apex-zero + +# Can enable mlp_chekcpoint_lvl to fit batch_size 16 on A100 40GB +# model: +# config: +# # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"} +# mlp_checkpoint_lvl: 1 + +datamodule: + # batch_size: 16 + batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"} + +trainer: + # strategy: null + # strategy: ${eval:"None if ${trainer.devices} == 1 else 'ddp_sharded'"} + strategy: + _target_: src.utils.ddp_zero1.DDPStrategyZero1 + find_unused_parameters: False + gradient_as_bucket_view: True + # TD [2022-08-03] Deepspeed makes the ppl curve go wild + # strategy: deepspeed_stage_1 diff --git a/training/configs/experiment/owt/gpt2l-hf.yaml b/training/configs/experiment/owt/gpt2l-hf.yaml new file mode 100644 index 000000000..b8a292492 --- /dev/null +++ b/training/configs/experiment/owt/gpt2l-hf.yaml @@ -0,0 +1,14 @@ +# @package _global_ +defaults: + - /experiment/owt/gpt2m-hf.yaml + - override /model/gpt2model: gpt2-large + - override /optimizer: adamw-zero + +datamodule: + batch_size: 2 + +trainer: + strategy: + _target_: src.utils.ddp_zero1.DDPStrategyZero1 + find_unused_parameters: False + gradient_as_bucket_view: True diff --git a/training/configs/experiment/owt/gpt2l.yaml b/training/configs/experiment/owt/gpt2l.yaml new file mode 100644 index 000000000..83d3ccf25 --- /dev/null +++ b/training/configs/experiment/owt/gpt2l.yaml @@ -0,0 +1,14 @@ +# @package _global_ +defaults: + - /experiment/owt/gpt2m.yaml + - override /model/gpt2model: gpt2-large + - override /optimizer: adamw-zero + +datamodule: + batch_size: 4 # Per GPU + +trainer: + strategy: + _target_: src.utils.ddp_zero1.DDPStrategyZero1 + find_unused_parameters: False + gradient_as_bucket_view: True diff --git a/training/configs/experiment/owt/gpt2m-flash.yaml b/training/configs/experiment/owt/gpt2m-flash.yaml new file mode 100644 index 000000000..f3d93d917 --- /dev/null +++ b/training/configs/experiment/owt/gpt2m-flash.yaml @@ -0,0 +1,17 @@ +# @package _global_ +defaults: + - /experiment/owt/gpt2s-flash.yaml + - override /model/gpt2model: gpt2-medium + +# Can enable mlp_checkpoint_lvl to fit batch_size 32 to A100 40GB +model: + config: + mlp_checkpoint_lvl: 1 + +datamodule: + # batch_size: 32 + batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"} + +train: + optimizer: + lr: 1.5e-4 diff --git a/training/configs/experiment/owt/gpt2m-hf.yaml b/training/configs/experiment/owt/gpt2m-hf.yaml new file mode 100644 index 000000000..1e570e21b --- /dev/null +++ b/training/configs/experiment/owt/gpt2m-hf.yaml @@ -0,0 +1,11 @@ +# @package _global_ +defaults: + - /experiment/owt/gpt2s-hf.yaml + - override /model/gpt2model: gpt2-medium + +datamodule: + batch_size: 4 + +train: + optimizer: + lr: 1.5e-4 diff --git a/training/configs/experiment/owt/gpt2m.yaml b/training/configs/experiment/owt/gpt2m.yaml new file mode 100644 index 000000000..4cc99335b --- /dev/null +++ b/training/configs/experiment/owt/gpt2m.yaml @@ -0,0 +1,11 @@ +# @package _global_ +defaults: + - /experiment/owt/gpt2s.yaml + - override /model/gpt2model: gpt2-medium + +datamodule: + batch_size: 8 # Per GPU + +train: + optimizer: + lr: 1.5e-4 diff --git a/training/configs/experiment/owt/gpt2s-flash.yaml b/training/configs/experiment/owt/gpt2s-flash.yaml new file mode 100644 index 000000000..0bcd4021f --- /dev/null +++ b/training/configs/experiment/owt/gpt2s-flash.yaml @@ -0,0 +1,18 @@ +# @package _global_ +defaults: + - /experiment/owt/base.yaml + - override /model: gpt2 + - override /model/gpt2model: gpt2-small + +model: + config: + # n_positions is already set to ${datamodule.max_length} + use_flash_attn: True + fused_bias_fc: True + fused_dense_gelu_dense: True + fused_dropout_add_ln: True + pad_vocab_size_multiple: 8 + +datamodule: + # batch_size: 64 + batch_size: ${eval:"16 if ${train.gpu_mem} < 24 else (32 if ${train.gpu_mem} < 40 else 64)"} diff --git a/training/configs/experiment/owt/gpt2s-hf.yaml b/training/configs/experiment/owt/gpt2s-hf.yaml new file mode 100644 index 000000000..9b0f65cab --- /dev/null +++ b/training/configs/experiment/owt/gpt2s-hf.yaml @@ -0,0 +1,23 @@ +# @package _global_ +defaults: + - /experiment/owt/base.yaml + - override /model: gpt2-hf + - override /model/gpt2model: gpt2-small + - override /callbacks: [default, norm-monitor, flop-count] + +datamodule: + batch_size: 8 + +train: + # Use the standard torch.nn.CrossEntropyLoss + loss_fn: null + +callbacks: + flop_count: + input_size: + - ${datamodule.max_length} + input_dtype: + # It's surprisingly hard to get hydra to return torch.long since it's not a callable + _target_: torch.__getattribute__ + _args_: + - long diff --git a/training/configs/experiment/owt/gpt2s.yaml b/training/configs/experiment/owt/gpt2s.yaml new file mode 100644 index 000000000..c9faf60b0 --- /dev/null +++ b/training/configs/experiment/owt/gpt2s.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - /experiment/owt/base.yaml + - override /model: gpt2 + - override /model/gpt2model: gpt2-small + +datamodule: + batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"} diff --git a/training/configs/experiment/owt/gpt2xl-flash.yaml b/training/configs/experiment/owt/gpt2xl-flash.yaml new file mode 100644 index 000000000..b9d6ff0cc --- /dev/null +++ b/training/configs/experiment/owt/gpt2xl-flash.yaml @@ -0,0 +1,21 @@ +# @package _global_ +defaults: + - /experiment/owt/gpt2l-flash.yaml + - override /model/gpt2model: gpt2-xlarge + +# Can enable mlp_checkpoint_lvl to fit to A100 40GB +# model: +# config: +# # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"} +# mlp_checkpoint_lvl: 1 + +datamodule: + batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"} + # With adamw-zero optimizer: + # checkpoint_lvl=1, batch size = 4: mem 37GB, 4650ms / batch of 512 (285ms * 15 + 375ms * 1) + # checkpoint_lvl=1, batch size = 8: mem 46GB, 4330ms / batch of 512 (530ms * 7 + 620ms * 1) + # checkpoint_lvl=2, batch size = 8: mem 41GB, 4570ms / batch of 512 (560ms * 7 + 650ms * 1) + # With adamw-apex-distributed optimizer: + # checkpoint_lvl=1, batch size = 8: mem 41.5GB, 4500ms / batch of 512 (550ms * 7 + 650ms * 1) + # checkpoint_lvl=1 for 24 layers and checkpoint_lvl=2 for 24 layers, + # batch size = 8: mem 39GB, 4640ms / batch of 512 (565ms * 7 + 675ms * 1) diff --git a/training/configs/experiment/owt/gpt2xl.yaml b/training/configs/experiment/owt/gpt2xl.yaml new file mode 100644 index 000000000..a43db2f85 --- /dev/null +++ b/training/configs/experiment/owt/gpt2xl.yaml @@ -0,0 +1,14 @@ +# @package _global_ +defaults: + - /experiment/owt/gpt2m.yaml + - override /model/gpt2model: gpt2-xlarge + - override /optimizer: adamw-zero + +datamodule: + batch_size: 2 # Per GPU + +trainer: + strategy: + _target_: src.utils.ddp_zero1.DDPStrategyZero1 + find_unused_parameters: False + gradient_as_bucket_view: True diff --git a/training/configs/experiment/pile/base.yaml b/training/configs/experiment/pile/base.yaml new file mode 100644 index 000000000..ce46efd36 --- /dev/null +++ b/training/configs/experiment/pile/base.yaml @@ -0,0 +1,83 @@ +# @package _global_ +defaults: + - override /trainer: default # choose trainer from 'configs/trainer/' + - override /model: null + - override /datamodule: thepile + - override /optimizer: adamw-apex # slight speedup (1-2%) over Pytorch AdamW + - override /scheduler: cosine-warmup-timm + - override /callbacks: [default, norm-monitor] + - override /metrics: [perplexity, num-tokens] + - override /logger: wandb + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +task: + _target_: src.tasks.seq.SequenceLMModel + +seed: 1111 + +trainer: + accelerator: gpu + devices: 8 + num_nodes: 1 + accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}} + max_steps: 800000 + val_check_interval: ${eval:2000 * ${.accumulate_grad_batches}} + check_val_every_n_epoch: null # We don't care about epoch boundary + precision: bf16 + gradient_clip_val: 1.0 + strategy: null + +datamodule: + batch_size: 16 # Per GPU + batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k + max_length: 2048 + fault_tolerant: True + ddp: ${eval:"${trainer.devices} > 1"} + +train: + gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} + global_batch_size: 256 + optimizer: + lr: 6e-4 + weight_decay: 0.1 + optimizer_param_grouping: + bias_weight_decay: False + normalization_weight_decay: False + scheduler: + t_in_epochs: False + t_initial: 600000 + warmup_lr_init: 1e-6 + warmup_t: ${eval:0.01 * ${trainer.max_steps}} + lr_min: ${eval:0.1 * ${train.optimizer.lr}} + loss_fn: + # This is faster and uses less memory than torch.nn.CrossEntropyLoss. + # It's also more numerically stable if we're using DeepSpeed 16 bits. + _target_: src.losses.cross_entropy_apex.CrossEntropyLossApex + inplace_backward: True # to save memory + +eval: + log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step + +callbacks: + model_checkpoint: + monitor: val/loss + mode: min + save_top_k: 3 + save_last: True + every_n_train_steps: 1000 + dirpath: ${work_dir}/checkpoints/${oc.select:name,''} + filename: step_{step} + auto_insert_metric_name: False + model_checkpoint_progress: + _target_: src.callbacks.model_checkpoint.ModelCheckpointMine + # fault_tolerant: True # The .pl_auto_save.ckpt doesn't get saved by all workers + every_n_train_steps: 50000 + save_last: False + save_top_k: -1 # Save all the checkpoints + dirpath: ${..model_checkpoint.dirpath} + filename: progress_step_{step} + auto_insert_metric_name: False + early_stopping: null + diff --git a/training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml b/training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml new file mode 100644 index 000000000..cb0ab0db9 --- /dev/null +++ b/training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml @@ -0,0 +1,18 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt2xl-flash-8k.yaml + +model: + config: + n_embd: 2560 + n_head: 32 + n_layer: 32 + initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} + mlp_checkpoint_lvl: 0 + +datamodule: + batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} + +train: + optimizer: + lr: 1.6e-4 diff --git a/training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml b/training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml new file mode 100644 index 000000000..9dfaa827a --- /dev/null +++ b/training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml @@ -0,0 +1,18 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt2xl-flash-rotary-8k.yaml + +model: + config: + n_embd: 2560 + n_head: 20 + n_layer: 32 + initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} + mlp_checkpoint_lvl: 0 + +datamodule: + batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"} + +train: + optimizer: + lr: 1.6e-4 diff --git a/training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml b/training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml new file mode 100644 index 000000000..aab9c970e --- /dev/null +++ b/training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml @@ -0,0 +1,18 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt2xl-flash-rotary.yaml + +model: + config: + n_embd: 2560 + n_head: 20 + n_layer: 32 + initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} + mlp_checkpoint_lvl: 0 + +datamodule: + batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu} < 80 else 32))"} + +train: + optimizer: + lr: 1.6e-4 diff --git a/training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml b/training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml new file mode 100644 index 000000000..6e56a2b06 --- /dev/null +++ b/training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml @@ -0,0 +1,18 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt2xl-flash-rotary-8k.yaml + +model: + config: + n_embd: 2560 + n_head: 32 + n_layer: 32 + initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} + mlp_checkpoint_lvl: 0 + +datamodule: + batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"} + +train: + optimizer: + lr: 1.6e-4 diff --git a/training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml b/training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml new file mode 100644 index 000000000..60853b850 --- /dev/null +++ b/training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml @@ -0,0 +1,18 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt2xl-flash-rotary.yaml + +model: + config: + n_embd: 2560 + n_head: 32 + n_layer: 32 + initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} + mlp_checkpoint_lvl: 0 + +datamodule: + batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu} < 80 else 32))"} + +train: + optimizer: + lr: 1.6e-4 diff --git a/training/configs/experiment/pile/gpt3l-flash-8k.yaml b/training/configs/experiment/pile/gpt3l-flash-8k.yaml new file mode 100644 index 000000000..ccbbebfd0 --- /dev/null +++ b/training/configs/experiment/pile/gpt3l-flash-8k.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3l-flash.yaml + +datamodule: + max_length: 8192 + batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} + +train: + global_batch_size: 64 diff --git a/training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml b/training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml new file mode 100644 index 000000000..74c6bb9ce --- /dev/null +++ b/training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3l-flash-rotary.yaml + +trainer: + max_steps: 60000 + +train: + scheduler: + t_initial: ${trainer.max_steps} diff --git a/training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml b/training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml new file mode 100644 index 000000000..2b3ba3145 --- /dev/null +++ b/training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3l-flash-8k.yaml + +model: + config: + max_position_embeddings: 0 # Disable absolute position embedding + rotary_emb_fraction: 0.5 diff --git a/training/configs/experiment/pile/gpt3l-flash-rotary.yaml b/training/configs/experiment/pile/gpt3l-flash-rotary.yaml new file mode 100644 index 000000000..f28563202 --- /dev/null +++ b/training/configs/experiment/pile/gpt3l-flash-rotary.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3l-flash.yaml + +model: + config: + max_position_embeddings: 0 # Disable absolute position embedding + rotary_emb_fraction: 0.5 diff --git a/training/configs/experiment/pile/gpt3l-flash.yaml b/training/configs/experiment/pile/gpt3l-flash.yaml new file mode 100644 index 000000000..eebc19a85 --- /dev/null +++ b/training/configs/experiment/pile/gpt3l-flash.yaml @@ -0,0 +1,24 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3s-flash.yaml + - override /optimizer: adamw-zero + +model: + config: + n_embd: 1536 + n_head: 16 + n_layer: 24 + # mlp_checkpoint_lvl: 1 # To fit batch_size 8 + +datamodule: + batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"} + +train: + optimizer: + lr: 2.5e-4 + +trainer: + strategy: + _target_: src.utils.ddp_zero1.DDPStrategyZero1 + find_unused_parameters: False + gradient_as_bucket_view: True diff --git a/training/configs/experiment/pile/gpt3m-flash-8k.yaml b/training/configs/experiment/pile/gpt3m-flash-8k.yaml new file mode 100644 index 000000000..d75e6d3a3 --- /dev/null +++ b/training/configs/experiment/pile/gpt3m-flash-8k.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3m-flash.yaml + +datamodule: + max_length: 8192 + batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"} + +train: + global_batch_size: 64 diff --git a/training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml b/training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml new file mode 100644 index 000000000..04630753e --- /dev/null +++ b/training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3m-flash-rotary.yaml + +trainer: + max_steps: 60000 + +train: + scheduler: + t_initial: ${trainer.max_steps} diff --git a/training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml b/training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml new file mode 100644 index 000000000..f217ac521 --- /dev/null +++ b/training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3m-flash-8k.yaml + +model: + config: + max_position_embeddings: 0 # Disable absolute position embedding + rotary_emb_fraction: 0.5 diff --git a/training/configs/experiment/pile/gpt3m-flash-rotary.yaml b/training/configs/experiment/pile/gpt3m-flash-rotary.yaml new file mode 100644 index 000000000..adb0cb614 --- /dev/null +++ b/training/configs/experiment/pile/gpt3m-flash-rotary.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3m-flash.yaml + +model: + config: + max_position_embeddings: 0 # Disable absolute position embedding + rotary_emb_fraction: 0.5 diff --git a/training/configs/experiment/pile/gpt3m-flash.yaml b/training/configs/experiment/pile/gpt3m-flash.yaml new file mode 100644 index 000000000..b1bfe5e03 --- /dev/null +++ b/training/configs/experiment/pile/gpt3m-flash.yaml @@ -0,0 +1,16 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3s-flash.yaml + - override /model/gpt2model: gpt2-medium + +# Can enable mlp_checkpoint_lvl to fit batch_size 16 to A100 40GB +# model: +# config: +# mlp_checkpoint_lvl: 1 + +datamodule: + batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"} + +train: + optimizer: + lr: 3.0e-4 diff --git a/training/configs/experiment/pile/gpt3s-flash-8k.yaml b/training/configs/experiment/pile/gpt3s-flash-8k.yaml new file mode 100644 index 000000000..06ce6453d --- /dev/null +++ b/training/configs/experiment/pile/gpt3s-flash-8k.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3s-flash.yaml + +datamodule: + max_length: 8192 + batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"} + +train: + global_batch_size: 64 diff --git a/training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml b/training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml new file mode 100644 index 000000000..d43448006 --- /dev/null +++ b/training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3s-flash-rotary.yaml + +trainer: + max_steps: 60000 + +train: + scheduler: + t_initial: ${trainer.max_steps} diff --git a/training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml b/training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml new file mode 100644 index 000000000..bdee8766f --- /dev/null +++ b/training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3s-flash-8k.yaml + +model: + config: + max_position_embeddings: 0 # Disable absolute position embedding + rotary_emb_fraction: 0.5 diff --git a/training/configs/experiment/pile/gpt3s-flash-rotary.yaml b/training/configs/experiment/pile/gpt3s-flash-rotary.yaml new file mode 100644 index 000000000..41176eea1 --- /dev/null +++ b/training/configs/experiment/pile/gpt3s-flash-rotary.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3s-flash.yaml + +model: + config: + max_position_embeddings: 0 # Disable absolute position embedding + rotary_emb_fraction: 0.5 diff --git a/training/configs/experiment/pile/gpt3s-flash.yaml b/training/configs/experiment/pile/gpt3s-flash.yaml new file mode 100644 index 000000000..3def2a8e8 --- /dev/null +++ b/training/configs/experiment/pile/gpt3s-flash.yaml @@ -0,0 +1,17 @@ +# @package _global_ +defaults: + - /experiment/pile/base.yaml + - override /model: gpt2 + - override /model/gpt2model: gpt2-small + +model: + config: + # n_positions is already set to ${datamodule.max_length} + use_flash_attn: True + fused_dropout_add_ln: True + fused_dense_gelu_dense: True + fused_bias_fc: True + pad_vocab_size_multiple: 8 + +datamodule: + batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"} diff --git a/training/configs/experiment/pile/gpt3xl-flash-8k.yaml b/training/configs/experiment/pile/gpt3xl-flash-8k.yaml new file mode 100644 index 000000000..578fd1c3a --- /dev/null +++ b/training/configs/experiment/pile/gpt3xl-flash-8k.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt2xl-flash.yaml + +datamodule: + max_length: 8192 + batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} + +train: + global_batch_size: 128 diff --git a/training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml b/training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml new file mode 100644 index 000000000..f32e96ce4 --- /dev/null +++ b/training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml @@ -0,0 +1,10 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt2xl-flash-rotary.yaml + +trainer: + max_steps: 60000 + +train: + scheduler: + t_initial: ${trainer.max_steps} diff --git a/training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml b/training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml new file mode 100644 index 000000000..9f5dd00e8 --- /dev/null +++ b/training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt2xl-flash-8k.yaml + +model: + config: + max_position_embeddings: 0 # Disable absolute position embedding + rotary_emb_fraction: 0.5 diff --git a/training/configs/experiment/pile/gpt3xl-flash-rotary.yaml b/training/configs/experiment/pile/gpt3xl-flash-rotary.yaml new file mode 100644 index 000000000..b188eddeb --- /dev/null +++ b/training/configs/experiment/pile/gpt3xl-flash-rotary.yaml @@ -0,0 +1,8 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt2xl-flash.yaml + +model: + config: + max_position_embeddings: 0 # Disable absolute position embedding + rotary_emb_fraction: 0.5 diff --git a/training/configs/experiment/pile/gpt3xl-flash.yaml b/training/configs/experiment/pile/gpt3xl-flash.yaml new file mode 100644 index 000000000..96165e00c --- /dev/null +++ b/training/configs/experiment/pile/gpt3xl-flash.yaml @@ -0,0 +1,35 @@ +# @package _global_ +defaults: + - /experiment/pile/gpt3s-flash.yaml + - override /optimizer: adamw-zero + +model: + config: + n_embd: 2048 + n_head: 16 + n_layer: 24 + +datamodule: + batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"} + +train: + global_batch_size: 512 + optimizer: + lr: 2.0e-4 + scheduler: + t_initial: 300000 + +trainer: + strategy: + _target_: src.utils.ddp_zero1.DDPStrategyZero1 + find_unused_parameters: False + gradient_as_bucket_view: True + max_steps: 400000 + val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}} + +callbacks: + model_checkpoint: + every_n_train_steps: 1000 + model_checkpoint_progress: + every_n_train_steps: 12500 + fault_tolerant: False # Saving takes too long diff --git a/training/configs/logger/comet.yaml b/training/configs/logger/comet.yaml new file mode 100644 index 000000000..6ac99f46c --- /dev/null +++ b/training/configs/logger/comet.yaml @@ -0,0 +1,7 @@ +# https://www.comet.ml + +comet: + _target_: pytorch_lightning.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + project_name: "template-tests" + experiment_name: ${name} diff --git a/training/configs/logger/csv.yaml b/training/configs/logger/csv.yaml new file mode 100644 index 000000000..0f917e89c --- /dev/null +++ b/training/configs/logger/csv.yaml @@ -0,0 +1,8 @@ +# csv logger built in lightning + +csv: + _target_: pytorch_lightning.loggers.csv_logs.CSVLogger + save_dir: "." + name: "csv/" + version: ${name} + prefix: "" diff --git a/training/configs/logger/many_loggers.yaml b/training/configs/logger/many_loggers.yaml new file mode 100644 index 000000000..7bc3d6762 --- /dev/null +++ b/training/configs/logger/many_loggers.yaml @@ -0,0 +1,9 @@ +# train with many loggers at once + +defaults: + # - comet.yaml + - csv.yaml + # - mlflow.yaml + # - neptune.yaml + # - tensorboard.yaml + - wandb.yaml diff --git a/training/configs/logger/mlflow.yaml b/training/configs/logger/mlflow.yaml new file mode 100644 index 000000000..bfb3781b1 --- /dev/null +++ b/training/configs/logger/mlflow.yaml @@ -0,0 +1,10 @@ +# https://mlflow.org + +mlflow: + _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger + experiment_name: ${name} + tracking_uri: null + tags: null + save_dir: ./mlruns + prefix: "" + artifact_location: null diff --git a/training/configs/logger/neptune.yaml b/training/configs/logger/neptune.yaml new file mode 100644 index 000000000..117af9379 --- /dev/null +++ b/training/configs/logger/neptune.yaml @@ -0,0 +1,11 @@ +# https://neptune.ai + +neptune: + _target_: pytorch_lightning.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project_name: your_name/template-tests + close_after_fit: True + offline_mode: False + experiment_name: ${name} + experiment_id: null + prefix: "" diff --git a/training/configs/logger/tensorboard.yaml b/training/configs/logger/tensorboard.yaml new file mode 100644 index 000000000..acd1fa411 --- /dev/null +++ b/training/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger + save_dir: "tensorboard/" + name: "default" + version: ${name} + log_graph: False + default_hp_metric: True + prefix: "" diff --git a/training/configs/logger/wandb.yaml b/training/configs/logger/wandb.yaml new file mode 100644 index 000000000..67527809e --- /dev/null +++ b/training/configs/logger/wandb.yaml @@ -0,0 +1,15 @@ +# https://wandb.ai + +wandb: + _target_: pytorch_lightning.loggers.wandb.WandbLogger + project: attention + name: ${name} + save_dir: "." + mode: online # set offline to store all logs only locally + id: ${oc.select:name} # pass correct id to resume experiment! + # entity: "" # set to name of your wandb team or just remove it + log_model: False + prefix: "" + job_type: "train" + group: "" + tags: [] diff --git a/training/configs/metrics/acc.yaml b/training/configs/metrics/acc.yaml new file mode 100644 index 000000000..fe7a63f5c --- /dev/null +++ b/training/configs/metrics/acc.yaml @@ -0,0 +1,3 @@ +# @package eval.metrics +acc: + _target_: src.metrics.accuracy.AccuracyMine diff --git a/training/configs/metrics/acc_ignore_index.yaml b/training/configs/metrics/acc_ignore_index.yaml new file mode 100644 index 000000000..03364aa1a --- /dev/null +++ b/training/configs/metrics/acc_ignore_index.yaml @@ -0,0 +1,4 @@ +# @package eval.metrics +acc: + _target_: torchmetrics.Accuracy + ignore_index: -100 diff --git a/training/configs/metrics/acctop5.yaml b/training/configs/metrics/acctop5.yaml new file mode 100644 index 000000000..5f798ae0c --- /dev/null +++ b/training/configs/metrics/acctop5.yaml @@ -0,0 +1,4 @@ +# @package eval.metrics +acctop5: + _target_: src.metrics.accuracy.AccuracyMine + top_k: 5 diff --git a/training/configs/metrics/mse.yaml b/training/configs/metrics/mse.yaml new file mode 100644 index 000000000..50b0484d9 --- /dev/null +++ b/training/configs/metrics/mse.yaml @@ -0,0 +1,3 @@ +# @package eval.metrics +mse: + _target_: torchmetrics.MeanSquaredError diff --git a/training/configs/metrics/num-tokens.yaml b/training/configs/metrics/num-tokens.yaml new file mode 100644 index 000000000..047d42354 --- /dev/null +++ b/training/configs/metrics/num-tokens.yaml @@ -0,0 +1,3 @@ +# @package eval.metrics +num-tokens: + _target_: src.metrics.num_tokens.NumTokens diff --git a/training/configs/metrics/perplexity.yaml b/training/configs/metrics/perplexity.yaml new file mode 100644 index 000000000..2edd21788 --- /dev/null +++ b/training/configs/metrics/perplexity.yaml @@ -0,0 +1,3 @@ +# @package eval.metrics +ppl: + _target_: src.metrics.perplexity.Perplexity diff --git a/training/configs/mode/debug.yaml b/training/configs/mode/debug.yaml new file mode 100644 index 000000000..b2335c981 --- /dev/null +++ b/training/configs/mode/debug.yaml @@ -0,0 +1,27 @@ +# @package _global_ + +# run in debug mode with: +# `python run.py mode=debug` + +defaults: + - override /trainer: debug.yaml + +debug_mode: True + +hydra: + # sets level of all command line loggers to 'DEBUG' + verbose: True + + # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ + # sets level of only chosen command line loggers to 'DEBUG' + # verbose: [src.train, src.utils.utils] + + # sets output paths for all file logs to 'logs/debug/' + run: + dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S} + subdir: ${hydra.job.num} + +# disable rich config printing, since it will be already printed by hydra when `verbose: True` +print_config: False diff --git a/training/configs/mode/default.yaml b/training/configs/mode/default.yaml new file mode 100644 index 000000000..ac0cae1ce --- /dev/null +++ b/training/configs/mode/default.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# default running mode + +default_mode: True + +hydra: + # default output paths for all file logs + run: + dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/multiruns/${now:%Y-%m-%d_%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/training/configs/mode/exp.yaml b/training/configs/mode/exp.yaml new file mode 100644 index 000000000..032aaa943 --- /dev/null +++ b/training/configs/mode/exp.yaml @@ -0,0 +1,17 @@ +# @package _global_ + +# run in experiment mode with: +# `python run.py mode=exp name=experiment_name` + +experiment_mode: True + +# allows for custom naming of the experiment +name: ??? + +hydra: + # sets output paths for all file logs to `logs/experiment/name' + run: + dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name} + sweep: + dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name} + subdir: ${hydra.job.num} diff --git a/training/configs/mode/profile.yaml b/training/configs/mode/profile.yaml new file mode 100644 index 000000000..f6c547a43 --- /dev/null +++ b/training/configs/mode/profile.yaml @@ -0,0 +1,31 @@ +# @package _global_ +# Run the Pytorch profiler + +trainer: + profiler: + _target_: pytorch_lightning.profilers.PyTorchProfiler + dirpath: ${hydra.run.dir} + schedule: + _target_: torch.profiler.schedule + wait: 5 + warmup: 5 + active: 5 + use_cuda: True + max_steps: 20 + +logger: + wandb: + mode: disabled + +callbacks: + model_checkpoint: null + model_checkpoint_progress: null + early_stopping: null + +hydra: + # sets output paths for all file logs to 'logs/profile/' + run: + dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/multirun_${now:%Y-%m-%d_%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/training/configs/mode/smoke.yaml b/training/configs/mode/smoke.yaml new file mode 100644 index 000000000..eac3dd286 --- /dev/null +++ b/training/configs/mode/smoke.yaml @@ -0,0 +1,22 @@ +# @package _global_ +# Smoke test: disable logging and model checkpointing + +logger: + wandb: + mode: disabled + +callbacks: + model_checkpoint: null + model_checkpoint_progress: null + +hydra: + # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ + # sets level of only chosen command line loggers to 'DEBUG' + # verbose: [src.train, src.utils.utils] + + # sets output paths for all file logs to 'logs/debug/' + run: + dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/training/configs/model/gpt2-hf.yaml b/training/configs/model/gpt2-hf.yaml new file mode 100644 index 000000000..d6cb22f0a --- /dev/null +++ b/training/configs/model/gpt2-hf.yaml @@ -0,0 +1,13 @@ +defaults: + - _self_ + - gpt2model: gpt2-small + +_target_: transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel +_recursive_: True +config: + _target_: transformers.GPT2Config + # Mistral's config: https://github.com/stanford-crfm/mistral/blob/main/conf/models/gpt2-small.yaml + # However, reorder_and_upcast_attn slows things down + reorder_and_upcast_attn: false + scale_attn_by_inverse_layer_idx: true + n_positions: ${datamodule.max_length} diff --git a/training/configs/model/gpt2.yaml b/training/configs/model/gpt2.yaml new file mode 100644 index 000000000..6c3868d06 --- /dev/null +++ b/training/configs/model/gpt2.yaml @@ -0,0 +1,13 @@ +defaults: + - _self_ + - gpt2model: gpt2-small + +_target_: flash_attn.models.gpt.GPTLMHeadModel +_recursive_: True +config: + _target_: transformers.GPT2Config + # Mistral's config: # https://github.com/stanford-crfm/mistral/blob/main/conf/models/mistral-small.yaml + # However, reorder_and_upcast_attn slows things down + reorder_and_upcast_attn: false + scale_attn_by_inverse_layer_idx: true + n_positions: ${datamodule.max_length} diff --git a/training/configs/model/gpt2model/gpt2-large.yaml b/training/configs/model/gpt2model/gpt2-large.yaml new file mode 100644 index 000000000..434a61eb9 --- /dev/null +++ b/training/configs/model/gpt2model/gpt2-large.yaml @@ -0,0 +1,6 @@ +# @package _global_ +model: + config: + n_embd: 1280 + n_head: 20 + n_layer: 36 diff --git a/training/configs/model/gpt2model/gpt2-medium.yaml b/training/configs/model/gpt2model/gpt2-medium.yaml new file mode 100644 index 000000000..786091836 --- /dev/null +++ b/training/configs/model/gpt2model/gpt2-medium.yaml @@ -0,0 +1,6 @@ +# @package _global_ +model: + config: + n_embd: 1024 + n_head: 16 + n_layer: 24 diff --git a/training/configs/model/gpt2model/gpt2-small.yaml b/training/configs/model/gpt2model/gpt2-small.yaml new file mode 100644 index 000000000..039c91802 --- /dev/null +++ b/training/configs/model/gpt2model/gpt2-small.yaml @@ -0,0 +1,6 @@ +# @package _global_ +model: + config: + n_embd: 768 + n_head: 12 + n_layer: 12 diff --git a/training/configs/model/gpt2model/gpt2-xlarge.yaml b/training/configs/model/gpt2model/gpt2-xlarge.yaml new file mode 100644 index 000000000..d67a0e418 --- /dev/null +++ b/training/configs/model/gpt2model/gpt2-xlarge.yaml @@ -0,0 +1,6 @@ +# @package _global_ +model: + config: + n_embd: 1600 + n_head: 25 + n_layer: 48 diff --git a/training/configs/optimizer/adam.yaml b/training/configs/optimizer/adam.yaml new file mode 100644 index 000000000..f8821d74c --- /dev/null +++ b/training/configs/optimizer/adam.yaml @@ -0,0 +1,2 @@ +# @package train.optimizer +_target_: torch.optim.Adam diff --git a/training/configs/optimizer/adamw-apex-distributed.yaml b/training/configs/optimizer/adamw-apex-distributed.yaml new file mode 100644 index 000000000..b7a5136eb --- /dev/null +++ b/training/configs/optimizer/adamw-apex-distributed.yaml @@ -0,0 +1,3 @@ +# @package train.optimizer +_target_: apex.contrib.optimizers.distributed_fused_adam.DistributedFusedAdam +adam_w_mode: True diff --git a/training/configs/optimizer/adamw-apex-zero.yaml b/training/configs/optimizer/adamw-apex-zero.yaml new file mode 100644 index 000000000..f19d7a044 --- /dev/null +++ b/training/configs/optimizer/adamw-apex-zero.yaml @@ -0,0 +1,7 @@ +# @package train.optimizer +_target_: torch.distributed.optim.ZeroRedundancyOptimizer +_recursive_: True +optimizer_class: + _target_: apex.optimizers.FusedAdam + _partial_: True + adam_w_mode: True diff --git a/training/configs/optimizer/adamw-apex.yaml b/training/configs/optimizer/adamw-apex.yaml new file mode 100644 index 000000000..fdbf90fdf --- /dev/null +++ b/training/configs/optimizer/adamw-apex.yaml @@ -0,0 +1,3 @@ +# @package train.optimizer +_target_: apex.optimizers.FusedAdam +adam_w_mode: True diff --git a/training/configs/optimizer/adamw-zero.yaml b/training/configs/optimizer/adamw-zero.yaml new file mode 100644 index 000000000..66ea2fd03 --- /dev/null +++ b/training/configs/optimizer/adamw-zero.yaml @@ -0,0 +1,7 @@ +# @package train.optimizer +_target_: torch.distributed.optim.ZeroRedundancyOptimizer +_recursive_: True +optimizer_class: + _target_: torch.optim.__getattribute__ + _args_: + - "AdamW" diff --git a/training/configs/optimizer/adamw.yaml b/training/configs/optimizer/adamw.yaml new file mode 100644 index 000000000..02252ec1c --- /dev/null +++ b/training/configs/optimizer/adamw.yaml @@ -0,0 +1,2 @@ +# @package train.optimizer +_target_: torch.optim.AdamW diff --git a/training/configs/optimizer/fusedlamb-ds.yaml b/training/configs/optimizer/fusedlamb-ds.yaml new file mode 100644 index 000000000..a4fffbfb3 --- /dev/null +++ b/training/configs/optimizer/fusedlamb-ds.yaml @@ -0,0 +1,2 @@ +# @package train.optimizer +_target_: deepspeed.ops.lamb.FusedLamb diff --git a/training/configs/optimizer/fusedlamb.yaml b/training/configs/optimizer/fusedlamb.yaml new file mode 100644 index 000000000..c8d7b2b8e --- /dev/null +++ b/training/configs/optimizer/fusedlamb.yaml @@ -0,0 +1,2 @@ +# @package train.optimizer +_target_: apex.optimizers.FusedLAMB diff --git a/training/configs/optimizer/sgd.yaml b/training/configs/optimizer/sgd.yaml new file mode 100644 index 000000000..43b834653 --- /dev/null +++ b/training/configs/optimizer/sgd.yaml @@ -0,0 +1,2 @@ +# @package train.optimizer +_target_: torch.optim.SGD diff --git a/training/configs/scheduler/cosine-warmup-timm.yaml b/training/configs/scheduler/cosine-warmup-timm.yaml new file mode 100644 index 000000000..f2bbbec01 --- /dev/null +++ b/training/configs/scheduler/cosine-warmup-timm.yaml @@ -0,0 +1,2 @@ +# @package train.scheduler +_target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler diff --git a/training/configs/scheduler/cosine-warmup.yaml b/training/configs/scheduler/cosine-warmup.yaml new file mode 100644 index 000000000..afaf0f618 --- /dev/null +++ b/training/configs/scheduler/cosine-warmup.yaml @@ -0,0 +1,2 @@ +# @package train.scheduler +_target_: transformers.get_cosine_schedule_with_warmup diff --git a/training/configs/scheduler/invsqrt.yaml b/training/configs/scheduler/invsqrt.yaml new file mode 100644 index 000000000..bb16f3c15 --- /dev/null +++ b/training/configs/scheduler/invsqrt.yaml @@ -0,0 +1,3 @@ +# @package train.scheduler +_target_: src.optim.lr_scheduler.InvSqrt +num_warmup_steps: ??? diff --git a/training/configs/scheduler/linear-warmup.yaml b/training/configs/scheduler/linear-warmup.yaml new file mode 100644 index 000000000..bb6a69896 --- /dev/null +++ b/training/configs/scheduler/linear-warmup.yaml @@ -0,0 +1,2 @@ +# @package train.scheduler +_target_: transformers.get_linear_schedule_with_warmup diff --git a/training/configs/scheduler/multi-step.yaml b/training/configs/scheduler/multi-step.yaml new file mode 100644 index 000000000..42cd60716 --- /dev/null +++ b/training/configs/scheduler/multi-step.yaml @@ -0,0 +1,2 @@ +# @package train.scheduler +_target_: torch.optim.lr_scheduler.MultiStepLR diff --git a/training/configs/scheduler/plateau.yaml b/training/configs/scheduler/plateau.yaml new file mode 100644 index 000000000..436c264dc --- /dev/null +++ b/training/configs/scheduler/plateau.yaml @@ -0,0 +1,9 @@ +# @package _global_ +train: + scheduler_interval: epoch + scheduler_monitor: ??? + scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + factor: 0.2 # Decay factor when ReduceLROnPlateau is used + patience: 20 + min_lr: 0.0 # Minimum learning rate during annealing diff --git a/training/configs/scheduler/poly-warmup.yaml b/training/configs/scheduler/poly-warmup.yaml new file mode 100644 index 000000000..79808ea42 --- /dev/null +++ b/training/configs/scheduler/poly-warmup.yaml @@ -0,0 +1,2 @@ +# @package train.scheduler +_target_: transformers.get_polynomial_decay_schedule_with_warmup diff --git a/training/configs/scheduler/step.yaml b/training/configs/scheduler/step.yaml new file mode 100644 index 000000000..e0d9a0ce8 --- /dev/null +++ b/training/configs/scheduler/step.yaml @@ -0,0 +1,3 @@ +# @package train.scheduler +_target_: torch.optim.lr_scheduler.StepLR +step_size: ??? diff --git a/training/configs/task/sequence-model.yaml b/training/configs/task/sequence-model.yaml new file mode 100644 index 000000000..435cf0501 --- /dev/null +++ b/training/configs/task/sequence-model.yaml @@ -0,0 +1 @@ +_target_: src.tasks.seq.SequenceModel diff --git a/training/configs/trainer/all_params.yaml b/training/configs/trainer/all_params.yaml new file mode 100644 index 000000000..24a0b5048 --- /dev/null +++ b/training/configs/trainer/all_params.yaml @@ -0,0 +1,49 @@ +_target_: pytorch_lightning.Trainer + +# default values for all trainer parameters +checkpoint_callback: True +default_root_dir: null +gradient_clip_val: 0.0 +process_position: 0 +num_nodes: 1 +num_processes: 1 +gpus: null +auto_select_gpus: False +tpu_cores: null +log_gpu_memory: null +overfit_batches: 0.0 +track_grad_norm: -1 +check_val_every_n_epoch: 1 +fast_dev_run: False +accumulate_grad_batches: 1 +max_epochs: 1 +min_epochs: 1 +max_steps: null +min_steps: null +limit_train_batches: 1.0 +limit_val_batches: 1.0 +limit_test_batches: 1.0 +val_check_interval: 1.0 +flush_logs_every_n_steps: 100 +log_every_n_steps: 50 +accelerator: null +sync_batchnorm: False +precision: 32 +weights_summary: "top" +weights_save_path: null +num_sanity_val_steps: 2 +truncated_bptt_steps: null +resume_from_checkpoint: null +profiler: null +benchmark: False +deterministic: False +reload_dataloaders_every_epoch: False +auto_lr_find: False +replace_sampler_ddp: True +terminate_on_nan: False +auto_scale_batch_size: False +prepare_data_per_node: True +plugins: null +amp_backend: "native" +amp_level: "O2" +move_metrics_to_cpu: False diff --git a/training/configs/trainer/ddp.yaml b/training/configs/trainer/ddp.yaml new file mode 100644 index 000000000..3c9544407 --- /dev/null +++ b/training/configs/trainer/ddp.yaml @@ -0,0 +1,6 @@ +defaults: + - default.yaml + +accelerator: gpu +devices: 4 +strategy: ddp diff --git a/training/configs/trainer/debug.yaml b/training/configs/trainer/debug.yaml new file mode 100644 index 000000000..8371d96ce --- /dev/null +++ b/training/configs/trainer/debug.yaml @@ -0,0 +1,21 @@ +defaults: + - default.yaml + +gpus: 0 + +min_epochs: 1 +max_epochs: 2 + +# prints +weights_summary: "full" +profiler: null + +# debugs +fast_dev_run: true +num_sanity_val_steps: 2 +overfit_batches: 0 +limit_train_batches: 1.0 +limit_val_batches: 1.0 +limit_test_batches: 1.0 +track_grad_norm: -1 +terminate_on_nan: true diff --git a/training/configs/trainer/default.yaml b/training/configs/trainer/default.yaml new file mode 100644 index 000000000..beab6c71b --- /dev/null +++ b/training/configs/trainer/default.yaml @@ -0,0 +1,7 @@ +_target_: pytorch_lightning.Trainer + +# set `gpu` to train on GPU, null to train on CPU only +accelerator: null + +min_epochs: 1 +max_epochs: 1000 diff --git a/training/run.py b/training/run.py new file mode 100644 index 000000000..2b22d8e2c --- /dev/null +++ b/training/run.py @@ -0,0 +1,68 @@ +from typing import Callable + +import dotenv +import hydra +from omegaconf import OmegaConf, DictConfig + +# load environment variables from `.env` file if it exists +# recursively searches for `.env` in all folders starting from work dir +dotenv.load_dotenv(override=True) + +OmegaConf.register_new_resolver('eval', eval) +OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y) +# Delay the evaluation until we have the datamodule +# So we want the resolver to yield the same string. +OmegaConf.register_new_resolver('datamodule', lambda attr: '${datamodule:' + str(attr) + '}') + +# Turn on TensorFloat32 +import torch.backends +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +def dictconfig_filter_key(d: DictConfig, fn: Callable) -> DictConfig: + """Only keep keys where fn(key) is True. Support nested DictConfig. + """ + # Using d.items_ex(resolve=False) instead of d.items() since we want to keep the + # ${datamodule:foo} unresolved for now. + return DictConfig({k: dictconfig_filter_key(v, fn) if isinstance(v, DictConfig) else v + # for k, v in d.items_ex(resolve=False) if fn(k)}) + for k, v in d.items() if fn(k)}) + + +@hydra.main(config_path="configs/", config_name="config.yaml") +def main(config: DictConfig): + + # Remove config keys that start with '__'. These are meant to be used only in computing + # other entries in the config. + config = dictconfig_filter_key(config, lambda k: not k.startswith('__')) + + # Imports should be nested inside @hydra.main to optimize tab completion + # Read more here: https://github.com/facebookresearch/hydra/issues/934 + from src.train import train + from src.eval import evaluate + from src.utils import utils + + # A couple of optional utilities: + # - disabling python warnings + # - forcing debug-friendly configuration + # - verifying experiment name is set when running in experiment mode + # You can safely get rid of this line if you don't want those + utils.extras(config) + + # Pretty print config using Rich library + if config.get("print_config"): + utils.print_config(config, resolve=True) + + # Train model + mode = config.get('mode', 'train') + if mode not in ['train', 'eval']: + raise NotImplementedError(f'mode {mode} not supported') + if mode == 'train': + return train(config) + elif mode == 'eval': + return evaluate(config) + + +if __name__ == "__main__": + main() diff --git a/training/src/callbacks/__init__.py b/training/src/callbacks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/training/src/callbacks/causality_monitor.py b/training/src/callbacks/causality_monitor.py new file mode 100644 index 000000000..58e212134 --- /dev/null +++ b/training/src/callbacks/causality_monitor.py @@ -0,0 +1,61 @@ + +import pytorch_lightning as pl +from pytorch_lightning import Callback +from pytorch_lightning.utilities import rank_zero_only + +import torch +from torch.autograd import grad + +class CausalityMonitor(Callback): + r"""Monitor causality of a model by tracking gradient leakage forward in time. + In a fully causal model, dy[k]du[s] ~= 0 for all k < s. + + Args: + seq_len (int): Length of the sequence to monitor. + input_dim (int): Dimension of the input to monitor. If 0, the callback assumes + the task to be language modeling, and skips the embedding layer. If > 0, + input_dim is interpreted as the input channel dimension, i.e. D with + dummy input of dimension [B, L, D]. + + Notes: + This callback assumes that `pl_module.model` has a `net` or `s4seq` attribute, + indicating the primary model to monitor. For LMs, `net` or `s4seq` should + be after the embedding layer. + """ + + def __init__(self, seq_len: int = 10, input_dim: int = 0): + super().__init__() + self.seq_len = seq_len + self.input_dim = input_dim + + @rank_zero_only + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + model = pl_module.model + + with torch.enable_grad(): + if self.input_dim == 0: + # [MP] LongTensors cannot have gradients - we start from post + # embedding in the LM case + input_dim = model.d_model + x = torch.randn((2, self.seq_len, input_dim), \ + requires_grad=True).to(pl_module.device) + # [DF] HACK: we need to get the layer that comes after the embedding + if hasattr(model, 'net'): + y = model.net(x) + else: + y = model.s4seq(x) + else: + x = torch.randn(1, self.seq_len, self.input_dim, \ + requires_grad=True).to(pl_module.device) + y = model(x) + + stats = {} + for i in range(self.seq_len): + # total gradients flowing from y_i to x + g = grad(y[0,0,i].mean(), x, retain_graph=True, allow_unused=True)[0] + g = g[0,i+1:,:].abs().mean() + stats[f'stats/causality_{i}'] = g.item() + + if trainer.loggers is not None: + for logger in trainer.loggers: + logger.log_metrics(stats, step=trainer.global_step) diff --git a/training/src/callbacks/ema.py b/training/src/callbacks/ema.py new file mode 100644 index 000000000..14941715b --- /dev/null +++ b/training/src/callbacks/ema.py @@ -0,0 +1,82 @@ +# Inspired by https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/stochastic_weight_avg.py +# https://github.com/PyTorchLightning/Lightning-Bolts/blob/master/pl_bolts/callbacks/byol_updates.py +# https://forums.pytorchlightning.ai/t/adopting-exponential-moving-average-ema-for-pl-pipeline/488/2 +# https://github.com/PyTorchLightning/pytorch-lightning/issues/8100 + +from typing import Dict, Any + +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.types import STEP_OUTPUT + +from src.utils.ema import ExponentialMovingAverage + + +class EMACallback(Callback): + """TD [2021-08-31]: saving and loading from checkpoint should work. + """ + def __init__(self, decay: float, use_num_updates: bool = True): + """ + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing + averages. + """ + super().__init__() + self.decay = decay + self.use_num_updates = use_num_updates + self.ema = None + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + # It's possible that we already loaded EMA from the checkpoint + if self.ema is None: + self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad], + decay=self.decay, use_num_updates=self.use_num_updates) + + # Ideally we want on_after_optimizer_step but pytorch-lightning doesn't have it + # We only want to update when parameters are changing. + # Because of gradient accumulation, this doesn't happen every training step. + # https://github.com/PyTorchLightning/pytorch-lightning/issues/11688 + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + if (batch_idx + 1) % trainer.accumulate_grad_batches == 0: + self.ema.update() + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + # During the initial validation we don't have self.ema yet + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema is not None: + self.ema.restore() + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self.ema is not None: + self.ema.restore() + + def on_save_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> Dict[str, Any]: + return self.ema.state_dict() + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", + checkpoint: Dict[str, Any] + ) -> None: + if self.ema is None: + self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad], + decay=self.decay, use_num_updates=self.use_num_updates) + self.ema.load_state_dict(checkpoint) diff --git a/training/src/callbacks/flop_count.py b/training/src/callbacks/flop_count.py new file mode 100644 index 000000000..93053bbf7 --- /dev/null +++ b/training/src/callbacks/flop_count.py @@ -0,0 +1,43 @@ +# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py +from typing import Any, List, Sequence + +import torch + +from pytorch_lightning import Callback, Trainer, LightningModule +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.parsing import AttributeDict + +from src.utils.flops import has_deepspeed_profiling, has_fvcore_profiling +from src.utils.flops import profile_deepspeed, profile_fvcore + + +class FlopCount(Callback): + """Counter the number of FLOPs used by the model + """ + def __init__(self, profilers: List[str] = ['fvcore', 'deepspeed'], + input_size: tuple = (3, 224, 224), input_dtype=torch.float32, device=None): + if not isinstance(profilers, Sequence): + profilers = [profilers] + if any(p not in ['fvcore', 'deepspeed'] for p in profilers): + raise NotImplementedError('Only support fvcore and deepspeed profilers') + if 'fvcore' in profilers and not has_fvcore_profiling: + raise ImportError('fvcore is not installed. Install it by running `pip install fvcore`') + elif 'deepspeed' in profilers and not has_deepspeed_profiling: + raise ImportError('deepspeed is not installed') + super().__init__() + self.profilers = profilers + self.input_size = tuple(input_size) + self.input_dtype = input_dtype + self.device = device + + @rank_zero_only + def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + if 'fvcore' in self.profilers: + _, macs, _, acts = profile_fvcore(pl_module.to(self.device), input_size=self.input_size, + input_dtype=self.input_dtype, detailed=True) + trainer.logger.log_hyperparams({'GMACs': macs * 1e-9, 'MActs': acts * 1e-6}) + if 'deepspeed' in self.profilers: + macs, _= profile_deepspeed(pl_module.to(self.device), input_size=self.input_size, + input_dtype=self.input_dtype, detailed=True) + if 'fvcore' not in self.profilers: # fvcore's MACs seem more accurate + trainer.logger.log_hyperparams({'GMACs': macs * 1e-9}) diff --git a/training/src/callbacks/gpu_affinity.py b/training/src/callbacks/gpu_affinity.py new file mode 100644 index 000000000..1e8a64ed3 --- /dev/null +++ b/training/src/callbacks/gpu_affinity.py @@ -0,0 +1,40 @@ +import torch + +from pytorch_lightning import Callback, Trainer, LightningModule + +import logging + +log = logging.getLogger(__name__) # We want a logger for each process, not just the rank 0 + + +def l2_promote(): + import ctypes + _libcudart = ctypes.CDLL('libcudart.so') + # Set device limit on the current device + # cudaLimitMaxL2FetchGranularity = 0x05 + pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int)) + _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) + _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) + assert pValue.contents.value == 128 + + +def set_affinity(trainer): + try: + from src.utils.gpu_affinity import set_affinity + nproc_per_node = torch.cuda.device_count() + affinity = set_affinity(trainer.local_rank, nproc_per_node, 'socket_unique_continuous') + log.info(f'{trainer.local_rank}: thread affinity: {affinity}') + # TD [2022-05-07] Somehow calling this causes GPU 0 to allocate extra ~800MB of memory per + # number of GPUs (e.g., 6.4GB of extra memory in a 8-GPU setup). H/t Dan. + # l2_promote() + except: + pass + + +class GpuAffinity(Callback): + """Set GPU affinity and increase the L2 fetch granularity. + Adapted from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL + """ + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage=None) -> None: + set_affinity(trainer) diff --git a/training/src/callbacks/loss_scale_monitor.py b/training/src/callbacks/loss_scale_monitor.py new file mode 100644 index 000000000..81d325f4e --- /dev/null +++ b/training/src/callbacks/loss_scale_monitor.py @@ -0,0 +1,32 @@ +# Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/lr_monitor.py. +from typing import Any + +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.strategies import DeepSpeedStrategy + + +class LossScaleMonitor(Callback): + """Monitor the loss scale for AMP (fp16). + """ + + # Use on_before_optimizer_step instead of on_train_batch_start since there might be + # gradient accumulation and we only care about the loss scale when it could change (i.e., + # optimizer.step). + @rank_zero_only + def on_before_optimizer_step(self, trainer: Trainer, *args: Any, **kwargs: Any) -> None: + if not trainer._logger_connector.should_update_logs: + return + stats = {} + if isinstance(trainer.strategy, DeepSpeedStrategy): + stats = {'scalar/scale': trainer.model.optimizer.loss_scale} + if hasattr(trainer, 'precision_plugin') and hasattr(trainer.precision_plugin, 'scaler'): + scaler = trainer.precision_plugin.scaler + if scaler is not None: + stats = { + 'scaler/scale': scaler.get_scale(), + 'scaler/growth_tracker': scaler._get_growth_tracker(), + } + if stats and trainer.loggers is not None: + for logger in trainer.loggers: + logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) diff --git a/training/src/callbacks/model_checkpoint.py b/training/src/callbacks/model_checkpoint.py new file mode 100644 index 000000000..09a2e91ad --- /dev/null +++ b/training/src/callbacks/model_checkpoint.py @@ -0,0 +1,36 @@ +# Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/fault_tolerance.py +from typing import Any +from pathlib import Path + +import pytorch_lightning as pl + + +class ModelCheckpointMine(pl.callbacks.model_checkpoint.ModelCheckpoint): + + def __init__(self, *args, fault_tolerant=False, **kwargs): + super().__init__(*args, **kwargs) + self.fault_tolerant = fault_tolerant + + def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: + if self.fault_tolerant: + # overwrite if necessary + trainer.save_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt')) + + # def teardown(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: + # if self.fault_tolerant: + # trainer.strategy.remove_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt')) + + +# TD [2022-07-17] I was trying to make resuming from standard checkpoint fault-tolerant. +# However, when it resumes it's off by 1 iteration. My attempt to fix it in seq.py (below) didn't work. +# So I decided to just copy _FaultToleranceCheckpoint and just save on_exception. + + # def on_save_checkpoint(self, checkpoint): + # # TD [2022-07-12] The "completed" counter is off by 1 so when it resumes + # # it's off by 1 iteration. However, the data is still off by 1 iteration, probably + # # because the dataloader_state_dict['counter'] is off by @batch_size, and idk how + # # to fix it cleanly. + # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] += 1 + # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] += 1 + # checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] += 1 + # checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['dataloader_state_dict'][0]['state'][0]['num_batches_fetched'] += 1 diff --git a/training/src/callbacks/norm_monitor.py b/training/src/callbacks/norm_monitor.py new file mode 100644 index 000000000..3a8943cb8 --- /dev/null +++ b/training/src/callbacks/norm_monitor.py @@ -0,0 +1,79 @@ +# Inspired by https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/utilities/grads.py +# However, they compute grad at every iteration (I think), and the .item() calls incur a lot of overhead +# (6-7% slow down on GPT-2 small). Instead we only compute for iterations where we need to log, and don't +# call .item() explicitly. + +from typing import Any +from collections import OrderedDict + +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.strategies import DeepSpeedStrategy + +import torch +import torch.nn as nn + +try: + from apex.contrib.layer_norm import FastLayerNorm +except ImportError: + FastLayerNorm = None + + +class NormMonitor(Callback): + """Monitor the scales of weights and gradients. + """ + + def __init__(self, layer_norm_only: bool = False): + super().__init__() + self.layer_norm_only = layer_norm_only + + # Use on_before_optimizer_step instead of on_train_batch_start since there might be + # gradient accumulation and we only care about scale when it could change (i.e., optimizer.step). + @rank_zero_only + def on_before_optimizer_step(self, trainer: Trainer, pl_module, *args: Any, **kwargs: Any) -> None: + if not trainer._logger_connector.should_update_logs: + return + model = pl_module.model + named_parameters = {} + if self.layer_norm_only: + ln_modules = (nn.LayerNorm, nn.Embedding) + if FastLayerNorm is not None: + ln_modules += (FastLayerNorm,) + for mn, m in model.named_modules(): + if isinstance(m, ln_modules): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + named_parameters[fpn] = p + else: + named_parameters = dict(model.named_parameters()) + + if isinstance(trainer.strategy, DeepSpeedStrategy): + loss_scale = trainer.model.optimizer.loss_scale + else: + loss_scale = 1.0 + + stats = {} + param_l1_norm, grad_l1_norm = [], [] + for param_name, param in named_parameters.items(): + param_abs = param.abs() + param_abs_mean = param_abs.mean(dtype=torch.float32) + stats[f'stats/{param_name}_max'] = param_abs.max() + stats[f'stats/{param_name}_mean'] = param_abs_mean + param_l1_norm.append(param_abs_mean * param.numel()) + if param.grad is not None: + # If using AMP, gradient is already unscaled by the AMP loss scaler at this point + # https://github.com/Lightning-AI/lightning/pull/9606 + # However, if using DeepSpeed, we need to scale it ourselves + param_grad_abs = param.grad.abs() + param_grad_abs_mean = param_grad_abs.mean(dtype=torch.float32) / loss_scale + stats[f'stats/{param_name}_grad_max'] = param_grad_abs.max() / loss_scale + stats[f'stats/{param_name}_grad_mean'] = param_grad_abs_mean + grad_l1_norm.append(param_grad_abs_mean * param.grad.numel()) + stats['total_param_l1_norm'] = torch.stack(param_l1_norm).sum() + if grad_l1_norm: + stats['total_grad_l1_norm'] = torch.stack(grad_l1_norm).sum() + # Sort by params name + stats = OrderedDict(sorted(stats.items())) + if trainer.loggers is not None: + for logger in trainer.loggers: + logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) diff --git a/training/src/callbacks/params_log.py b/training/src/callbacks/params_log.py new file mode 100644 index 000000000..c594b19c8 --- /dev/null +++ b/training/src/callbacks/params_log.py @@ -0,0 +1,34 @@ +from typing import Any + +from pytorch_lightning import Callback, Trainer, LightningModule +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.parsing import AttributeDict + + +class ParamsLog(Callback): + """Log the number of parameters of the model + """ + def __init__(self, total_params_log: bool = True, trainable_params_log: bool = True, + non_trainable_params_log: bool = True): + super().__init__() + self._log_stats = AttributeDict( + { + 'total_params_log': total_params_log, + 'trainable_params_log': trainable_params_log, + 'non_trainable_params_log': non_trainable_params_log, + } + ) + + @rank_zero_only + def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + logs = {} + if self._log_stats.total_params_log: + logs["model/params_total"] = sum(p.numel() for p in pl_module.parameters()) + if self._log_stats.trainable_params_log: + logs["model/params_trainable"] = sum(p.numel() for p in pl_module.parameters() + if p.requires_grad) + if self._log_stats.non_trainable_params_log: + logs["model/params_not_trainable"] = sum(p.numel() for p in pl_module.parameters() + if not p.requires_grad) + if trainer.logger is not None: + trainer.logger.log_hyperparams(logs) diff --git a/training/src/callbacks/speed_monitor.py b/training/src/callbacks/speed_monitor.py new file mode 100644 index 000000000..b17a09cfb --- /dev/null +++ b/training/src/callbacks/speed_monitor.py @@ -0,0 +1,95 @@ +# Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor +# We only need the speed monitoring, not the GPU monitoring +import time +from typing import Any + +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.types import STEP_OUTPUT + + +class SpeedMonitor(Callback): + """Monitor the speed of each step and each epoch. + """ + def __init__(self, intra_step_time: bool = True, inter_step_time: bool = True, + epoch_time: bool = True, verbose=False): + super().__init__() + self._log_stats = AttributeDict( + { + 'intra_step_time': intra_step_time, + 'inter_step_time': inter_step_time, + 'epoch_time': epoch_time, + } + ) + self.verbose = verbose + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._snap_epoch_time = None + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._snap_intra_step_time = None + self._snap_inter_step_time = None + self._snap_epoch_time = time.time() + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._snap_inter_step_time = None + + def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._snap_inter_step_time = None + + @rank_zero_only + def on_train_batch_start( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + batch: Any, + batch_idx: int, + ) -> None: + if self._log_stats.intra_step_time: + self._snap_intra_step_time = time.time() + + if not trainer._logger_connector.should_update_logs: + return + + logs = {} + if self._log_stats.inter_step_time and self._snap_inter_step_time: + # First log at beginning of second step + logs["time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000 + + if trainer.logger is not None: + trainer.logger.log_metrics(logs, step=trainer.global_step) + + @rank_zero_only + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + if self._log_stats.inter_step_time: + self._snap_inter_step_time = time.time() + + if self.verbose and self._log_stats.intra_step_time and self._snap_intra_step_time: + pl_module.print(f"time/intra_step (ms): {(time.time() - self._snap_intra_step_time) * 1000}") + + if not trainer._logger_connector.should_update_logs: + return + + logs = {} + if self._log_stats.intra_step_time and self._snap_intra_step_time: + logs["time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000 + + if trainer.logger is not None: + trainer.logger.log_metrics(logs, step=trainer.global_step) + + @rank_zero_only + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",) -> None: + logs = {} + if self._log_stats.epoch_time and self._snap_epoch_time: + logs["time/epoch (s)"] = time.time() - self._snap_epoch_time + if trainer.logger is not None: + trainer.logger.log_metrics(logs, step=trainer.global_step) + diff --git a/training/src/callbacks/wandb_callbacks.py b/training/src/callbacks/wandb_callbacks.py new file mode 100644 index 000000000..4e0a46523 --- /dev/null +++ b/training/src/callbacks/wandb_callbacks.py @@ -0,0 +1,289 @@ +import subprocess +from pathlib import Path +from typing import List + +import matplotlib.pyplot as plt +import seaborn as sn +import torch +import wandb +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.loggers import LoggerCollection, WandbLogger +from pytorch_lightning.utilities import rank_zero_only +from sklearn import metrics +from sklearn.metrics import f1_score, precision_score, recall_score + + +def get_wandb_logger(trainer: Trainer) -> WandbLogger: + """Safely get Weights&Biases logger from Trainer.""" + + if trainer.fast_dev_run: + raise Exception( + "Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode." + ) + + if isinstance(trainer.logger, WandbLogger): + return trainer.logger + + if isinstance(trainer.logger, LoggerCollection): + for logger in trainer.logger: + if isinstance(logger, WandbLogger): + return logger + + raise Exception( + "You are using wandb related callback, but WandbLogger was not found for some reason..." + ) + + +class WatchModel(Callback): + """Make wandb watch model at the beginning of the run.""" + + def __init__(self, log: str = "gradients", log_freq: int = 100): + self.log = log + self.log_freq = log_freq + + @rank_zero_only + def on_train_start(self, trainer, pl_module): + logger = get_wandb_logger(trainer=trainer) + logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + + +class UploadCodeAsArtifact(Callback): + """Upload all code files to wandb as an artifact, at the beginning of the run.""" + + def __init__(self, code_dir: str, use_git: bool = True): + """ + + Args: + code_dir: the code directory + use_git: if using git, then upload all files that are not ignored by git. + if not using git, then upload all '*.py' file + """ + self.code_dir = code_dir + self.use_git = use_git + + @rank_zero_only + def on_train_start(self, trainer, pl_module): + logger = get_wandb_logger(trainer=trainer) + experiment = logger.experiment + + code = wandb.Artifact("project-source", type="code") + + if self.use_git: + # get .git folder + # https://alexwlchan.net/2020/11/a-python-function-to-ignore-a-path-with-git-info-exclude/ + git_dir_path = Path( + subprocess.check_output(["git", "rev-parse", "--git-dir"]).strip().decode("utf8") + ).resolve() + + for path in Path(self.code_dir).resolve().rglob("*"): + if ( + path.is_file() + # ignore files in .git + and not str(path).startswith(str(git_dir_path)) # noqa: W503 + # ignore files ignored by git + and ( # noqa: W503 + subprocess.run(["git", "check-ignore", "-q", str(path)]).returncode == 1 + ) + ): + code.add_file(str(path), name=str(path.relative_to(self.code_dir))) + + else: + for path in Path(self.code_dir).resolve().rglob("*.py"): + code.add_file(str(path), name=str(path.relative_to(self.code_dir))) + + experiment.log_artifact(code) + + +class UploadCheckpointsAsArtifact(Callback): + """Upload checkpoints to wandb as an artifact, at the end of run.""" + + def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): + self.ckpt_dir = ckpt_dir + self.upload_best_only = upload_best_only + + @rank_zero_only + def on_keyboard_interrupt(self, trainer, pl_module): + self.on_train_end(trainer, pl_module) + + @rank_zero_only + def on_train_end(self, trainer, pl_module): + logger = get_wandb_logger(trainer=trainer) + experiment = logger.experiment + + ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") + + if self.upload_best_only: + ckpts.add_file(trainer.checkpoint_callback.best_model_path) + else: + for path in Path(self.ckpt_dir).rglob("*.ckpt"): + ckpts.add_file(str(path)) + + experiment.log_artifact(ckpts) + + +class LogConfusionMatrix(Callback): + """Generate confusion matrix every epoch and send it to wandb. + Expects validation step to return predictions and targets. + """ + + def __init__(self): + self.preds = [] + self.targets = [] + self.ready = True + + def on_sanity_check_start(self, trainer, pl_module) -> None: + self.ready = False + + def on_sanity_check_end(self, trainer, pl_module): + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ): + """Gather data from single batch.""" + if self.ready: + self.preds.append(outputs["preds"]) + self.targets.append(outputs["targets"]) + + def on_validation_epoch_end(self, trainer, pl_module): + """Generate confusion matrix.""" + if self.ready: + logger = get_wandb_logger(trainer) + experiment = logger.experiment + + preds = torch.cat(self.preds).cpu().numpy() + targets = torch.cat(self.targets).cpu().numpy() + + confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds) + + # set figure size + plt.figure(figsize=(14, 8)) + + # set labels size + sn.set(font_scale=1.4) + + # set font size + sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g") + + # names should be uniqe or else charts from different experiments in wandb will overlap + experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False) + + # according to wandb docs this should also work but it crashes + # experiment.log(f{"confusion_matrix/{experiment.name}": plt}) + + # reset plot + plt.clf() + + self.preds.clear() + self.targets.clear() + + +class LogF1PrecRecHeatmap(Callback): + """Generate f1, precision, recall heatmap every epoch and send it to wandb. + Expects validation step to return predictions and targets. + """ + + def __init__(self, class_names: List[str] = None): + self.preds = [] + self.targets = [] + self.ready = True + + def on_sanity_check_start(self, trainer, pl_module): + self.ready = False + + def on_sanity_check_end(self, trainer, pl_module): + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ): + """Gather data from single batch.""" + if self.ready: + self.preds.append(outputs["preds"]) + self.targets.append(outputs["targets"]) + + def on_validation_epoch_end(self, trainer, pl_module): + """Generate f1, precision and recall heatmap.""" + if self.ready: + logger = get_wandb_logger(trainer=trainer) + experiment = logger.experiment + + preds = torch.cat(self.preds).cpu().numpy() + targets = torch.cat(self.targets).cpu().numpy() + f1 = f1_score(targets, preds, average=None) + r = recall_score(targets, preds, average=None) + p = precision_score(targets, preds, average=None) + data = [f1, p, r] + + # set figure size + plt.figure(figsize=(14, 3)) + + # set labels size + sn.set(font_scale=1.2) + + # set font size + sn.heatmap( + data, + annot=True, + annot_kws={"size": 10}, + fmt=".3f", + yticklabels=["F1", "Precision", "Recall"], + ) + + # names should be uniqe or else charts from different experiments in wandb will overlap + experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False) + + # reset plot + plt.clf() + + self.preds.clear() + self.targets.clear() + + +class LogImagePredictions(Callback): + """Logs a validation batch and their predictions to wandb. + Example adapted from: + https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY + """ + + def __init__(self, num_samples: int = 8): + super().__init__() + self.num_samples = num_samples + self.ready = True + + def on_sanity_check_start(self, trainer, pl_module): + self.ready = False + + def on_sanity_check_end(self, trainer, pl_module): + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_epoch_end(self, trainer, pl_module): + if self.ready: + logger = get_wandb_logger(trainer=trainer) + experiment = logger.experiment + + # get a validation batch from the validation dat loader + val_samples = next(iter(trainer.datamodule.val_dataloader())) + val_imgs, val_labels = val_samples + + # run the batch through the network + val_imgs = val_imgs.to(device=pl_module.device) + logits = pl_module(val_imgs) + preds = torch.argmax(logits, dim=-1) + + # log the images as wandb Image + experiment.log( + { + f"Images/{experiment.name}": [ + wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") + for x, pred, y in zip( + val_imgs[: self.num_samples], + preds[: self.num_samples], + val_labels[: self.num_samples], + ) + ] + } + ) diff --git a/training/src/datamodules/datasets/detokenizer.py b/training/src/datamodules/datasets/detokenizer.py new file mode 100644 index 000000000..fe0d1b77b --- /dev/null +++ b/training/src/datamodules/datasets/detokenizer.py @@ -0,0 +1,53 @@ +# Copied from https://github.com/stanford-crfm/mistral/blob/main/src/corpora/detokenization.py +# Which was originally from https://github.com/NVIDIA/Megatron-LM/blob/aed2f75e209e525c842aec7c044af7acae2a4614/tasks/zeroshot_gpt/detokenizer.py + +""" +Handle detokenization for different dataset for zero-shot LM evaluation. +""" +import re + + +def wikitext_detokenize(string: str) -> str: + """ + Wikitext is whitespace tokenized and we remove these whitespaces. + Taken from https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt2/detokenizer.py + """ + # Contractions + string = string.replace("s '", "s'") + string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) + + # Number Separators + string = string.replace(" @-@ ", "-") + string = string.replace(" @,@ ", ",") + string = string.replace(" @.@ ", ".") + + # Punctuation + string = string.replace(" : ", ": ") + string = string.replace(" ; ", "; ") + string = string.replace(" . ", ". ") + string = string.replace(" ! ", "! ") + string = string.replace(" ? ", "? ") + string = string.replace(" , ", ", ") + + # Double Brackets + string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) + string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) + string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) + string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) + string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) + + # Miscellaneous + string = string.replace("= = = =", "====") + string = string.replace("= = =", "===") + string = string.replace("= =", "==") + string = string.replace(" " + chr(176) + " ", chr(176)) + string = string.replace(" \n", "\n") + string = string.replace("\n ", "\n") + string = string.replace(" N ", " 1 ") + string = string.replace(" 's", "'s") + + return string + + +# Set Registry for Various Datasets +DATASET_TOKENIZATION_REGISTRY = {"wikitext": wikitext_detokenize} diff --git a/training/src/datamodules/datasets/lm_dataset.py b/training/src/datamodules/datasets/lm_dataset.py new file mode 100644 index 000000000..72bd956e8 --- /dev/null +++ b/training/src/datamodules/datasets/lm_dataset.py @@ -0,0 +1,32 @@ +# Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py +# Except we don't pad the last block and don't use overlapping eval +# And we return both the input and the target +import math +import numpy as np + +import torch + + +class LMDataset(torch.utils.data.Dataset): + + def __init__(self, tokens, seq_len, drop_last=True): + """tokens should be a numpy array + """ + self.seq_len = seq_len + ntokens = len(tokens) + if drop_last: + ntokens = ((ntokens - 1) // seq_len) * seq_len + 1 + self.ntokens = ntokens + # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset, + # and slicing would load it to memory. + self.tokens = tokens + self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len) + + def __len__(self): + return self.total_sequences + + def __getitem__(self, idx): + start_idx = idx * self.seq_len + seq_len = min(self.seq_len, self.ntokens - 1 - start_idx) + data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64)) + return data[:-1], data[1:].clone() diff --git a/training/src/datamodules/fault_tolerant_sampler.py b/training/src/datamodules/fault_tolerant_sampler.py new file mode 100644 index 000000000..11a157182 --- /dev/null +++ b/training/src/datamodules/fault_tolerant_sampler.py @@ -0,0 +1,123 @@ +# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397 +from typing import Iterator +import math + +import torch +from torch.utils.data import RandomSampler, DistributedSampler + + +class RandomFaultTolerantSampler(RandomSampler): + + def __init__(self, *args, generator=None, **kwargs): + # generator = torch.Generator().manual_seed(seed) + # super().__init__(*args, generator=generator, **kwargs) + # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed, + # which should be reproducible if pl.seed_everything was called before hand. + # This means that changing the seed of the experiment will also change the + # sampling order. + if generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator().manual_seed(seed) + super().__init__(*args, generator=generator, **kwargs) + self.counter = 0 + # self.start_counter = 0 + self.restarting = False + + def state_dict(self): + return {"random_state": self.state, "counter": self.counter} + + def load_state_dict(self, state_dict): + self.generator.set_state(state_dict.get("random_state")) + self.counter = state_dict["counter"] + # self.start_counter = self.counter + self.restarting = True + + # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per + # epoch, and subsequent epoch will have very few batches. + # def __len__(self): + # # We need a separate self.start_counter because PL seems to call len repeatedly. + # # If we use len(self.data_source) - self.counter then PL will think the epoch ends + # # when we're only half way through. + # return len(self.data_source) - self.start_counter + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + + self.state = self.generator.get_state() + indices = torch.randperm(n, generator=self.generator).tolist() + + if not self.restarting: + self.counter = 0 + else: + indices = indices[self.counter:] + self.restarting = False + # self.start_counter = self.counter + + for index in indices: + self.counter += 1 + yield index + + self.counter = 0 + # self.start_counter = self.counter + + +class FaultTolerantDistributedSampler(DistributedSampler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.counter = 0 + # self.start_counter = 0 + self.restarting = False + + def state_dict(self): + return {"epoch": self.epoch, "counter": self.counter} + + def load_state_dict(self, state_dict): + self.epoch = state_dict["epoch"] + self.counter = state_dict["counter"] + # self.start_counter = self.counter + self.restarting = True + + # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per + # epoch, and subsequent epoch will have very few batches. + # def __len__(self) -> int: + # return self.num_samples - self.start_counter + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + if not self.restarting: + self.counter = 0 + else: + indices = indices[self.counter:] + self.restarting = False + # self.start_counter = self.counter + + for index in indices: + self.counter += 1 + yield index + + self.counter = 0 + # self.start_counter = self.counter diff --git a/training/src/datamodules/imagenet.py b/training/src/datamodules/imagenet.py new file mode 100644 index 000000000..e36bca847 --- /dev/null +++ b/training/src/datamodules/imagenet.py @@ -0,0 +1,283 @@ +# Adapted from https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/datamodules/imagenet_datamodule.py +import os +from pathlib import Path +from typing import Any, List, Union, Callable, Optional + +import torch +from torch.utils.data import Dataset, DataLoader, SequentialSampler +from torch.utils.data.dataloader import default_collate +from torch.utils.data.distributed import DistributedSampler + +from pytorch_lightning import LightningDataModule + +from torchvision import transforms +from torchvision.datasets import ImageFolder + + +class DictDataset(Dataset): + + def __init__(self, dataset_dict, length=None): + """dataset_dict: dictionary mapping from index to batch + length is used in the case of DistributedSampler: e.g. the dataset could have size 1k, but + with 8 GPUs the dataset_dict would only have 125 items. + """ + super().__init__() + self.dataset_dict = dataset_dict + self.length = length or len(self.dataset_dict) + + def __getitem__(self, index): + return self.dataset_dict[index] + + def __len__(self): + return self.length + + +# From https://github.com/PyTorchLightning/lightning-bolts/blob/2415b49a2b405693cd499e09162c89f807abbdc4/pl_bolts/transforms/dataset_normalizations.py#L10 +def imagenet_normalization(): + return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + +class ImagenetDataModule(LightningDataModule): + """ + .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/ + Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png + :width: 400 + :alt: Imagenet + Specs: + - 1000 classes + - Each image is (3 x varies x varies) (here we default to 3 x 224 x 224) + Imagenet train, val and test dataloaders. + The train set is the imagenet train. + The val set is taken from the train set with `num_imgs_per_val_class` images per class. + For example if `num_imgs_per_val_class=2` then there will be 2,000 images in the validation set. + The test set is the official imagenet validation set. + Example:: + from pl_bolts.datamodules import ImagenetDataModule + dm = ImagenetDataModule(IMAGENET_PATH) + model = LitModel() + Trainer().fit(model, datamodule=dm) + """ + + name = "imagenet" + + def __init__( + self, + data_dir: str, + image_size: int = 224, + train_transforms=None, + val_transforms=None, + test_transforms=None, + img_dtype='float32', # Using str since OmegaConf doesn't support non-primitive type + cache_val_dataset=False, + mixup: Optional[Callable] = None, + num_aug_repeats: int = 0, + num_workers: int = 0, + batch_size: int = 32, + batch_size_eval: Optional[int] = None, + shuffle: bool = True, + pin_memory: bool = True, + drop_last: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Args: + data_dir: path to the imagenet dataset file + num_imgs_per_val_class: how many images per class for the validation set + image_size: final image size + num_workers: how many data workers + batch_size: batch_size + shuffle: If true shuffles the data every epoch + pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before + returning them + drop_last: If true drops the last incomplete batch + """ + super().__init__(*args, **kwargs) + + self.image_size = image_size + self.train_transforms = train_transforms + self.val_transforms = val_transforms + self.test_transforms = test_transforms + assert img_dtype in ['float32', 'float16', 'bfloat16'] + self.img_dtype = torch.__getattribute__(img_dtype) + self.cache_val_dataset = cache_val_dataset + self.mixup = mixup + self.num_aug_repeats = num_aug_repeats + self.dims = (3, self.image_size, self.image_size) + self.data_dir = Path(data_dir).expanduser() + self.num_workers = num_workers + self.batch_size = batch_size + self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + + @property + def num_classes(self) -> int: + """ + Return: + 1000 + """ + return 1000 + + def _verify_splits(self, data_dir: str, split: str) -> None: + dirs = os.listdir(data_dir) + + if split not in dirs: + raise FileNotFoundError( + f"a {split} Imagenet split was not found in {data_dir}," + f" make sure the folder contains a subfolder named {split}" + ) + + def prepare_data(self) -> None: + """This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin. + .. warning:: Please download imagenet on your own first. + """ + self._verify_splits(self.data_dir, "train") + self._verify_splits(self.data_dir, "val") + + def setup(self, stage: Optional[str] = None) -> None: + """Creates train, val, and test dataset.""" + if stage == "fit" or stage is None: + train_transforms = (self.train_transform() if self.train_transforms is None + else self.train_transforms) + val_transforms = (self.val_transform() if self.val_transforms is None + else self.val_transforms) + if self.img_dtype is not torch.float32: + assert isinstance(train_transforms, transforms.Compose) + assert isinstance(val_transforms, transforms.Compose) + convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype)) + train_transforms.transforms.append(convert_dtype) + val_transforms.transforms.append(convert_dtype) + self.dataset_train = ImageFolder(self.data_dir / 'train', transform=train_transforms) + self.dataset_val = ImageFolder(self.data_dir / 'val', transform=val_transforms) + + if stage == "test" or stage is None: + test_transforms = (self.val_transform() if self.test_transforms is None + else self.test_transforms) + if self.img_dtype is not torch.float32: + assert isinstance(test_transforms, transforms.Compose) + convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype)) + test_transforms.transforms.append(convert_dtype) + self.dataset_test = ImageFolder(self.data_dir / 'val', transform=test_transforms) + + def train_transform(self) -> Callable: + """The standard imagenet transforms. + .. code-block:: python + transforms.Compose([ + transforms.RandomResizedCrop(self.image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + """ + preprocessing = transforms.Compose( + [ + transforms.RandomResizedCrop(self.image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + imagenet_normalization(), + ] + ) + + return preprocessing + + def val_transform(self) -> Callable: + """The standard imagenet transforms for validation. + .. code-block:: python + transforms.Compose([ + transforms.Resize(self.image_size + 32), + transforms.CenterCrop(self.image_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + """ + + preprocessing = transforms.Compose( + [ + transforms.Resize(self.image_size + 32), + transforms.CenterCrop(self.image_size), + transforms.ToTensor(), + imagenet_normalization(), + ] + ) + return preprocessing + + def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: + """ The train dataloader """ + if self.num_aug_repeats == 0: + shuffle = self.shuffle + sampler = None + else: + shuffle = False + from timm.data.distributed_sampler import RepeatAugSampler + sampler = RepeatAugSampler(self.dataset_train, num_repeats=self.num_aug_repeats) + return self._data_loader(self.dataset_train, batch_size=self.batch_size, + shuffle=shuffle, mixup=self.mixup, sampler=sampler) + + def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: + """ The val dataloader """ + # If using RepeatAugment, we set trainer.replace_sampler_ddp=False, so we have to + # construct the DistributedSampler ourselves. + if not self.cache_val_dataset: + sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last) + if self.num_aug_repeats != 0 else None) + return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval, + sampler=sampler) + else: + print('Caching val dataset') + sampler = (SequentialSampler(self.dataset_val) if self.trainer.world_size <= 1 + else DistributedSampler(self.dataset_val, shuffle=False, + drop_last=self.drop_last)) + indices = list(iter(sampler)) + loader = DataLoader(self.dataset_val, batch_size=None, shuffle=False, sampler=sampler, + num_workers=self.num_workers, drop_last=self.drop_last) + batches = list(loader) + assert len(batches) == len(indices) + self.dataset_val = DictDataset(dict(zip(indices, batches)), + length=len(self.dataset_val)) + sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last) + if self.num_aug_repeats != 0 else None) + return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval, + sampler=sampler) + + def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: + """ The test dataloader """ + sampler = (DistributedSampler(self.dataset_test, shuffle=False, drop_last=self.drop_last) + if self.num_aug_repeats != 0 else None) + return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, sampler=sampler) + + def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, + mixup: Optional[Callable] = None, sampler=None) -> DataLoader: + collate_fn = ((lambda batch: mixup(*default_collate(batch))) if mixup is not None + else default_collate) + return DataLoader( + dataset, + collate_fn=collate_fn, + batch_size=batch_size, + shuffle=shuffle, + sampler=sampler, + num_workers=self.num_workers, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + persistent_workers=True + ) + + +class Imagenet21kPDataModule(ImagenetDataModule): + """ImageNet-21k (winter 21) processed with https://github.com/Alibaba-MIIL/ImageNet21K + """ + + @property + def num_classes(self) -> int: + """ + Return: + 10450 + """ + return 10450 diff --git a/training/src/datamodules/language_modeling_hf.py b/training/src/datamodules/language_modeling_hf.py new file mode 100644 index 000000000..eaa35d3a0 --- /dev/null +++ b/training/src/datamodules/language_modeling_hf.py @@ -0,0 +1,299 @@ +# Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py +from itertools import chain +from pathlib import Path +import pickle +from typing import Any, List, Union +import subprocess +import mmap + +from multiprocessing.shared_memory import SharedMemory + +import numpy as np + +import torch +from torch.utils.data.dataloader import DataLoader, Dataset +from transformers import AutoTokenizer +from datasets import load_dataset + +from pytorch_lightning import LightningDataModule + +from src.datamodules.datasets.lm_dataset import LMDataset +from src.datamodules.fault_tolerant_sampler import RandomFaultTolerantSampler +from src.datamodules.fault_tolerant_sampler import FaultTolerantDistributedSampler +from src.datamodules.datasets.detokenizer import DATASET_TOKENIZATION_REGISTRY +from src.utils.utils import get_logger +logger = get_logger() + + +# https://github.com/numpy/numpy/issues/18294 +class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array + + def __new__(cls, input_array, shm=None): + obj = np.asarray(input_array).view(cls) + obj.shm = shm + return obj + + def __array_finalize__(self, obj): + if obj is None: return + self.shm = getattr(obj, 'shm', None) + + +class LMDataModule(LightningDataModule): + def __init__(self, dataset_name, tokenizer_name, dataset_config_name=None, max_length=1024, + cache_dir=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True, + detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1, + shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False, + fast_forward_epochs=None, fast_forward_batches=None, + use_shmem=True): + super().__init__() + self.dataset_name = dataset_name + self.dataset_config_name = dataset_config_name + self.tokenizer_name = tokenizer_name + self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser() + self.max_length = max_length + self.val_ratio = val_ratio + self.val_split_seed = val_split_seed + self.val_only = val_only + self.add_eos = add_eos + self.detokenize = detokenize + self.batch_size = batch_size + self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size + self.num_workers = num_workers + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + if fault_tolerant: + assert self.shuffle + self.fault_tolerant = fault_tolerant + if ddp: + assert fault_tolerant + self.ddp = ddp + self.fast_forward_epochs = fast_forward_epochs + self.fast_forward_batches = fast_forward_batches + if self.fast_forward_epochs is not None or self.fast_forward_batches is not None: + assert ddp and fault_tolerant + + self.use_shmem = use_shmem + if self.use_shmem: + assert cache_dir is not None + + def prepare_data(self): + if self.cache_dir is None: # Just download the dataset + load_dataset(self.dataset_name, self.dataset_config_name) + else: # Process the dataset and save it + self.process_dataset() + + def setup(self, stage=None): + if stage == 'test' and hasattr(self, 'dataset_test'): + return + concat_ids, self.tokenizer = self.process_dataset() + self.vocab_size = len(self.tokenizer) + # Create all splits + self.dataset_train, self.dataset_val, self.dataset_test = [ + LMDataset(concat_ids[split], seq_len=self.max_length) + for split in ['train', 'validation', 'test'] + ] + + def process_dataset(self): + cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name + if cache_dir is not None: + if cache_dir.is_dir(): + return self._load_from_cache(cache_dir) + + raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name) + # https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py + if 'validation' not in raw_datasets: + assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets" + raw_datasets = raw_datasets["train"].train_test_split( + test_size=self.val_ratio, seed=self.val_split_seed, + shuffle=True # Otherwise test will be at the end of the dataset + ) + raw_datasets['validation'] = raw_datasets['test'] + + if self.val_only: # Should only be used for evaluation, not for training + raw_datasets['train'] = raw_datasets['validation'] + + # [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse + # (GPT2-small val ppl after 10 epochs ~22 -> ~25) + # However, it's useful for zero-shot transfer from Openwebtext, + # as after detokenization it's closer to Openwebtext's format. + # https://github.com/stanford-crfm/mistral/issues/12 + if self.detokenize: + if self.dataset_name in DATASET_TOKENIZATION_REGISTRY: + detokenizer = DATASET_TOKENIZATION_REGISTRY[self.dataset_name] + raw_datasets = raw_datasets.map( + lambda example: {'text': detokenizer(example['text'])}, + num_proc=max(self.num_workers, 1), + desc='Running detokenizer on dataset' + ) + + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True) + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + # [2021-12-25] TD: For wikitext, don't need to add the EOS since each example already ends + # with '\n', and there are no other '\n' in the examples. + # assert all([t.count('\n') == 1 for t in raw_datasets['train']['text'] if t]) + # Add EOS token to the end of the text if the text is not empty + # https://github.com/stanford-crfm/mistral/issues/91 + # https://github.com/stanford-crfm/mistral/pull/98 + if self.add_eos: + add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq + add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] + tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name])) + else: + tokenize = lambda example: tokenizer(example[text_column_name]) + # tokenized_datasets = raw_datasets.map( + # tokenize, + # batched=True, + # num_proc=max(self.num_workers, 1), + # remove_columns=column_names, + # desc="Running tokenizer on dataset", + # ) + dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32 + def tokenize_concat(examples): + # We just need 'input_ids', not 'attention_mask' (since it's all 1) + input_ids = np.fromiter(chain(*tokenize(examples)['input_ids']), dtype=dtype) + # Need to return a list since we're doing batched processing + return {'input_ids': [input_ids], 'len': [len(input_ids)]} + tokenized_datasets = raw_datasets.map( + tokenize_concat, + batched=True, + num_proc=max(self.num_workers, 1), + remove_columns=column_names, + desc="Running tokenizer on dataset", + ) + + if self.use_shmem: + # Concatenate all input_ids into an array in shared memory + def write_ids_to_shm(example, shm_name, array_len): + shm = SharedMemory(name=shm_name) + shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) + start_idx = example['len_offset'] - len(example['input_ids']) + shm_arr[start_idx:example['len_offset']] = example['input_ids'] + shm.close() + concat_ids = {} + for name, ds in tokenized_datasets.items(): + tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) + array_len = tokenized_datasets[name][-1]['len_offset'] + shm = SharedMemory(create=True, size=array_len * np.dtype(dtype).itemsize) + shm_name = shm.name + tokenized_datasets[name].map( + write_ids_to_shm, + fn_kwargs={'shm_name': shm_name, 'array_len': array_len}, + batched=False, + num_proc=max(self.num_workers, 1), + desc="Concatenating examples", + ) + shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) + # We need to keep a reference to the shared memory, otherwise it gets garbage-collected + # when it goes out of scope, and that memory is gone. + # https://github.com/numpy/numpy/issues/18294 + concat_ids[name] = SHMArray(shm_arr, shm=shm) + else: + # Use disk + concat_ids = {} + assert cache_dir is not None + cache_dir.mkdir(parents=True, exist_ok=True) + def write_ids_to_disk(example, filename): + with open(filename, 'r+b') as f: + mm = mmap.mmap(f.fileno(), 0) + start_idx = example['len_offset'] - len(example['input_ids']) + array_len = len(example['input_ids']) + arr = np.ndarray((array_len,), dtype=dtype, buffer=mm, + offset=np.dtype(dtype).itemsize * start_idx) + arr[:] = example['input_ids'] + mm.flush() + for name, ds in tokenized_datasets.items(): + tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) + array_len = tokenized_datasets[name][-1]['len_offset'] + filename = cache_dir / f'{name}.bin' + # Need to create the file with this specific size first + # https://ostechnix.com/create-files-certain-size-linux/ + subprocess.run(['truncate', '-s', str(array_len * np.dtype(dtype).itemsize), + str(filename)], check=True) + tokenized_datasets[name].map( + write_ids_to_disk, + fn_kwargs={'filename': filename}, + batched=False, + num_proc=max(self.num_workers, 1), + desc="Concatenating examples", + ) + concat_ids[name] = np.memmap(filename, dtype=dtype, mode='r', shape=(array_len,)) + + if cache_dir is not None: + self._save_to_cache(concat_ids, tokenizer, cache_dir) + if not self.use_shmem: + for name in concat_ids: + Path(cache_dir / f'{name}.bin').unlink() + return concat_ids, tokenizer + + def _save_to_cache(self, concat_ids, tokenizer, cache_dir): + cache_dir.mkdir(parents=True, exist_ok=True) + logger.info(f'Saving to cache at {str(cache_dir)}') + for k, v in concat_ids.items(): + np.save(cache_dir / f'{k}.npy', v) + with open(cache_dir / 'tokenizer.pkl', 'wb') as f: + pickle.dump(tokenizer, f) + + def _load_from_cache(self, cache_dir): + assert cache_dir.is_dir() + logger.info(f'Load from cache at {str(cache_dir)}') + concat_ids = {split: np.load(cache_dir / f'{split}.npy', mmap_mode='r') + for split in ['train', 'validation', 'test']} + with open(cache_dir / 'tokenizer.pkl', 'rb') as f: + tokenizer = pickle.load(f) + return concat_ids, tokenizer + + @property + def _cache_dir_name(self): + return f'tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-{self.detokenize}' + + def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: + """ The train dataloader """ + if self.shuffle and self.fault_tolerant: + shuffle = False + sampler = (FaultTolerantDistributedSampler(self.dataset_train) if self.ddp + else RandomFaultTolerantSampler(self.dataset_train)) + # TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now + # We assume that it's being resumed with the same number of GPUs + if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None: + sampler.load_state_dict({ + 'epoch': self.fast_forward_epochs, + 'counter': self.fast_forward_batches * self.batch_size + }) + else: + shuffle = self.shuffle + sampler = None + return self._data_loader(self.dataset_train, batch_size=self.batch_size, + shuffle=shuffle, sampler=sampler) + + def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: + """ The val dataloader """ + return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval) + + def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: + """ The test dataloader """ + return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval) + + def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, + sampler=None) -> DataLoader: + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=1, # Data is already in memory, we don't need many workers + shuffle=shuffle, + sampler=sampler, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + # persistent_workers=True + ) + + def load_state_dict(self, checkpoint): + if self.fault_tolerant: + self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed'] + # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration + # behind, so we're using the optimizer's progress. This is set correctly in seq.py. + self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] + # At this point the train loader hasn't been constructed yet diff --git a/training/src/datamodules/timm_mixup.py b/training/src/datamodules/timm_mixup.py new file mode 100644 index 000000000..66080985f --- /dev/null +++ b/training/src/datamodules/timm_mixup.py @@ -0,0 +1,20 @@ +import torch + +from timm.data import Mixup +from timm.data.mixup import mixup_target + + +class TimmMixup(Mixup): + """ Wrap timm.data.Mixup that avoids the assert that batch size must be even. + """ + def __call__(self, x, target): + if self.mode == 'elem': + lam = self._mix_elem(x) + elif self.mode == 'pair': + # We move the assert from the beginning of the function to here + assert len(x) % 2 == 0, 'Batch size should be even when using this' + lam = self._mix_pair(x) + else: + lam = self._mix_batch(x) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) + return x, target diff --git a/training/src/distributed/ddp_comm_hooks.py b/training/src/distributed/ddp_comm_hooks.py new file mode 100644 index 000000000..ad436a9a3 --- /dev/null +++ b/training/src/distributed/ddp_comm_hooks.py @@ -0,0 +1,43 @@ +# Adapted from https://pytorch.org/docs/stable/_modules/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.html +# We divide by world_size first before converting to fp16, so it's safer. +from typing import Any, Callable + +import torch +import torch.distributed as dist + + +def fp16_compress_hook( + process_group: dist.ProcessGroup, bucket: dist.GradBucket +) -> torch.futures.Future[torch.Tensor]: + """ + This DDP communication hook implements a simple gradient compression + approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) + and then divides it by the process group size. + It allreduces those ``float16`` gradient tensors. Once compressed gradient + tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). + + Example:: + >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + world_size = group_to_use.size() + + # Divide first before converting to fp16 + # Use out argument to fuse the division and the conversion. + compressed_tensor = torch.div(bucket.buffer(), world_size, + out=torch.empty_like(bucket.buffer(), dtype=torch.float16)) + + fut = dist.all_reduce( + compressed_tensor, group=group_to_use, async_op=True + ).get_future() + + def decompress(fut): + decompressed_tensor = bucket.buffer() + # Decompress in place to reduce the peak memory. + # See: https://github.com/pytorch/pytorch/issues/45968 + decompressed_tensor.copy_(fut.value()[0]) + return decompressed_tensor + + # TODO: maybe have a backoff strategy: check if the buffer has inf / NaN, in that case + # resend with fp32? + return fut.then(decompress) diff --git a/training/src/eval.py b/training/src/eval.py new file mode 100644 index 000000000..161a23c89 --- /dev/null +++ b/training/src/eval.py @@ -0,0 +1,129 @@ +from typing import List, Optional +from pathlib import Path + +import torch + +import hydra +from omegaconf import OmegaConf, DictConfig +from pytorch_lightning import ( + Callback, + LightningDataModule, + LightningModule, + Trainer, + seed_everything, +) +from pytorch_lightning.loggers import LightningLoggerBase + +from src.utils import utils + +log = utils.get_logger(__name__) + + +def remove_prefix(text: str, prefix: str): + if text.startswith(prefix): + return text[len(prefix) :] + return text # or whatever + + +def load_checkpoint(path, device='cpu'): + path = Path(path).expanduser() + if path.is_dir(): + path /= 'last.ckpt' + # dst = f'cuda:{torch.cuda.current_device()}' + log.info(f'Loading checkpoint from {str(path)}') + state_dict = torch.load(path, map_location=device) + # T2T-ViT checkpoint is nested in the key 'state_dict_ema' + if state_dict.keys() == {'state_dict_ema'}: + state_dict = state_dict['state_dict_ema'] + # Swin checkpoint is nested in the key 'model' + if state_dict.keys() == {'model'}: + state_dict = state_dict['model'] + # Lightning checkpoint contains extra stuff, we only want the model state dict + if 'pytorch-lightning_version' in state_dict: + state_dict = {remove_prefix(k, 'model.'): v for k, v in state_dict['state_dict'].items()} + return state_dict + + +def evaluate(config: DictConfig) -> None: + """Example of inference with trained model. + It loads trained image classification model from checkpoint. + Then it loads example image and predicts its label. + """ + + # load model from checkpoint + # model __init__ parameters will be loaded from ckpt automatically + # you can also pass some parameter explicitly to override it + + # We want to add fields to config so need to call OmegaConf.set_struct + OmegaConf.set_struct(config, False) + + # load model + checkpoint_type = config.eval.get('checkpoint_type', 'pytorch') + if checkpoint_type not in ['lightning', 'pytorch']: + raise NotImplementedError(f'checkpoint_type ${checkpoint_type} not supported') + + if checkpoint_type == 'lightning': + cls = hydra.utils.get_class(config.task._target_) + model = cls.load_from_checkpoint(checkpoint_path=config.eval.ckpt) + elif checkpoint_type == 'pytorch': + model_cfg = config.model_pretrained if 'model_pretrained' in config else None + trained_model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, + model_cfg=model_cfg, + _recursive_=False) + if 'ckpt' in config.eval: + load_return = trained_model.model.load_state_dict( + load_checkpoint(config.eval.ckpt, device=trained_model.device), strict=False + ) + log.info(load_return) + if 'model_pretrained' in config: + ... + else: + model = trained_model + + datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) + # datamodule: LightningDataModule = model._datamodule + datamodule.prepare_data() + datamodule.setup() + + # print model hyperparameters + log.info(f'Model hyperparameters: {model.hparams}') + + # Init Lightning callbacks + callbacks: List[Callback] = [] + if "callbacks" in config: + for _, cb_conf in config["callbacks"].items(): + if cb_conf is not None and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + # Init Lightning loggers + logger: List[LightningLoggerBase] = [] + if "logger" in config: + for _, lg_conf in config["logger"].items(): + if lg_conf is not None and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + # Init Lightning trainer + log.info(f"Instantiating trainer <{config.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" + ) + + # Evaluate the model + log.info("Starting evaluation!") + if config.eval.get('run_val', True): + trainer.validate(model=model, datamodule=datamodule) + if config.eval.get('run_test', True): + trainer.test(model=model, datamodule=datamodule) + + # Make sure everything closed properly + log.info("Finalizing!") + utils.finish( + config=config, + model=model, + datamodule=datamodule, + trainer=trainer, + callbacks=callbacks, + logger=logger, + ) diff --git a/training/src/losses/cross_entropy_apex.py b/training/src/losses/cross_entropy_apex.py new file mode 100644 index 000000000..ef7094659 --- /dev/null +++ b/training/src/losses/cross_entropy_apex.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn + +import xentropy_cuda_lib + + +# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py +class SoftmaxCrossEntropyLossFn(torch.autograd.Function): + @staticmethod + def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False): + losses, max_log_sum_exp = xentropy_cuda_lib.forward( + logits, labels, smoothing) + losses.masked_fill_(labels==padding_idx, 0) + ctx.save_for_backward(logits, max_log_sum_exp, labels) + ctx.smoothing = smoothing + ctx.padding_idx = padding_idx + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_loss): + logits, max_log_sum_exp, labels = ctx.saved_tensors + if not grad_loss.is_contiguous(): + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels==ctx.padding_idx, 0) + grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels, + ctx.smoothing, ctx.inplace_backward) + return grad_logits, None, None, None, None + + +class CrossEntropyLossApex(nn.Module): + + def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, + inplace_backward=False): + super().__init__() + if reduction not in ['mean', 'none']: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + + def forward(self, input, target): + assert input.is_cuda and target.is_cuda + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing, + self.ignore_index, self.inplace_backward) + if self.reduction == 'mean': + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/training/src/losses/cross_entropy_parallel.py b/training/src/losses/cross_entropy_parallel.py new file mode 100644 index 000000000..84fe82dc6 --- /dev/null +++ b/training/src/losses/cross_entropy_parallel.py @@ -0,0 +1,112 @@ +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py +# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and +# the losses we can get the global loss. There's no need to do it step by step +# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) +import torch +import torch.nn as nn + +import xentropy_cuda_lib + +from apex.transformer.parallel_state import get_tensor_model_parallel_group +from apex.transformer.parallel_state import get_tensor_model_parallel_rank +from apex.transformer.parallel_state import get_tensor_model_parallel_world_size +from apex.transformer.tensor_parallel.utils import VocabUtility + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 4 lines are for backward comparability with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base +if "reduce_scatter_tensor" not in dir(torch.distributed): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + + +class SoftmaxCrossEntropyLossParallelFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, logits_parallel, labels, smoothing=0.0, ignored_index=-100, + inplace_backward=False): + """ + logits_parallel: (batch, vocab_size / world_size) + labels: (batch,) + """ + assert smoothing == 0.0, 'smoothing != 0.0 is not yet implemented, file an issue if you need it' + batch, partition_vocab_size = logits_parallel.shape + assert labels.shape == (batch,) + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size( + partition_vocab_size, get_tensor_model_parallel_rank(), + get_tensor_model_parallel_world_size() + ) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) + ignored_mask = labels == ignored_index + labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) + masked_labels = labels_local.clone() + masked_labels[labels_mask] = ignored_index + + losses, lse_local = xentropy_cuda_lib.forward(logits_parallel, masked_labels, smoothing) + assert lse_local.shape == (batch,) + assert losses.shape == (batch,) + losses.masked_fill_(masked_labels==ignored_index, 0) + + if world_size > 1: + lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, + device=lse_local.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(), + group=get_tensor_model_parallel_group()) + lse = torch.logsumexp(lse_allgather, dim=0) + torch.distributed.all_reduce(losses, op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group()) + # The losses are currently lse_local - predicted_logit, we just have to subtract the + # lse_local and add the lse (global). + rank_per_sample = labels // partition_vocab_size + lse_local = lse_allgather[rank_per_sample, + torch.arange(batch, device=lse_allgather.device)] + losses += lse - lse_local + losses.masked_fill_(ignored_mask, 0) + else: + lse = lse_local + + ctx.save_for_backward(logits_parallel, lse, labels_local) + ctx.smoothing = smoothing + ctx.ignored_index = ignored_index + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_loss): + logits_parallel, lse, labels = ctx.saved_tensors + if not grad_loss.is_contiguous(): + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels==ctx.ignored_index, 0) + grad_logits = xentropy_cuda_lib.backward(grad_loss, logits_parallel, lse, labels, + ctx.smoothing, ctx.inplace_backward) + return grad_logits, None, None, None, None, None + + +class CrossEntropyLossParallel(nn.Module): + + def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, + inplace_backward=False): + super().__init__() + if reduction not in ['mean', 'none']: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + + def forward(self, input, target): + assert input.is_cuda and target.is_cuda + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossParallelFn.apply( + input, target, self.label_smoothing, self.ignore_index, self.inplace_backward + ) + if self.reduction == 'mean': + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/training/src/metrics/accuracy.py b/training/src/metrics/accuracy.py new file mode 100644 index 000000000..810b6513b --- /dev/null +++ b/training/src/metrics/accuracy.py @@ -0,0 +1,11 @@ +import torch +from torch import Tensor + +from torchmetrics import Metric, Accuracy + + +class AccuracyMine(Accuracy): + """Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup. + """ + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + super().update(preds, target.argmax(dim=-1) if target.is_floating_point() else target) diff --git a/training/src/metrics/num_tokens.py b/training/src/metrics/num_tokens.py new file mode 100644 index 000000000..9e731c9db --- /dev/null +++ b/training/src/metrics/num_tokens.py @@ -0,0 +1,45 @@ +from typing import Any, Dict, Optional + +import torch +from torch import Tensor + +from torchmetrics import Metric + + +class NumTokens(Metric): + """Keep track of how many tokens we've seen. + """ + # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch + # of the next epoch. + # Right now the hack is that we override reset(), which would mess up the forward method. + # We then override forward to do the right thing. + + is_differentiable = False + higher_is_better = False + full_state_update = False + count: Tensor + + def __init__(self, **kwargs: Dict[str, Any]): + super().__init__(**kwargs) + self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum", + persistent=True) # We want the count to be saved to state-dict + + def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore + self.count += target.numel() + + def compute(self) -> Tensor: + return self.count + + def reset(self): + count = self.count + super().reset() + self.count = count + + # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py + def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: + """forward computation using single call to `update` to calculate the metric value on the current batch and + accumulate global state. + This can be done when the global metric state is a sinple reduction of batch states. + """ + self.update(*args, **kwargs) + return self.compute() diff --git a/training/src/metrics/perplexity.py b/training/src/metrics/perplexity.py new file mode 100644 index 000000000..9e79a4bac --- /dev/null +++ b/training/src/metrics/perplexity.py @@ -0,0 +1,70 @@ +# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py +# But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll)) +# Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py +# But we pass in the loss to avoid recomputation + +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torchmetrics import Metric + +try: + from src.losses.cross_entropy_apex import CrossEntropyLossApex as CrossEntropyLoss +except ImportError: + CrossEntropyLoss = torch.nn.CrossEntropyLoss + +__all__ = ['Perplexity'] + + +class Perplexity(Metric): + r""" + Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits + per word a model needs to represent the sample. + Args: + kwargs: + Additional keyword arguments, see :ref:`Metric kwargs` for more info. + Examples: + >>> import torch + >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) + >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) + >>> target[0, 6:] = -100 + >>> metric = Perplexity(ignore_index=-100) + >>> metric(preds, target) + tensor(5.2545) + """ + is_differentiable = True + higher_is_better = False + full_state_update = False + total_log_probs: Tensor + count: Tensor + + def __init__(self, **kwargs: Dict[str, Any]): + super().__init__(**kwargs) + self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64), + dist_reduce_fx="sum") + self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") + + self.loss_fn = CrossEntropyLoss() + + def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore + """Compute and store intermediate statistics for Perplexity. + Args: + preds: + Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. + target: + Ground truth values with a shape [batch_size, seq_len]. + """ + count = target.numel() + if loss is None: + loss = self.loss_fn(preds, target) + self.total_log_probs += loss.double() * count + self.count += count + + def compute(self) -> Tensor: + """Compute the Perplexity. + Returns: + Perplexity + """ + return torch.exp(self.total_log_probs / self.count) diff --git a/training/src/models/modules/seq_common.py b/training/src/models/modules/seq_common.py new file mode 100644 index 000000000..4d0469d80 --- /dev/null +++ b/training/src/models/modules/seq_common.py @@ -0,0 +1,342 @@ +import math +from functools import partial +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair + +import hydra + +from einops import reduce, rearrange + + +def pooling(x, pooling_mode='CLS', key_padding_mask=None, batch_first=True): + if pooling_mode not in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN']: + raise NotImplementedError(f'pooling_mode must be MEAN, SUM, CLS, LAST, FLATTEN') + if pooling_mode in ['MEAN', 'SUM']: + if key_padding_mask is not None: + mask = rearrange(~key_padding_mask.bool_matrix, + 'b s -> b s 1' if batch_first else 'b s -> s b 1') + x = x.masked_fill(mask, 0) + s = reduce(x, 'b s ... -> b ...' if batch_first else 's b ... -> b ...', 'sum') + if pooling_mode == 'SUM': + return s + else: + if key_padding_mask is None: + return s / x.shape[1 if batch_first else 0] + else: + lengths = rearrange(key_padding_mask._lengths, 'b -> b 1') + return s / lengths + elif pooling_mode == 'CLS': + return x[:, 0] if batch_first else x[0] + elif pooling_mode == 'LAST': + if key_padding_mask is None: + return x[:, -1] if batch_first else x[-1] + else: + lengths = key_padding_mask._lengths + if batch_first: + batch_size = x.shape[0] + return x[torch.arange(batch_size, device=x.device), lengths - 1] + else: + batch_size = x.shape[1] + return x[lengths - 1, torch.arange(batch_size, device=x.device)] + elif pooling_mode == 'FLATTEN': + return rearrange(x, 'b ... -> b (...)' if batch_first else 's b ... -> b (s ...)') + + +class ClassificationHeadLinear(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, d_model, num_classes, pooling_mode='MEAN', + batch_first=False, **kwargs): + super().__init__() + assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported' + self.pooling_mode = pooling_mode + self.batch_first = batch_first + self.out_proj = nn.Linear(d_model, num_classes) + + def forward(self, hidden_states, key_padding_mask=None, **kwargs): + """ + hidden_states: (B, S, D) if batch_first else (S, B, D) + """ + hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode, + key_padding_mask=key_padding_mask, batch_first=self.batch_first) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/models/reformer/modeling_reformer.py +class ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN', + batch_first=False): + super().__init__() + assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported' + self.pooling_mode = pooling_mode + self.batch_first = batch_first + self.dense = nn.Linear(d_model, d_inner) + self.dropout = nn.Dropout(dropout) + self.out_proj = nn.Linear(d_inner, num_classes) + + def forward(self, hidden_states, key_padding_mask=None, **kwargs): + """ + hidden_states: (B, S, D) if batch_first else (S, B, D) + """ + hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode, + key_padding_mask=key_padding_mask, batch_first=self.batch_first) + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + # Huggingface uses tanh instead of relu + hidden_states = torch.relu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class ClassificationHeadDual(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN', + batch_first=False, interaction='NLI'): + super().__init__() + assert pooling_mode in ['MEAN', 'SUM', 'CLS'], 'pooling_mode not supported' + assert interaction in [None, 'NLI'], 'interaction not supported' + self.pooling_mode = pooling_mode + self.batch_first = batch_first + self.interaction = interaction + self.dense = nn.Linear(d_model * (4 if self.interaction == 'NLI' else 2), d_inner) + self.dropout = nn.Dropout(dropout) + self.out_proj = nn.Linear(d_inner, num_classes) + + def forward(self, hidden_states1, hidden_states2, + key_padding_mask1=None, key_padding_mask2=None, **kwargs): + """ + hidden_states: (B, S, D) if batch_first else (S, B, D) + """ + x1 = pooling(hidden_states1, pooling_mode=self.pooling_mode, + key_padding_mask=key_padding_mask1, batch_first=self.batch_first) + x2 = pooling(hidden_states2, pooling_mode=self.pooling_mode, + key_padding_mask=key_padding_mask2, batch_first=self.batch_first) + hidden_states = (torch.cat([x1, x2, x1 * x2, x1 - x2], dim=-1) if self.interaction == 'NLI' + else torch.cat([x1, x2], dim=-1)) + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + # Huggingface uses tanh instead of relu + hidden_states = torch.relu(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class LMHead(nn.Module): + + def __init__(self, d_model, num_classes, batch_first=True, bias=True): + super().__init__() + self.lm_head = nn.Linear(d_model, num_classes, bias=bias) + + def forward(self, hidden_states, **kwargs): + """ + hidden_states: (B, S, D) if batch_first else (S, B, D) + """ + CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) + return CausalLMOutput(self.lm_head(hidden_states)) + + +def sinusoidal_init_(tensor): + """ + tensor: (max_len, d_model) + """ + max_len, d_model = tensor.shape + position = rearrange(torch.arange(0.0, max_len), 's -> s 1') + div_term = torch.exp(-math.log(10000.0) * torch.arange(0.0, d_model, 2.0) / d_model) + tensor[:, 0::2] = torch.sin(position * div_term) + tensor[:, 1::2] = torch.cos(position * div_term) + return tensor + + +# Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py +class PositionalEncoding(nn.Module): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False, initializer=None): + super().__init__() + self.batch_first = batch_first + self.dropout = nn.Dropout(p=dropout) + pe = torch.empty(max_len, d_model) + if initializer is None: + sinusoidal_init_(pe) + pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d') + self.register_buffer('pe', pe) + else: + hydra.utils.call(initializer, pe) + pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d') + self.pe = nn.Parameter(pe) + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] if not batch_first else [B, S, D] + output: [sequence length, batch size, embed dim] if not batch_first else [B, S, D] + Examples: + >>> output = pos_encoder(x) + """ + x = x + (self.pe[:, :x.size(1)] if self.batch_first else self.pe[:x.size(0)]) + return self.dropout(x) + + +# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + act_fn=None, drop=0., device=None, dtype=None): + """TD [2021-10-27] act_fn takes precedence over act_layer if set. + This is to support Pytorch 1.10 Transformer interface that construct the activation + *function*, not the activation *layer*. + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + drop_probs = _pair(drop) + self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) + self.act = act_layer() if act_fn is None else act_fn + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class MlpBig(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + act_fn=None, drop=0., device=None, dtype=None): + """Copied from Mlp above. If num_layers > 2, add more Mlp layers, doubling each time. + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + cur_hidden_features = hidden_features + layers = [] + for _ in range(4): + layers.append(nn.Linear(in_features, cur_hidden_features, **factory_kwargs)) + layers.append(act_layer()) + layers.append(nn.Dropout(drop)) + in_features = cur_hidden_features + cur_hidden_features *= 2 + layers.append(nn.Linear(in_features, out_features, **factory_kwargs)) + layers.append(nn.Dropout(drop)) + self.fwd = nn.Sequential(*layers) + + def forward(self, x): + return self.fwd(x) + +class GluMlp(nn.Module): + """ MLP w/ GLU style gating + See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + assert hidden_features % 2 == 0 + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features // 2, out_features) + self.drop = nn.Dropout(drop) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + fc1_mid = self.fc1.bias.shape[0] // 2 + nn.init.ones_(self.fc1.bias[fc1_mid:]) + nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x, gates = x.chunk(2, dim=-1) + x = x * self.act(gates) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class GatedMlp(nn.Module): + """ MLP as used in gMLP + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + gate_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + if gate_layer is not None: + assert hidden_features % 2 == 0 + self.gate = gate_layer(hidden_features) + hidden_features = hidden_features // 2 # FIXME base reduction on gate property? + else: + self.gate = nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.gate(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvMlp(nn.Module): + """ MLP using 1x1 convs that keeps spatial dims + """ + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x + diff --git a/training/src/optim/param_grouping.py b/training/src/optim/param_grouping.py new file mode 100644 index 000000000..31f06f24c --- /dev/null +++ b/training/src/optim/param_grouping.py @@ -0,0 +1,114 @@ +import inspect + +import torch.nn as nn + +import hydra + +try: + from apex.contrib.layer_norm import FastLayerNorm +except ImportError: + FastLayerNorm = None + +from src.models.modules.seq_common import PositionalEncoding + + +def group_parameters_for_optimizer(model, optimizer_cfg, bias_weight_decay=False, + normalization_weight_decay=False): + """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with + attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for + normalization parameters if normalization_weight_decay==False + """ + # Get the weight decay from the config, or from the default value of the optimizer constructor + # if it's not specified in the config. + if 'weight_decay' in optimizer_cfg: + weight_decay = optimizer_cfg.weight_decay + else: + # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value + signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_)) + if 'weight_decay' in signature.parameters: + weight_decay = signature.parameters['weight_decay'].default + if weight_decay is inspect.Parameter.empty: + weight_decay = 0.0 + else: + weight_decay = 0.0 + + # If none of the parameters have weight decay anyway, and there are no parameters with special + # optimization params + if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()): + return model.parameters() + + skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set() + skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords') + else set()) + + # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + special = set() + whitelist_weight_modules = (nn.Linear, ) + blacklist_weight_modules = (nn.Embedding, PositionalEncoding) + if not normalization_weight_decay: + blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, + nn.GroupNorm, nn.SyncBatchNorm, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.LayerNorm, nn.LocalResponseNorm) + if FastLayerNorm is not None: + blacklist_weight_modules += (FastLayerNorm,) + + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + # In case of parameter sharing, some parameters show up here but are not in + # param_dict.keys() + if not p.requires_grad or fpn not in param_dict: + continue # frozen weights + if hasattr(p, '_optim'): + special.add(fpn) + elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords): + no_decay.add(fpn) + elif getattr(p, '_no_weight_decay', False): + no_decay.add(fpn) + elif not bias_weight_decay and pn.endswith('bias'): + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + decay |= (param_dict.keys() - no_decay - special) + # validate that we considered every parameter + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both decay/no_decay sets!" + assert len(param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" + + if weight_decay == 0.0 or not no_decay: + param_groups = [{"params": [param_dict[pn] for pn in sorted(list(no_decay | decay))], + "weight_decay": weight_decay}] + else: + # We need sorted(list()) so that the order is deterministic. Otherwise when we resume + # the order could change and resume will fail. [H/t Albert] + param_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + # Add parameters with special hyperparameters + # Unique dicts + hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)] + for hp in hps: + params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp] + param_groups.append({"params": params, **hp}) + + return param_groups diff --git a/training/src/optim/timm_lr_scheduler.py b/training/src/optim/timm_lr_scheduler.py new file mode 100644 index 000000000..cfba73cbe --- /dev/null +++ b/training/src/optim/timm_lr_scheduler.py @@ -0,0 +1,30 @@ +import torch +from torch.optim import Optimizer + +from timm.scheduler import CosineLRScheduler + + +# We need to subclass torch.optim.lr_scheduler._LRScheduler, or Pytorch-lightning will complain +class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): + """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. + It supports resuming as well. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._last_epoch = -1 + self.step(epoch=0) + + def step(self, epoch=None): + if epoch is None: + self._last_epoch += 1 + else: + self._last_epoch = epoch + # We call either step or step_update, depending on whether we're using the scheduler every + # epoch or every step. + # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set + # scheduler interval to "step", then the learning rate update will be wrong. + if self.t_in_epochs: + super().step(epoch=self._last_epoch) + else: + super().step_update(num_updates=self._last_epoch) diff --git a/training/src/tasks/seq.py b/training/src/tasks/seq.py new file mode 100644 index 000000000..c44c968ba --- /dev/null +++ b/training/src/tasks/seq.py @@ -0,0 +1,192 @@ +from typing import Any, List +import inspect + +import torch +import hydra +from pytorch_lightning import LightningModule, LightningDataModule +from torchmetrics import MetricCollection + +from einops import rearrange + +from omegaconf import OmegaConf + +from src.utils.utils import get_logger +from src.optim.param_grouping import group_parameters_for_optimizer +from src.utils.checkpoint import load_checkpoint + +logger = get_logger(__name__) + + +class SequenceModel(LightningModule): + + def __init__(self, cfg, model_cfg=None): + """If model_cfg is passed, it will take precedence over cfg.model + """ + super().__init__() + # this line ensures params passed to LightningModule will be saved to ckpt + # it also allows to access params with 'self.hparams' attribute + self.save_hyperparameters(cfg) + self.cfg = cfg + self.model_cfg = model_cfg or self.cfg.model + + self.instantiate_datamodule() + self.instantiate_model() + self.warmstart() + self.instantiate_loss() + self.instantiate_metrics() + + def instantiate_datamodule(self): + logger.info(f"Instantiating datamodule <{self.cfg.datamodule._target_}>") + # Calling this self.datamodule will mess with PL since it also assigns self.datamodule + self._datamodule: LightningDataModule = hydra.utils.instantiate(self.cfg.datamodule) + self._datamodule.prepare_data() + self._datamodule.setup() + OmegaConf.clear_resolver('datamodule') + OmegaConf.register_new_resolver('datamodule', lambda attr: getattr(self._datamodule, attr)) + + def instantiate_model(self): + # if hasattr(self._datamodule, 'num_classes'): + # self.model_cfg.num_classes = self._datamodule.num_classes + # if (hasattr(self._datamodule, 'vocab_size') + # and self.model_cfg.get('embedding_cfg', None) is not None + # and self.model_cfg.embedding_cfg._target_ == "torch.nn.Embedding"): + # self.model_cfg.embedding_cfg.num_embeddings = self._datamodule.vocab_size + logger.info(f"Instantiating model <{self.model_cfg._target_}>") + recursive = getattr(self.model_cfg, '_recursive_', False) + self.model = hydra.utils.instantiate(self.model_cfg, _recursive_=recursive) + + def instantiate_loss(self): + loss_fn_cfg = self.cfg.train.get('loss_fn') + if loss_fn_cfg is None: + loss_fn_cfg = {'_target_': 'torch.nn.CrossEntropyLoss'} + self.loss_fn = hydra.utils.instantiate(loss_fn_cfg) + loss_fn_val_cfg = self.cfg.train.get('loss_fn_val', loss_fn_cfg) + self.loss_fn_val = hydra.utils.instantiate(loss_fn_val_cfg) + + def instantiate_metrics(self): + # use separate metric instance for train, val and test step + # to ensure a proper reduction over the epoch + if 'eval' in self.cfg and 'metrics' in self.cfg.eval: + metrics_cfg = self.cfg.eval.metrics + else: + metrics_cfg = {'acc': {'_target_': 'torchmetrics.Accuracy'}} + metrics = MetricCollection({name: hydra.utils.instantiate(cfg) + for name, cfg in metrics_cfg.items()}) + self.train_metrics = metrics.clone(prefix='train/') + self.val_metrics = metrics.clone(prefix='val/') + self.test_metrics = metrics.clone(prefix='test/') + + def warmstart(self): + if self.cfg.train.get('warmstart', None) is not None: + logger.info(f"Warm-starting with weights from {self.cfg.train.warmstart.path}") + strict = self.cfg.train.warmstart.get('strict', True) + state_dict = load_checkpoint(self.cfg.train.warmstart.path) + if self.cfg.train.warmstart.get('post_process', None) is not None: + state_dict = hydra.utils.instantiate(self.cfg.train.warmstart.post_process, + state_dict) + load_return = self.model.load_state_dict(state_dict, strict=False) + logger.info(load_return) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def step(self, batch: Any, is_train=True): + try: + x, y, lengths = batch + except ValueError: + x, y = batch + lengths = None + output = self.forward(x) if lengths is None else self.forward(x, lengths=lengths) + loss = self.loss_fn(output, y) if is_train else self.loss_fn_val(output, y) + return loss, output, y + + def shared_step(self, batch: Any, batch_idx: int, phase='train'): + loss, output, targets = self.step(batch, is_train=(phase == 'train')) + metrics = getattr(self, f'{phase}_metrics') + metrics(output, targets) + log_on_step = 'eval' in self.cfg and self.cfg.eval.get('log_on_step', False) and phase == 'train' + self.log(f"{phase}/loss", loss, on_step=log_on_step, on_epoch=True, + prog_bar=False, sync_dist=True) + # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training + # We need to log the Metrics object, not the metric result, since otherwise + # pytorch-lightning will use torch.mean to reduce it. + # This would be wrong for perplexity, for example. + self.log_dict(metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, sync_dist=True) + return {"loss": loss, "output": output, "targets": targets} + + def training_step(self, batch: Any, batch_idx: int): + return self.shared_step(batch, batch_idx, phase='train') + + def validation_step(self, batch: Any, batch_idx: int): + return self.shared_step(batch, batch_idx, phase='val') + + def test_step(self, batch: Any, batch_idx: int): + return self.shared_step(batch, batch_idx, phase='test') + + def configure_optimizers(self): + if 'optimizer_param_grouping' in self.cfg.train: # Set zero weight decay for some params + parameters = group_parameters_for_optimizer(self.model, self.cfg.train.optimizer, + **self.cfg.train.optimizer_param_grouping) + else: + # parameters = self.model.parameters() + parameters = self.parameters() # [21-09-08] AG: this will train task specific parameters such as Retrieval head for AAN + optimizer = hydra.utils.instantiate(self.cfg.train.optimizer, parameters) + + # Log optimizer info + for i, g in enumerate(optimizer.param_groups): + ntensors = len(g['params']) + nparams = sum(p.numel() for p in g['params']) + hparams = {k: v for k, v in g.items() if k != 'params'} + logger.info(f'Optimizer group {i}: {ntensors} tensors, {nparams} parameters, {hparams}') + + if 'scheduler' not in self.cfg.train: + return optimizer + else: + # lr_scheduler should be called either every step (default) or every epoch + lr_scheduler = hydra.utils.instantiate(self.cfg.train.scheduler, optimizer) + return [optimizer], {'scheduler': lr_scheduler, + 'interval': self.cfg.train.get('scheduler_interval', 'step'), + 'monitor': self.cfg.train.get('scheduler_monitor', 'val/loss')} + + def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + # https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html#set-grads-to-none + # TD [2022-04-30]: DeepSpeed optimizer uses the kwarg set_grad_to_none instead of set_to_none + if 'set_to_none' in inspect.signature(optimizer.zero_grad).parameters: + optimizer.zero_grad(set_to_none=True) + else: + optimizer.zero_grad() + + def on_save_checkpoint(self, checkpoint): + # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration + # behind, so we're using the optimizer's progress. + checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['total']['completed'] * self.trainer.accumulate_grad_batches + checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['current']['completed'] * self.trainer.accumulate_grad_batches + # _batches_that_stepped tracks the number of global steps, not the number + # of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here. + checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['total']['completed'] + + +class SequenceLMModel(SequenceModel): + + def step(self, batch: Any, is_train=True): + x, y = batch + output = self.forward(x).logits + output = rearrange(output, '... C -> (...) C') + y = rearrange(y, '... -> (...)') + loss = self.loss_fn(output, y) if is_train else self.loss_fn_val(output, y) + return loss, output, y + + def shared_step(self, batch: Any, batch_idx: int, phase='train'): + loss, output, targets = self.step(batch, is_train=(phase == 'train')) + # Passing the loss to the perplexity metrics to avoid recomputation + metrics = getattr(self, f'{phase}_metrics') + metrics(output, targets, loss=loss) + log_on_step = 'eval' in self.cfg and self.cfg.eval.get('log_on_step', False) and phase == 'train' + self.log(f"{phase}/loss", loss, on_step=log_on_step, on_epoch=True, + prog_bar=False, sync_dist=True) + # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training + # We need to log the Metrics object, not the metric result, since otherwise + # pytorch-lightning will use torch.mean to reduce it. + # This would be wrong for perplexity, for example. + self.log_dict(metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, sync_dist=True) + return {"loss": loss, "output": output, "targets": targets} diff --git a/training/src/train.py b/training/src/train.py new file mode 100644 index 000000000..8c92413e4 --- /dev/null +++ b/training/src/train.py @@ -0,0 +1,136 @@ +from typing import List, Optional, Sequence +from pathlib import Path + +import hydra +from omegaconf import OmegaConf, DictConfig +from pytorch_lightning import ( + Callback, + LightningDataModule, + LightningModule, + Trainer, + seed_everything, +) +from pytorch_lightning.loggers import LightningLoggerBase + +from src.utils import utils + +log = utils.get_logger(__name__) + + +def last_modification_time(path): + """Including files / directory 1-level below the path + """ + path = Path(path) + if path.is_file(): + return path.stat().st_mtime + elif path.is_dir(): + return max(child.stat().st_mtime for child in path.iterdir()) + else: + return None + + +def train(config: DictConfig) -> Optional[float]: + """Contains training pipeline. + Instantiates all PyTorch Lightning objects from config. + + Args: + config (DictConfig): Configuration composed by Hydra. + + Returns: + Optional[float]: Metric score for hyperparameter optimization. + """ + + # Set seed for random number generators in pytorch, numpy and python.random + if config.get("seed"): + seed_everything(config.seed, workers=True) + + # We want to add fields to config so need to call OmegaConf.set_struct + OmegaConf.set_struct(config, False) + # Init lightning model + model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, _recursive_=False) + datamodule: LightningDataModule = model._datamodule + + # Init lightning callbacks + callbacks: List[Callback] = [] + if "callbacks" in config: + for _, cb_conf in config.callbacks.items(): + if cb_conf is not None and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + # Init lightning loggers + logger: List[LightningLoggerBase] = [] + if "logger" in config: + for _, lg_conf in config.logger.items(): + if lg_conf is not None and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + ckpt_cfg = {} + if config.get('resume'): + try: + checkpoint_path = Path(config.callbacks.model_checkpoint.dirpath) + if checkpoint_path.is_dir(): + last_ckpt = checkpoint_path / 'last.ckpt' + autosave_ckpt = checkpoint_path / '.pl_auto_save.ckpt' + if not (last_ckpt.exists() or autosave_ckpt.exists()): + raise FileNotFoundError("Resume requires either last.ckpt or .pl_autosave.ckpt") + if ((not last_ckpt.exists()) + or (autosave_ckpt.exists() + and last_modification_time(autosave_ckpt) > last_modification_time(last_ckpt))): + # autosave_ckpt = autosave_ckpt.replace(autosave_ckpt.with_name('.pl_auto_save_loaded.ckpt')) + checkpoint_path = autosave_ckpt + else: + checkpoint_path = last_ckpt + # DeepSpeed's checkpoint is a directory, not a file + if checkpoint_path.is_file() or checkpoint_path.is_dir(): + ckpt_cfg = {'ckpt_path': str(checkpoint_path)} + else: + log.info(f'Checkpoint file {str(checkpoint_path)} not found. Will start training from scratch') + except (KeyError, FileNotFoundError): + pass + + # Configure ddp automatically + n_devices = config.trainer.get('devices', 1) + if isinstance(n_devices, Sequence): # trainer.devices could be [1, 3] for example + n_devices = len(n_devices) + if n_devices > 1 and config.trainer.get('strategy', None) is None: + config.trainer.strategy = dict( + _target_='pytorch_lightning.strategies.DDPStrategy', + find_unused_parameters=False, + gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations + ) + + # Init lightning trainer + log.info(f"Instantiating trainer <{config.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + config.trainer, callbacks=callbacks, logger=logger) + + # Train the model + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, **ckpt_cfg) + + # Evaluate model on test set, using the best model achieved during training + if config.get("test_after_training") and not config.trainer.get("fast_dev_run"): + log.info("Starting testing!") + trainer.test(model=model, datamodule=datamodule) + + # Make sure everything closed properly + log.info("Finalizing!") + utils.finish( + config=config, + model=model, + datamodule=datamodule, + trainer=trainer, + callbacks=callbacks, + logger=logger, + ) + + # Print path to best checkpoint + if not config.trainer.get("fast_dev_run"): + log.info(f"Best model ckpt: {trainer.checkpoint_callback.best_model_path}") + + # Return metric score for hyperparameter optimization + optimized_metric = config.get("optimized_metric") + if optimized_metric: + return trainer.callback_metrics[optimized_metric] diff --git a/training/src/utils/checkpoint.py b/training/src/utils/checkpoint.py new file mode 100644 index 000000000..64e3db63c --- /dev/null +++ b/training/src/utils/checkpoint.py @@ -0,0 +1,76 @@ +import re +from pathlib import Path + +import torch +import math +from einops import rearrange + +def load_checkpoint(path, device='cpu'): + path = Path(path).expanduser() + is_deepspeed = False + if path.is_dir(): # DeepSpeed checkpoint + is_deepspeed = True + latest_path = path / 'latest' + if latest_path.is_file(): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + path /= f'{tag}/mp_rank_00_model_states.pt' + state_dict = torch.load(path, map_location=device) + if is_deepspeed: + state_dict = state_dict['module'] + + # Replace the names of some of the submodules + def key_mapping(key): + return re.sub(r'^module.model.', '', key) + + state_dict = {key_mapping(k): v for k, v in state_dict.items()} + return state_dict + + +def blockdiag_to_dense_mlp_bert(state_dict): + from src.ops.blockdiag_multiply import blockdiag_weight_to_dense_weight + names = {name for name in state_dict + if re.match('bert.encoder.layer.(\d+).(mlp.fc(1|2)|(intermediate|output).dense).weight', + name)} + for name in names: + state_dict[name] = blockdiag_weight_to_dense_weight(state_dict[name]) + return state_dict + +def interpolate_pos_embedding(state_dict, out_seqlen, pos_embedding_name='model.pos_encoder.pe', interleave=False): + orig_emb = state_dict['state_dict'][pos_embedding_name] + assert (out_seqlen % orig_emb.shape[1]) == 0, 'out_seqlen must be a multiple of the original sequence length' + reps = [1 for i in orig_emb.shape] + reps[1] = out_seqlen // orig_emb.shape[1] + + if interleave: + assert math.isqrt(orig_emb.shape[1]) ** 2 == orig_emb.shape[1], 'interleave only works for square lengths' + assert math.isqrt(out_seqlen) ** 2 == out_seqlen, 'interleave only works for square lengths' + assert math.isqrt(reps[1]) ** 2 == reps[1], 'out_seqlen / seqlen must be a perfect square' + + emb_square = rearrange(orig_emb, 'b (h w) d -> b h w d', h = math.isqrt(orig_emb.shape[1])) + emb_square_expanded = emb_square.repeat_interleave(math.isqrt(reps[1]), axis=1).repeat_interleave(math.isqrt(reps[1]), axis=2) + new_emb = rearrange(emb_square_expanded, 'b h w d -> b (h w) d') + state_dict['state_dict'][pos_embedding_name] = new_emb + else: + state_dict['state_dict'][pos_embedding_name] = orig_emb.repeat(*reps) + + ret = remove_model_prefix(state_dict) + # # HACK: this is a hack for block-sparse flash attention + ret = { + k: v + for k, v in ret.items() + if not k.endswith('inner_attn.layout') + } + return ret + +def remove_model_prefix(state_dict): + # HACK: this is a hack to get the model to load properly, get rid of 'model.' prefix + for key in list(state_dict['state_dict'].keys()): + if key.startswith('model.'): + new_key = key[len('model.'):] + state_dict['state_dict'][new_key] = state_dict['state_dict'].pop(key) + + # HACK: something is wrong with the state dict being loaded... + return state_dict['state_dict'] diff --git a/training/src/utils/ddp_zero1.py b/training/src/utils/ddp_zero1.py new file mode 100644 index 000000000..da07c7bfc --- /dev/null +++ b/training/src/utils/ddp_zero1.py @@ -0,0 +1,101 @@ +# Meant to work with Pytorch's ZeroRedundancyOptimizer + +from typing import Any, Callable, Dict, List, Optional, Union +from pathlib import Path + +import torch +from torch.optim.optimizer import Optimizer +from torch.distributed.optim import ZeroRedundancyOptimizer + +from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.utilities.types import _PATH +# from lightning_lite.utilities.types import _PATH + + +# Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get +# the local state dict to avoid synchronization across GPUs. +# https://github.com/pytorch/pytorch/blob/0c7ca2d97ba5980a2af7dcd6b8106dc915e591cd/torch/distributed/optim/zero_redundancy_optimizer.py#L1131 +def get_zero_optimizer_state_dict_local(optimizer, global_rank): + optimizer._check_overlap_initialized() + + # Sync the exposed `param_groups` attributes to the local optimizer in + # case they have been updated + optimizer._sync_param_groups(optimizer.param_groups, optimizer.optim.param_groups) + + local_state_dict = optimizer.optim.state_dict() + state_dict = super(ZeroRedundancyOptimizer, optimizer).state_dict() + + # Update the global optimizer state with local state information, + # factoring in the translation from local to global indexing + rank = global_rank + # TODO: recursive copy to device + local_param_groups = local_state_dict["param_groups"] + global_param_groups = optimizer._partition_parameters()[rank] + assert len(local_param_groups) == len(global_param_groups), \ + "Mismatch between number of local and global parameter groups" + + for local_param_group, global_param_group in zip(local_param_groups, global_param_groups): + # `local_param_group` stores local indices, while + # `global_param_group` stores the tensors directly + local_param_indices = local_param_group["params"] + global_params = global_param_group["params"] + + assert len(local_param_indices) == len(global_params), \ + "Mismatch between number of local and global parameters in parameter group" + for local_param_index, global_param in zip(local_param_indices, global_params): + # Update the global parameter state, if any + if local_param_index in local_state_dict["state"]: + global_param_index = optimizer._param_to_index[global_param] + state_dict["state"][global_param_index] = local_state_dict["state"][local_param_index] + + # Sort the parameters in the state + state_dict["state"] = dict(sorted(state_dict["state"].items())) + return state_dict + + +class DDPStrategyZero1(DDPStrategy): + """To use ZeroRedundancyOptimizer, we need to shard the optimizer states when + saving/loading checkpoints. + """ + + strategy_name = "ddp_zero1" + + def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]: + if isinstance(optimizer, LightningOptimizer): + optimizer = optimizer._optimizer + if isinstance(optimizer, ZeroRedundancyOptimizer): + return get_zero_optimizer_state_dict_local(optimizer, self.global_rank) + else: + return optimizer.state_dict() + + def save_checkpoint( + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + ) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin + """ + filepath = Path(filepath) + filepath.mkdir(parents=True, exist_ok=True) + local_optimizer_states = checkpoint.pop('optimizer_states') + if self.is_global_zero: + self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt', + storage_options=storage_options) + self.checkpoint_io.save_checkpoint(local_optimizer_states, + filepath / f'{self.global_rank:03d}_optim_states.pt', + storage_options=storage_options) + + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + torch.cuda.empty_cache() + checkpoint_path = Path(checkpoint_path) + if checkpoint_path.is_file(): + return super().load_checkpoint(self, str(checkpoint_path)) + else: + assert checkpoint_path.is_dir() + global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt') + local_optimizer_states = self.checkpoint_io.load_checkpoint(checkpoint_path / f'{self.global_rank:03d}_optim_states.pt') + global_states['optimizer_states'] = local_optimizer_states + return global_states diff --git a/training/src/utils/ddp_zero2.py b/training/src/utils/ddp_zero2.py new file mode 100644 index 000000000..e526abfc3 --- /dev/null +++ b/training/src/utils/ddp_zero2.py @@ -0,0 +1,141 @@ +# Meant to work with Apex's DistributeFusedAdam + +from typing import Any, Callable, Dict, List, Optional, Union +from pathlib import Path +import types + +import torch +from torch.optim.optimizer import Optimizer +from torch.optim import LBFGS + +from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam + +from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.utilities.types import _PATH +# from lightning_lite.utilities.types import _PATH +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): + + def optimizer_step( # type: ignore[override] + self, + model: "pl.LightningModule", + optimizer, + optimizer_idx: int, + closure: Callable[[], Any], + **kwargs: Any, + ) -> Any: + if self.scaler is None: + # skip scaler logic, as bfloat16 does not require scaler + return NativeMixedPrecisionPlugin.optimizer_step( + self, optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs + ) + if isinstance(optimizer, LBFGS): + raise MisconfigurationException( + f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." + ) + closure_result = closure() + # HACK: we don't call self.scaler.unscale_ here. This is because DistributedFusedAdam + # optimizer internally takes the scale into account. + # If we call unscale_ here, it would be equivalent to unscaling the gradients twice. + # Not unscaling has the side-effect that the NormMonitor callback will report the + # gradient norm to be much larger than reality. + # # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook. + # self.scaler.unscale_(optimizer) + # This will call gradient clipping + self._after_closure(model, optimizer, optimizer_idx) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found + step_output = self.scaler.step(optimizer, **kwargs) + self.scaler.update() + return step_output + return closure_result + + def clip_grad_by_norm(self, optimizer: DistributedFusedAdam, clip_val: Union[int, float]) -> None: + """Clip gradients by norm.""" + # DistributedFusedAdam wants list, not generator + # Gradients have not be scaled, so we need to scale up the clip_val + if self.scaler is not None: + clip_val *= self.scaler.get_scale() + return optimizer.clip_grad_norm(clip_val) + + +class DDPStrategyZero2(DDPStrategy): + """To use Apex's DistributedFusedAdam, we need to shard the optimizer states when + saving/loading checkpoints. + """ + + strategy_name = "ddp_zero2" + + def __init__( + self, + *args, + precision_plugin: Optional[PrecisionPlugin] = DistAdamNativeMixedPrecisionPlugin, + # precision_plugin: Optional[PrecisionPlugin] = None, + **kwargs: Union[Any, Dict[str, Any]], + ) -> None: + super().__init__( + *args, precision_plugin=precision_plugin, **kwargs + ) + + @property + def precision_plugin(self) -> PrecisionPlugin: + return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin() + + @precision_plugin.setter + def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]) -> None: + self._precision_plugin = precision_plugin + # https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance + self._precision_plugin.optimizer_step = types.MethodType( + DistAdamNativeMixedPrecisionPlugin.optimizer_step, self._precision_plugin + ) + self._precision_plugin.clip_grad_by_norm = types.MethodType( + DistAdamNativeMixedPrecisionPlugin.clip_grad_by_norm, self._precision_plugin + ) + + def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]: + if isinstance(optimizer, LightningOptimizer): + optimizer = optimizer._optimizer + if isinstance(optimizer, DistributedFusedAdam): + return optimizer.state_dict(gather_on_root=False) + else: + return optimizer.state_dict() + + def save_checkpoint( + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + ) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin + """ + filepath = Path(filepath) + filepath.mkdir(parents=True, exist_ok=True) + local_optimizer_states = checkpoint.pop('optimizer_states') + if self.is_global_zero: + self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt', + storage_options=storage_options) + self.checkpoint_io.save_checkpoint(local_optimizer_states, + filepath / f'{self.global_rank:03d}_optim_states.pt', + storage_options=storage_options) + + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + torch.cuda.empty_cache() + checkpoint_path = Path(checkpoint_path) + if checkpoint_path.is_file(): + return super().load_checkpoint(self, str(checkpoint_path)) + else: + assert checkpoint_path.is_dir() + global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt') + local_optimizer_states = self.checkpoint_io.load_checkpoint( + checkpoint_path / f'{self.global_rank:03d}_optim_states.pt', + map_location='cuda' + ) + global_states['optimizer_states'] = local_optimizer_states + return global_states diff --git a/training/src/utils/distributed.py b/training/src/utils/distributed.py new file mode 100644 index 000000000..073b6135d --- /dev/null +++ b/training/src/utils/distributed.py @@ -0,0 +1,111 @@ +# Copied from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/distributed.py + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager + +import torch + + +def init_distributed(cuda): + """ + Initializes distributed backend. + :param cuda: (bool) if True initializes nccl backend, if False initializes + gloo backend + """ + world_size = int(os.environ.get('WORLD_SIZE', 1)) + distributed = (world_size > 1) + if distributed: + backend = 'nccl' if cuda else 'gloo' + torch.distributed.init_process_group(backend=backend, + init_method='env://') + assert torch.distributed.is_initialized() + return distributed + + +def barrier(): + """ + Call torch.distributed.barrier() if distritubed is in use + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + +def get_rank(): + """ + Gets distributed rank or returns zero if distributed is not initialized. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = 0 + return rank + + +def get_world_size(): + """ + Gets total number of distributed workers or returns one if distributed is + not initialized. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + else: + world_size = 1 + return world_size + + +def all_reduce_item(value, op='sum'): + """ + All-reduces single scalar value if distributed is in use + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if op == 'sum' or op == 'mean': + dop = torch.distributed.ReduceOp.SUM + elif op == 'min': + dop = torch.distributed.ReduceOp.MIN + elif op == 'max': + dop = torch.distributed.ReduceOp.MAX + elif op == 'product': + dop = torch.distributed.ReduceOp.PRODUCT + else: + raise RuntimeError('Unsupported reduce op') + + backend = torch.distributed.get_backend() + if backend == torch.distributed.Backend.NCCL: + device = torch.device('cuda') + elif backend == torch.distributed.Backend.GLOO: + device = torch.device('cpu') + else: + raise RuntimeError('Unsupported distributed backend') + + tensor = torch.tensor(value, device=device) + torch.distributed.all_reduce(tensor, dop) + if op == 'mean': + tensor /= get_world_size() + ret = tensor.item() + else: + ret = value + return ret + + +@contextmanager +def sync_workers(): + """ + Yields distributed rank and synchronizes all workers on exit. + """ + rank = get_rank() + yield rank + barrier() diff --git a/training/src/utils/ema.py b/training/src/utils/ema.py new file mode 100644 index 000000000..9fb3beb68 --- /dev/null +++ b/training/src/utils/ema.py @@ -0,0 +1,280 @@ +# Copied from https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py +from __future__ import division +from __future__ import unicode_literals + +from typing import Iterable, Optional +import weakref +import copy +import contextlib + +import torch + + +def to_float_maybe(x): + return x.float() if x.dtype in [torch.float16, torch.bfloat16] else x + + +# Partially based on: +# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + Args: + parameters: Iterable of `torch.nn.Parameter` (typically from + `model.parameters()`). + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing + averages. + """ + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float, + use_num_updates: bool = True + ): + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.decay = decay + self.num_updates = 0 if use_num_updates else None + parameters = list(parameters) + self.shadow_params = [to_float_maybe(p.clone().detach()) + for p in parameters if p.requires_grad] + self.collected_params = None + # By maintaining only a weakref to each parameter, + # we maintain the old GC behaviour of ExponentialMovingAverage: + # if the model goes out of scope but the ExponentialMovingAverage + # is kept, no references to the model or its parameters will be + # maintained, and the model will be cleaned up. + self._params_refs = [weakref.ref(p) for p in parameters] + + def _get_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] + ) -> Iterable[torch.nn.Parameter]: + if parameters is None: + parameters = [p() for p in self._params_refs] + if any(p is None for p in parameters): + raise ValueError( + "(One of) the parameters with which this " + "ExponentialMovingAverage " + "was initialized no longer exists (was garbage collected);" + " please either provide `parameters` explicitly or keep " + "the model to which they belong from being garbage " + "collected." + ) + return parameters + else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) + return parameters + + def update( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Update currently maintained parameters. + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + decay = self.decay + if self.num_updates is not None: + self.num_updates += 1 + decay = min( + decay, + (1 + self.num_updates) / (10 + self.num_updates) + ) + one_minus_decay = 1.0 - decay + if parameters[0].device != self.shadow_params[0].device: + self.to(device=parameters[0].device) + with torch.no_grad(): + parameters = [p for p in parameters if p.requires_grad] + for s_param, param in zip(self.shadow_params, parameters): + torch.lerp(s_param, param.to(dtype=s_param.dtype), one_minus_decay, out=s_param) + + def copy_to( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Copy current averaged parameters into given collection of parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + param.data.copy_(s_param.data) + + def store( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.collected_params = [ + param.clone() + for param in parameters + if param.requires_grad + ] + + def restore( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) + parameters = self._get_parameters(parameters) + for c_param, param in zip(self.collected_params, parameters): + if param.requires_grad: + param.data.copy_(c_param.data) + + @contextlib.contextmanager + def average_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ): + r""" + Context manager for validation/inference with averaged parameters. + Equivalent to: + ema.store() + ema.copy_to() + try: + ... + finally: + ema.restore() + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.store(parameters) + self.copy_to(parameters) + try: + yield + finally: + self.restore(parameters) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict.""" + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "num_updates": self.num_updates, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params + } + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.num_updates = state_dict["num_updates"] + assert self.num_updates is None or isinstance(self.num_updates, int), \ + "Invalid num_updates" + + self.shadow_params = state_dict["shadow_params"] + assert isinstance(self.shadow_params, list), \ + "shadow_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.shadow_params + ), "shadow_params must all be Tensors" + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len(self.shadow_params), \ + "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistent with torch.optim.Optimizer, cast things to consistent + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = to_float_maybe(self.shadow_params[i].to( + device=p.device, dtype=p.dtype + )) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) diff --git a/training/src/utils/flops.py b/training/src/utils/flops.py new file mode 100644 index 000000000..bb1ca7902 --- /dev/null +++ b/training/src/utils/flops.py @@ -0,0 +1,45 @@ +# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py +import torch + +try: + from deepspeed.profiling.flops_profiler import get_model_profile + has_deepspeed_profiling = True +except ImportError as e: + has_deepspeed_profiling = False + +try: + from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table + from fvcore.nn import ActivationCountAnalysis + has_fvcore_profiling = True +except ImportError as e: + FlopCountAnalysis = None + ActivationCountAnalysis = None + has_fvcore_profiling = False + + +def profile_deepspeed(model, input_size=(3, 224, 224), input_dtype=torch.float32, + batch_size=1, detailed=False): + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + flops, macs, params = get_model_profile( + model=model, + args=torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype), + print_profile=detailed, # prints the model graph with the measured profile attached to each module + detailed=detailed, # print the detailed profile + warm_up=10, # the number of warm-ups before measuring the time of each module + as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) + output_file=None, # path to the output file. If None, the profiler prints to stdout. + ignore_modules=None) # the list of modules to ignore in the profiling + return macs, 0 # no activation count in DS + + +def profile_fvcore(model, input_size=(3, 224, 224), input_dtype=torch.float32, max_depth=4, + batch_size=1, detailed=False, force_cpu=False): + if force_cpu: + model = model.to('cpu') + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + example_input = torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype) + fca = FlopCountAnalysis(model, example_input) + aca = ActivationCountAnalysis(model, example_input) + if detailed: + print(flop_count_table(fca, max_depth=max_depth)) + return fca, fca.total(), aca, aca.total() diff --git a/training/src/utils/gpu_affinity.py b/training/src/utils/gpu_affinity.py new file mode 100644 index 000000000..8636d504a --- /dev/null +++ b/training/src/utils/gpu_affinity.py @@ -0,0 +1,142 @@ +import collections +import math +import os +import pathlib +import re + +import pynvml + +pynvml.nvmlInit() + + +def systemGetDriverVersion(): + return pynvml.nvmlSystemGetDriverVersion() + + +def deviceGetCount(): + return pynvml.nvmlDeviceGetCount() + + +class device: + # assume nvml returns list of 64 bit ints + _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) + + def __init__(self, device_idx): + super().__init__() + self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) + + def getName(self): + return pynvml.nvmlDeviceGetName(self.handle) + + def getCpuAffinity(self): + affinity_string = '' + for j in pynvml.nvmlDeviceGetCpuAffinity( + self.handle, device._nvml_affinity_elements + ): + # assume nvml returns list of 64 bit ints + affinity_string = '{:064b}'.format(j) + affinity_string + affinity_list = [int(x) for x in affinity_string] + affinity_list.reverse() # so core 0 is in 0th element of list + + ret = [i for i, e in enumerate(affinity_list) if e != 0] + return ret + + +def set_socket_affinity(gpu_id): + dev = device(gpu_id) + affinity = dev.getCpuAffinity() + os.sched_setaffinity(0, affinity) + + +def set_single_affinity(gpu_id): + dev = device(gpu_id) + affinity = dev.getCpuAffinity() + os.sched_setaffinity(0, affinity[:1]) + + +def set_single_unique_affinity(gpu_id, nproc_per_node): + devices = [device(i) for i in range(nproc_per_node)] + socket_affinities = [dev.getCpuAffinity() for dev in devices] + + siblings_list = get_thread_siblings_list() + siblings_dict = dict(siblings_list) + + # remove siblings + for idx, socket_affinity in enumerate(socket_affinities): + socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values())) + + affinities = [] + assigned = [] + + for socket_affinity in socket_affinities: + for core in socket_affinity: + if core not in assigned: + affinities.append([core]) + assigned.append(core) + break + os.sched_setaffinity(0, affinities[gpu_id]) + + +def set_socket_unique_affinity(gpu_id, nproc_per_node, mode): + device_ids = [device(i) for i in range(nproc_per_node)] + socket_affinities = [dev.getCpuAffinity() for dev in device_ids] + + siblings_list = get_thread_siblings_list() + siblings_dict = dict(siblings_list) + + # remove siblings + for idx, socket_affinity in enumerate(socket_affinities): + socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values())) + + socket_affinities_to_device_ids = collections.defaultdict(list) + + for idx, socket_affinity in enumerate(socket_affinities): + socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx) + + for socket_affinity, device_ids in socket_affinities_to_device_ids.items(): + devices_per_group = len(device_ids) + cores_per_device = len(socket_affinity) // devices_per_group + for group_id, device_id in enumerate(device_ids): + if device_id == gpu_id: + if mode == 'interleaved': + affinity = list(socket_affinity[group_id::devices_per_group]) + elif mode == 'continuous': + affinity = list(socket_affinity[group_id*cores_per_device:(group_id+1)*cores_per_device]) + else: + raise RuntimeError('Unknown set_socket_unique_affinity mode') + + # reintroduce siblings + affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict] + os.sched_setaffinity(0, affinity) + + +def get_thread_siblings_list(): + path = '/sys/devices/system/cpu/cpu*/topology/thread_siblings_list' + thread_siblings_list = [] + pattern = re.compile(r'(\d+)\D(\d+)') + for fname in pathlib.Path(path[0]).glob(path[1:]): + with open(fname) as f: + content = f.read().strip() + res = pattern.findall(content) + if res: + pair = tuple(map(int, res[0])) + thread_siblings_list.append(pair) + return thread_siblings_list + + +def set_affinity(gpu_id, nproc_per_node, mode='socket'): + if mode == 'socket': + set_socket_affinity(gpu_id) + elif mode == 'single': + set_single_affinity(gpu_id) + elif mode == 'single_unique': + set_single_unique_affinity(gpu_id, nproc_per_node) + elif mode == 'socket_unique_interleaved': + set_socket_unique_affinity(gpu_id, nproc_per_node, 'interleaved') + elif mode == 'socket_unique_continuous': + set_socket_unique_affinity(gpu_id, nproc_per_node, 'continuous') + else: + raise RuntimeError('Unknown affinity mode') + + affinity = os.sched_getaffinity(0) + return affinity diff --git a/training/src/utils/utils.py b/training/src/utils/utils.py new file mode 100644 index 000000000..32e64ab4d --- /dev/null +++ b/training/src/utils/utils.py @@ -0,0 +1,146 @@ +import logging +import warnings +from typing import List, Sequence + +import pytorch_lightning as pl +import rich.syntax +import rich.tree +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.utilities import rank_zero_only + + +# Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging +class LoggingContext: + def __init__(self, logger, level=None, handler=None, close=True): + self.logger = logger + self.level = level + self.handler = handler + self.close = close + + def __enter__(self): + if self.level is not None: + self.old_level = self.logger.level + self.logger.setLevel(self.level) + if self.handler: + self.logger.addHandler(self.handler) + + def __exit__(self, et, ev, tb): + if self.level is not None: + self.logger.setLevel(self.old_level) + if self.handler: + self.logger.removeHandler(self.handler) + if self.handler and self.close: + self.handler.close() + # implicit return of None => don't swallow exceptions + + +def get_logger(name=__name__) -> logging.Logger: + """Initializes multi-GPU-friendly python logger.""" + + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger + + +def extras(config: DictConfig) -> None: + """A couple of optional utilities, controlled by main config file: + - disabling warnings + - forcing debug friendly configuration + - verifying experiment name is set when running in experiment mode + Modifies DictConfig in place. + Args: + config (DictConfig): Configuration composed by Hydra. + """ + + log = get_logger(__name__) + + # disable python warnings if + if config.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # verify experiment name is set when running in experiment mode + if config.get("experiment_mode") and not config.get("name"): + log.info( + "Running in experiment mode without the experiment name specified! " + "Use `python run.py mode=exp name=experiment_name`" + ) + log.info("Exiting...") + exit() + + # force debugger friendly configuration if + # debuggers don't like GPUs and multiprocessing + if config.trainer.get("fast_dev_run"): + log.info("Forcing debugger friendly configuration! ") + if config.trainer.get("gpus"): + config.trainer.gpus = 0 + if config.datamodule.get("pin_memory"): + config.datamodule.pin_memory = False + if config.datamodule.get("num_workers"): + config.datamodule.num_workers = 0 + + +@rank_zero_only +def print_config( + config: DictConfig, + fields: Sequence[str] = ( + "trainer", + "model", + "datamodule", + "train", + "eval", + "callbacks", + "logger", + "seed", + "name", + ), + resolve: bool = True, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + Args: + config (DictConfig): Configuration composed by Hydra. + fields (Sequence[str], optional): Determines which main fields from config will + be printed and in what order. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + """ + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + for field in fields: + branch = tree.add(field, style=style, guide_style=style) + + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, DictConfig): + branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + rich.print(tree) + + with open("config_tree.txt", "w") as fp: + rich.print(tree, file=fp) + + +def finish( + config: DictConfig, + model: pl.LightningModule, + datamodule: pl.LightningDataModule, + trainer: pl.Trainer, + callbacks: List[pl.Callback], + logger: List[pl.loggers.LightningLoggerBase], +) -> None: + """Makes sure everything closed properly.""" + + # without this sweeps with wandb logger might crash! + for lg in logger: + if isinstance(lg, pl.loggers.wandb.WandbLogger): + import wandb + + wandb.finish() diff --git a/training/tests/datamodules/test_language_modeling_hf.py b/training/tests/datamodules/test_language_modeling_hf.py new file mode 100644 index 000000000..47da39e25 --- /dev/null +++ b/training/tests/datamodules/test_language_modeling_hf.py @@ -0,0 +1,218 @@ +import os +from pathlib import Path +current_dir = Path(__file__).parent.absolute() + + +import pytest + +import torch + +import dotenv + +from src.datamodules.language_modeling_hf import LMDataModule + +# load environment variables from `.env` file if it exists +# recursively searches for `.env` in all folders starting from work dir +dotenv.load_dotenv(override=True) + + +def div_up(x: int, y: int) -> int: + return (x + y - 1) // y + + +# https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170 +def num_cpu_cores(): + try: + import psutil + return psutil.cpu_count(logical=False) + except ImportError: + return len(os.sched_getaffinity(0)) + + +class TestLMDataModule: + + def test_wikitext2(self): + batch_size = 7 + dataset_name = 'wikitext' + dataset_config_name = 'wikitext-2-raw-v1' + data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) + cache_dir = data_dir / 'wikitext-2' / 'cache' + max_length = 1024 + datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', + dataset_config_name=dataset_config_name, + max_length=max_length, cache_dir=cache_dir, + add_eos=False, batch_size=batch_size, num_workers=4) + datamodule.prepare_data() + datamodule.setup(stage='fit') + train_loader = datamodule.train_dataloader() + val_loader = datamodule.val_dataloader() + datamodule.setup(stage='test') + test_loader = datamodule.test_dataloader() + train_len = 2391884 + val_len = 247289 + test_len = 283287 + assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) + assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) + assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) + for loader in [train_loader, val_loader, test_loader]: + x, y = next(iter(loader)) + assert x.dim() == 2 + assert x.shape == (batch_size, max_length) + assert x.dtype == torch.long + assert torch.allclose(x[:, 1:], y[:, :-1]) + + def test_wikitext103(self): + batch_size = 7 + dataset_name = 'wikitext' + dataset_config_name = 'wikitext-103-raw-v1' + data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) + cache_dir = data_dir / 'wikitext-103' / 'cache' + max_length = 1024 + datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', + dataset_config_name=dataset_config_name, + max_length=max_length, cache_dir=cache_dir, + add_eos=False, batch_size=batch_size, num_workers=4) + datamodule.prepare_data() + datamodule.setup(stage='fit') + train_loader = datamodule.train_dataloader() + val_loader = datamodule.val_dataloader() + datamodule.setup(stage='test') + test_loader = datamodule.test_dataloader() + train_len = 117920140 + val_len = 247289 + test_len = 283287 + assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) + assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) + assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) + for loader in [train_loader, val_loader, test_loader]: + x, y = next(iter(loader)) + assert x.dim() == 2 + assert x.shape == (batch_size, max_length) + assert x.dtype == torch.long + assert torch.allclose(x[:, 1:], y[:, :-1]) + + def test_openwebtext(self): + batch_size = 8 + dataset_name = 'openwebtext' + dataset_config_name = None + data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) + cache_dir = data_dir / 'openwebtext' / 'cache' + max_length = 1024 + datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', + dataset_config_name=dataset_config_name, + max_length=max_length, cache_dir=cache_dir, + add_eos=True, batch_size=batch_size, + num_workers=num_cpu_cores() // 2) + datamodule.prepare_data() + datamodule.setup(stage='fit') + train_loader = datamodule.train_dataloader() + val_loader = datamodule.val_dataloader() + datamodule.setup(stage='test') + test_loader = datamodule.test_dataloader() + train_len = 9035582198 + val_len = 4434897 + test_len = 4434897 + assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) + assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) + assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) + for loader in [train_loader, val_loader, test_loader]: + x, y = next(iter(loader)) + assert x.dim() == 2 + assert x.shape == (batch_size, max_length) + assert x.dtype == torch.long + assert torch.allclose(x[:, 1:], y[:, :-1]) + + def test_lambada(self): + batch_size = 8 + dataset_name = 'lambada' + dataset_config_name = None + data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) + cache_dir = data_dir / 'lambada' / 'cache' + max_length = 1024 + datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', + dataset_config_name=dataset_config_name, + max_length=max_length, cache_dir=cache_dir, + add_eos=True, batch_size=batch_size, + num_workers=64) + datamodule.prepare_data() + datamodule.setup(stage='fit') + train_loader = datamodule.train_dataloader() + val_loader = datamodule.val_dataloader() + datamodule.setup(stage='test') + test_loader = datamodule.test_dataloader() + train_len = 9035582198 + val_len = 4434897 + test_len = 4434897 + assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) + assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) + assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) + for loader in [train_loader, val_loader, test_loader]: + x, y = next(iter(loader)) + assert x.dim() == 2 + assert x.shape == (batch_size, max_length) + assert x.dtype == torch.long + assert torch.allclose(x[:, 1:], y[:, :-1]) + + def test_the_pile(self): + batch_size = 8 + dataset_name = 'the_pile' + dataset_config_name = None + data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) + cache_dir = data_dir / 'the_pile' / 'cache' + max_length = 2048 + # Dataset is too large to fit into memory, need to use disk for concatenation + datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', + dataset_config_name=dataset_config_name, + max_length=max_length, cache_dir=cache_dir, + add_eos=True, batch_size=batch_size, + num_workers=num_cpu_cores() // 2, use_shmem=False) + datamodule.prepare_data() + datamodule.setup(stage='fit') + train_loader = datamodule.train_dataloader() + val_loader = datamodule.val_dataloader() + datamodule.setup(stage='test') + test_loader = datamodule.test_dataloader() + train_len = 374337375694 + val_len = 383326395 + test_len = 373297018 + assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) + assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) + assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) + for loader in [train_loader, val_loader, test_loader]: + x, y = next(iter(loader)) + assert x.dim() == 2 + assert x.shape == (batch_size, max_length) + assert x.dtype == torch.long + assert torch.allclose(x[:, 1:], y[:, :-1]) + + def test_pg19(self): + batch_size = 8 + dataset_name = 'pg19' + dataset_config_name = None + data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) + cache_dir = data_dir / 'pg19' / 'cache' + max_length = 2048 + # Dataset is too large to fit into memory, need to use disk for concatenation + datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', + dataset_config_name=dataset_config_name, + max_length=max_length, cache_dir=cache_dir, + add_eos=True, batch_size=batch_size, + num_workers=num_cpu_cores() // 2) + datamodule.prepare_data() + datamodule.setup(stage='fit') + train_loader = datamodule.train_dataloader() + val_loader = datamodule.val_dataloader() + datamodule.setup(stage='test') + test_loader = datamodule.test_dataloader() + train_len = 3066544128 + val_len = 4653056 + test_len = 10584064 + assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) + assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) + assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) + for loader in [train_loader, val_loader, test_loader]: + x, y = next(iter(loader)) + assert x.dim() == 2 + assert x.shape == (batch_size, max_length) + assert x.dtype == torch.long + assert torch.allclose(x[:, 1:], y[:, :-1])