diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 78645c236..8430f528c 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -19,6 +19,7 @@ import numbers import torch +from megatron import mpu from torch.nn.parameter import Parameter from torch.nn import init import importlib @@ -31,7 +32,6 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias, normalized_shape, eps): - ctx.normalized_shape = normalized_shape ctx.eps = eps input_ = input.contiguous() @@ -84,6 +84,18 @@ def reset_parameters(self): def forward(self, input): + weights = [torch.empty_like(self.weight) for tp in range(mpu.get_tensor_model_parallel_world_size())] + torch.distributed.all_gather(weights, self.weight, group=mpu.get_tensor_model_parallel_group()) + biases = [torch.empty_like(self.bias) for tp in range(mpu.get_tensor_model_parallel_world_size())] + torch.distributed.all_gather(biases, self.bias, group=mpu.get_tensor_model_parallel_group()) + if any(torch.any(weight != self.weight) for weight in weights): + if mpu.get_tensor_model_parallel_rank() == 0: + print("Weight sync failed") + print(weights) + if any(torch.any(bias != self.bias) for bias in biases): + if mpu.get_tensor_model_parallel_rank() == 0: + print("Bias sync failed") + print(biases) return FusedLayerNormAffineFunction.apply( input, self.weight, self.bias, self.normalized_shape,self.eps) diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index 821d9acfe..6056f94f6 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -82,6 +82,7 @@ def symbolic(graph, input_): @staticmethod def forward(ctx, input_): + # TODO: we need to assert that the input_ are all the same within a group return input_ @staticmethod @@ -102,6 +103,7 @@ def forward(ctx, input_): @staticmethod def backward(ctx, grad_output): + # TODO: we need to assert that the grad_output are all the same within a group return grad_output diff --git a/tests/ds_config_bf16.json b/tests/ds_config_bf16.json new file mode 100644 index 000000000..1f02566c9 --- /dev/null +++ b/tests/ds_config_bf16.json @@ -0,0 +1,13 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": 16, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 0 + }, + "bf16": { + "enabled": true + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 25921c12a..ed383e17a 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -293,6 +293,5 @@ def test_tokenizer_raise_error_make_vocab_size_divisible_by(self): self.assertEqual(str(exc_info.value), "5121 is not divisible by 128") - if __name__ == '__main__': unittest.main() diff --git a/tests/test_training.py b/tests/test_training.py index c77cb9af2..65067982e 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -20,6 +20,8 @@ import re import unittest from pathlib import Path + +import torch from parameterized import parameterized from megatron.testing_utils import ( @@ -31,7 +33,7 @@ require_bnb_non_decorator, require_deepspeed, require_torch_gpu, - set_seed + set_seed, torch_assert_equal ) set_seed(42) @@ -50,7 +52,7 @@ def get_3d_dimensions(): dp_size = 2 pp_size = 2 tp_size = 2 - if num_gpus >= 4: + elif num_gpus >= 4: dp_size = 1 pp_size = 2 tp_size = 2 @@ -592,3 +594,115 @@ def test_skip_train_iteration(self): train_iterations = range(1,10) for i in train_iterations: self.assertTrue(f"iteration {i:8d}/" in cs.out) + + @parameterized.expand(["bf16", "fp16"]) + def test_layer_norm_consistent(self, variation): + src_dir = self.src_dir + output_dir = self.get_auto_remove_tmp_dir() + num_gpus = 2 + seq_len = 128 + data_dir = f"{self.data_dir}/gpt2" + args = f""" + --tensor-model-parallel-size {2} + --pipeline-model-parallel-size {1} + --distributed-backend nccl + + --log-interval 1 + --save-interval 10 + --eval-interval 10 + --eval-iters 5 + --checkpoint-activations + --partition-activations + --exit-interval {20} + + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab-file {data_dir}/gpt2-tiny-vocab.json + --save {output_dir}/checkpoints + --load {output_dir}/checkpoints + --data-path {data_dir}/meg-gpt2-openwebtext_text_document + --tensorboard-dir {output_dir}/tensorboard + --tensorboard-queue-size 5 + --log-timers-to-tensorboard + --log-batch-size-to-tensorboard + --log-validation-ppl-to-tensorboard + + --num-layers 2 + --hidden-size 64 + --num-attention-heads 2 + --seq-length {seq_len} + --max-position-embeddings 1024 + --micro-batch-size 2 + --global-batch-size 16 + + --optimizer adam + --adam-beta1 0.9 + --adam-beta2 0.95 + --adam-eps 1e-8 + --lr 1e-1 + --clip-grad 1.0 + --weight-decay 1e-1 + --embed-layernorm + + --log-level debug + --log-level-replica info + + --rampup-batch-size 2 2 200 + --train-samples 200 + + --position-embedding-type alibi + """.split() + + ds_args = f""" + --deepspeed + --deepspeed-activation-checkpointing + """.split() + + if variation == "bf16": + args.append("--bf16") + ds_args += [ + "--zero-stage", "0", + "--deepspeed_config", f"{self.test_file_dir_str}/ds_config_bf16.json" + ] + elif variation == "fp16": + args.append("--fp16") + ds_args += [ + "--zero-stage", "1", + "--deepspeed_config", f"{self.test_file_dir_str}/ds_config.json" + ] + + # args, ds_args, num_gpus = self.get_variation_config("base", output_dir, n_samples=200) + + script = [f"{src_dir}/pretrain_gpt.py"] + launcher = get_launcher(num_gpus) + cmd = launcher + script + args + ds_args + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + checkpoints = ["global_step10", "global_step20"] + + # Check transformer layer norm + keys_to_compare = ["input_layernorm.weight", "input_layernorm.bias", "post_attention_layernorm.weight", "post_attention_layernorm.bias"] + files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [3,4]] + for checkpoint in checkpoints: + checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint) + for key in keys_to_compare: + for files in files_to_compare: + weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] + ref = weights[0] + for weight in weights[1:]: + torch_assert_equal(ref, weight, check_device=False) + + # Check embed layer norm + keys_to_compare = ["word_embeddings.norm.weight"] + files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [1]] + for checkpoint in checkpoints: + checkpoint_path = os.path.join(output_dir, "checkpoints", checkpoint) + for key in keys_to_compare: + for files in files_to_compare: + weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files] + ref = weights[0] + for weight in weights[1:]: + torch_assert_equal(ref, weight, check_device=False)