From 4894d2b17725c14835678d4ad3c2fc94880db0f6 Mon Sep 17 00:00:00 2001 From: Tobias Fischer Date: Mon, 30 Sep 2024 14:39:22 +0200 Subject: [PATCH] TE integration --- Dockerfile | 13 +++++++-- megatron/model/gpt2_model.py | 19 +++++++++++++ megatron/model/transformer.py | 42 +++++++++++++++++++++++++++- megatron/neox_arguments/neox_args.py | 1 + 4 files changed, 72 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index c338bacec..84b045646 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,7 @@ # Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM nvcr.io/nvidia/pytorch:24.02-py3 +FROM nvcr.io/nvidia/pytorch:24.06-py3 ENV DEBIAN_FRONTEND=noninteractive @@ -21,7 +24,7 @@ LABEL org.opencontainers.image.version = "2.0" LABEL org.opencontainers.image.authors = "contact@eleuther.ai" LABEL org.opencontainers.image.source = "https://www.github.com/eleutherai/gpt-neox" LABEL org.opencontainers.image.licenses = " Apache-2.0" -LABEL org.opencontainers.image.base.name="nvcr.io/nvidia/pytorch:24.02-py3" +LABEL org.opencontainers.image.base.name="nvcr.io/nvidia/pytorch:24.06-py3" #### System package (uses default Python 3 version in Ubuntu 20.04) RUN apt-get update -y && \ @@ -82,6 +85,12 @@ COPY megatron/fused_kernels/ /megatron/fused_kernels WORKDIR /megatron/fused_kernels RUN python setup.py install +SHELL ["/bin/bash", "-c"] + +RUN DS_BUILD_FUSED_LAMB=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_TRANSFORMER=1 DS_BUILD_STOCHASTIC_TRANSFORMER=1 DS_BUILD_UTILS=1 \ + TORCH_CUDA_ARCH_LIST="8.0 9.0+PTX" \ + python -m pip install git+https://github.com/microsoft/DeepSpeed.git@v0.14.4 + # Clear staging RUN mkdir -p /tmp && chmod 0777 /tmp diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 7899048db..085dce4f2 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -30,6 +30,7 @@ from megatron import mpu from megatron.mpu import ParallelRelativePositionBias from megatron.model.transformer import ( + ParallelTETransformerLayerPipe, ParallelTransformerLayerPipe, NormPipe, ParallelLinearPipe, @@ -271,6 +272,24 @@ def init_specs(self): layer_number=i, ) ) + elif layer_type in ["TE"]: + self.specs.append( + LayerSpec( + ParallelTETransformerLayerPipe, + hidden_size=self.neox_args.hidden_size, + ffn_hidden_size=self.neox_args.hidden_size * 4, + num_attention_heads=self.neox_args.num_attention_heads, + hidden_dropout=self.neox_args.hidden_dropout, + attention_dropout=self.neox_args.attention_dropout, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + layer_number=i + 1, + params_dtype=self.neox_args.params_dtype, + attn_input_format="sbhd", + seq_length=self.neox_args.seq_length, + set_parallel_mode=True + ) + ) else: self.specs.append( LayerSpec( diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index d112a7461..976db5767 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -38,6 +38,14 @@ apply_rotary_pos_emb, AliBi, ) + +from transformer_engine.pytorch import TransformerLayer +from transformer_engine.pytorch.attention import RotaryPositionEmbedding +from transformer_engine.pytorch import checkpoint + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + from megatron.model.fused_rope import ( FusedRoPEFunc, fused_apply_rotary_pos_emb_cached, @@ -460,7 +468,7 @@ def __init__( >= packaging.version.Version("2.4.0.post1") ) ) - self.sparse = self.attention_type not in ("global", "flash") + self.sparse = self.attention_type not in ("global", "flash", "TE") if self.gqa: assert not self.sparse @@ -1301,6 +1309,38 @@ def forward(self, args): self.last_moe_loss = moe_loss return output, attention_mask +class ParallelTETransformerLayerPipe(TransformerLayer): + """ + Layer in the spirit of ParallelTransformerLayerPipe, but with the TE transformer layer. + """ + def __init__(self, *args, **kwargs): + + self.tp_group = mpu.get_tensor_model_parallel_group() + + super().__init__( + *args, + **kwargs, + tp_group=self.tp_group + ) + + hidden_size_per_attention_head = mpu.divide( + kwargs["hidden_size"], kwargs["num_attention_heads"] + ) + PE = RotaryPositionEmbedding(dim=hidden_size_per_attention_head) + self.rotary_pos_emb = PE(kwargs["seq_length"]).to(device="cuda") + + def forward(self, args): + assert len(args) == 2, ("TE transformer layer pipe must have two args, hidden_states and the attention mask. " + "The mask will be discarded") + hidden_states, attention_mask = args + + fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max") + with te.fp8_autocast(enabled=False, fp8_recipe=fp8_recipe, fp8_group=self.tp_group): + output = checkpoint(super().forward, hidden_states, distribute_saved_activations=True, tp_group=self.tp_group, rotary_pos_emb=self.rotary_pos_emb) + + return output, attention_mask + class ParallelLinearPipe(ParallelLinear): """Another helper class to pass presents through to the output when doing inference with a Pipe Parallel model""" diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index b5dad71f5..8d157c117 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -28,6 +28,7 @@ ATTENTION_TYPE_CHOICES = [ "global", + "TE", "local", "sparse_fixed", "sparse_variable",