Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TE integration via full TransformerLayer #1297

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

FROM nvcr.io/nvidia/pytorch:24.02-py3
FROM nvcr.io/nvidia/pytorch:24.06-py3

ENV DEBIAN_FRONTEND=noninteractive

Expand All @@ -21,7 +24,7 @@ LABEL org.opencontainers.image.version = "2.0"
LABEL org.opencontainers.image.authors = "[email protected]"
LABEL org.opencontainers.image.source = "https://www.github.com/eleutherai/gpt-neox"
LABEL org.opencontainers.image.licenses = " Apache-2.0"
LABEL org.opencontainers.image.base.name="nvcr.io/nvidia/pytorch:24.02-py3"
LABEL org.opencontainers.image.base.name="nvcr.io/nvidia/pytorch:24.06-py3"

#### System package (uses default Python 3 version in Ubuntu 20.04)
RUN apt-get update -y && \
Expand Down Expand Up @@ -82,6 +85,12 @@ COPY megatron/fused_kernels/ /megatron/fused_kernels
WORKDIR /megatron/fused_kernels
RUN python setup.py install

SHELL ["/bin/bash", "-c"]

RUN DS_BUILD_FUSED_LAMB=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_TRANSFORMER=1 DS_BUILD_STOCHASTIC_TRANSFORMER=1 DS_BUILD_UTILS=1 \
TORCH_CUDA_ARCH_LIST="8.0 9.0+PTX" \
python -m pip install git+https://github.com/microsoft/[email protected]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably instead be latest deeperspeed. Can't hardcode the arch list.


# Clear staging
RUN mkdir -p /tmp && chmod 0777 /tmp

Expand Down
19 changes: 19 additions & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from megatron import mpu
from megatron.mpu import ParallelRelativePositionBias
from megatron.model.transformer import (
ParallelTETransformerLayerPipe,
ParallelTransformerLayerPipe,
NormPipe,
ParallelLinearPipe,
Expand Down Expand Up @@ -271,6 +272,24 @@ def init_specs(self):
layer_number=i,
)
)
elif layer_type in ["TE"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs tested with PP and TP, since we'd be relying on two external codebases (deepspeed for PP, TE for TP) whose topologies probably don't play nicely together.

self.specs.append(
LayerSpec(
ParallelTETransformerLayerPipe,
hidden_size=self.neox_args.hidden_size,
ffn_hidden_size=self.neox_args.hidden_size * 4,
num_attention_heads=self.neox_args.num_attention_heads,
hidden_dropout=self.neox_args.hidden_dropout,
attention_dropout=self.neox_args.attention_dropout,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
layer_number=i + 1,
params_dtype=self.neox_args.params_dtype,
attn_input_format="sbhd",
seq_length=self.neox_args.seq_length,
set_parallel_mode=True
)
)
else:
self.specs.append(
LayerSpec(
Expand Down
42 changes: 41 additions & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@
apply_rotary_pos_emb,
AliBi,
)

from transformer_engine.pytorch import TransformerLayer
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch import checkpoint

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

from megatron.model.fused_rope import (
FusedRoPEFunc,
fused_apply_rotary_pos_emb_cached,
Expand Down Expand Up @@ -460,7 +468,7 @@ def __init__(
>= packaging.version.Version("2.4.0.post1")
)
)
self.sparse = self.attention_type not in ("global", "flash")
self.sparse = self.attention_type not in ("global", "flash", "TE")

if self.gqa:
assert not self.sparse
Expand Down Expand Up @@ -1301,6 +1309,38 @@ def forward(self, args):
self.last_moe_loss = moe_loss
return output, attention_mask

class ParallelTETransformerLayerPipe(TransformerLayer):
"""
Layer in the spirit of ParallelTransformerLayerPipe, but with the TE transformer layer.
"""
def __init__(self, *args, **kwargs):

self.tp_group = mpu.get_tensor_model_parallel_group()

super().__init__(
*args,
**kwargs,
tp_group=self.tp_group
)

hidden_size_per_attention_head = mpu.divide(
kwargs["hidden_size"], kwargs["num_attention_heads"]
)
PE = RotaryPositionEmbedding(dim=hidden_size_per_attention_head)
self.rotary_pos_emb = PE(kwargs["seq_length"]).to(device="cuda")

def forward(self, args):
assert len(args) == 2, ("TE transformer layer pipe must have two args, hidden_states and the attention mask. "
"The mask will be discarded")
hidden_states, attention_mask = args

fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should instead accept neox_arg from the new te_fp8_format

fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max")
with te.fp8_autocast(enabled=False, fp8_recipe=fp8_recipe, fp8_group=self.tp_group):
output = checkpoint(super().forward, hidden_states, distribute_saved_activations=True, tp_group=self.tp_group, rotary_pos_emb=self.rotary_pos_emb)

return output, attention_mask


class ParallelLinearPipe(ParallelLinear):
"""Another helper class to pass presents through to the output when doing inference with a Pipe Parallel model"""
Expand Down
1 change: 1 addition & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

ATTENTION_TYPE_CHOICES = [
"global",
"TE",
"local",
"sparse_fixed",
"sparse_variable",
Expand Down