Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test different layer norm #270

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numbers
import torch
from megatron import mpu
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib
Expand All @@ -31,7 +32,6 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):

ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
Expand Down Expand Up @@ -84,6 +84,18 @@ def reset_parameters(self):


def forward(self, input):
weights = [torch.empty_like(self.weight) for tp in range(mpu.get_tensor_model_parallel_world_size())]
torch.distributed.all_gather(weights, self.weight, group=mpu.get_tensor_model_parallel_group())
biases = [torch.empty_like(self.bias) for tp in range(mpu.get_tensor_model_parallel_world_size())]
torch.distributed.all_gather(biases, self.bias, group=mpu.get_tensor_model_parallel_group())
if any(torch.any(weight != self.weight) for weight in weights):
if mpu.get_tensor_model_parallel_rank() == 0:
print("Weight sync failed")
print(weights)
if any(torch.any(bias != self.bias) for bias in biases):
if mpu.get_tensor_model_parallel_rank() == 0:
print("Bias sync failed")
print(biases)

return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
Expand Down
2 changes: 2 additions & 0 deletions megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def symbolic(graph, input_):

@staticmethod
def forward(ctx, input_):
# TODO: we need to assert that the input_ are all the same within a group
return input_

@staticmethod
Expand All @@ -102,6 +103,7 @@ def forward(ctx, input_):

@staticmethod
def backward(ctx, grad_output):
# TODO: we need to assert that the grad_output are all the same within a group
return grad_output


Expand Down
13 changes: 13 additions & 0 deletions tests/ds_config_bf16.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"train_micro_batch_size_per_gpu": 1,
"train_batch_size": 16,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 0
},
"bf16": {
"enabled": true
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
1 change: 0 additions & 1 deletion tests/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,5 @@ def test_tokenizer_raise_error_make_vocab_size_divisible_by(self):

self.assertEqual(str(exc_info.value), "5121 is not divisible by 128")


if __name__ == '__main__':
unittest.main()
118 changes: 116 additions & 2 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import re
import unittest
from pathlib import Path

import torch
from parameterized import parameterized

from megatron.testing_utils import (
Expand All @@ -31,7 +33,7 @@
require_bnb_non_decorator,
require_deepspeed,
require_torch_gpu,
set_seed
set_seed, torch_assert_equal
)

set_seed(42)
Expand All @@ -50,7 +52,7 @@ def get_3d_dimensions():
dp_size = 2
pp_size = 2
tp_size = 2
if num_gpus >= 4:
elif num_gpus >= 4:
dp_size = 1
pp_size = 2
tp_size = 2
Expand Down Expand Up @@ -592,3 +594,115 @@ def test_skip_train_iteration(self):
train_iterations = range(1,10)
for i in train_iterations:
self.assertTrue(f"iteration {i:8d}/" in cs.out)

@parameterized.expand(["bf16", "fp16"])
def test_layer_norm_consistent(self, variation):
src_dir = self.src_dir
output_dir = self.get_auto_remove_tmp_dir()
num_gpus = 2
seq_len = 128
data_dir = f"{self.data_dir}/gpt2"
args = f"""
--tensor-model-parallel-size {2}
--pipeline-model-parallel-size {1}
--distributed-backend nccl

--log-interval 1
--save-interval 10
--eval-interval 10
--eval-iters 5
--checkpoint-activations
--partition-activations
--exit-interval {20}

--merge-file {data_dir}/gpt2-tiny-merges.txt
--vocab-file {data_dir}/gpt2-tiny-vocab.json
--save {output_dir}/checkpoints
--load {output_dir}/checkpoints
--data-path {data_dir}/meg-gpt2-openwebtext_text_document
--tensorboard-dir {output_dir}/tensorboard
--tensorboard-queue-size 5
--log-timers-to-tensorboard
--log-batch-size-to-tensorboard
--log-validation-ppl-to-tensorboard

--num-layers 2
--hidden-size 64
--num-attention-heads 2
--seq-length {seq_len}
--max-position-embeddings 1024
--micro-batch-size 2
--global-batch-size 16

--optimizer adam
--adam-beta1 0.9
--adam-beta2 0.95
--adam-eps 1e-8
--lr 1e-1
--clip-grad 1.0
--weight-decay 1e-1
--embed-layernorm

--log-level debug
--log-level-replica info

--rampup-batch-size 2 2 200
--train-samples 200

--position-embedding-type alibi
""".split()

ds_args = f"""
--deepspeed
--deepspeed-activation-checkpointing
""".split()

if variation == "bf16":
args.append("--bf16")
ds_args += [
"--zero-stage", "0",
"--deepspeed_config", f"{self.test_file_dir_str}/ds_config_bf16.json"
]
elif variation == "fp16":
args.append("--fp16")
ds_args += [
"--zero-stage", "1",
"--deepspeed_config", f"{self.test_file_dir_str}/ds_config.json"
]

# args, ds_args, num_gpus = self.get_variation_config("base", output_dir, n_samples=200)

script = [f"{src_dir}/pretrain_gpt.py"]
launcher = get_launcher(num_gpus)
cmd = launcher + script + args + ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die

with CaptureStdout() as cs:
execute_subprocess_async(cmd, env=self.get_env())

checkpoints = ["global_step10", "global_step20"]

# Check transformer layer norm
keys_to_compare = ["input_layernorm.weight", "input_layernorm.bias", "post_attention_layernorm.weight", "post_attention_layernorm.bias"]
files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [3,4]]
for checkpoint in checkpoints:
checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint)
for key in keys_to_compare:
for files in files_to_compare:
weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files]
ref = weights[0]
for weight in weights[1:]:
torch_assert_equal(ref, weight, check_device=False)

# Check embed layer norm
keys_to_compare = ["word_embeddings.norm.weight"]
files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [1]]
for checkpoint in checkpoints:
checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint)
for key in keys_to_compare:
for files in files_to_compare:
weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files]
ref = weights[0]
for weight in weights[1:]:
torch_assert_equal(ref, weight, check_device=False)