Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-node training with deepspeed launcher #2605

Open
5 tasks done
ghtaro opened this issue Jan 22, 2025 · 0 comments
Open
5 tasks done

Multi-node training with deepspeed launcher #2605

ghtaro opened this issue Jan 22, 2025 · 0 comments
Labels
🚀 deepspeed Related to deepspeed 🏋 SFT Related to SFT

Comments

@ghtaro
Copy link

ghtaro commented Jan 22, 2025

Reproduction

Here is deepspeed laucher for 2 nodes of p5.48xlarge (I am using slurm, but I do not think it matters).

deepspeed --hostfile=config/hostfile --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT train.py --deepspeed_config config/ds_config_zero3_cpuoffload.json \
--model_type <llama 3.3 70b instruct> \
--batch_size $BATCH_SIZE \
--seq_length $SEQ_LENGTH \
--learning_rate $LR \
--gradient_accumulation_steps $GAS \
--num_train_epochs $EPOCH \
--fa flash_attention_2

Here is a simplified train.py file.

script_args = get_args(CPTArguments)

dataset = load_from_disk(<train_dataset_fullpath>)
dataset = dataset.train_test_split(test_size=0.1)

device_map = None
quantization_config = None
torch_dtype = None

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=script_args.batch_size,
    per_device_eval_batch_size=script_args.eval_batch_size,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    gradient_checkpointing=True,
    learning_rate=script_args.learning_rate,
    logging_steps=script_args.logging_steps,
    num_train_epochs=script_args.num_train_epochs,
    bf16=True,
    eval_strategy="no",
    save_strategy="epoch",
    warmup_ratio=0.05,
    logging_first_step=True,
    deepspeed=script_args.deepspeed_config,
    save_only_model=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model=<llama 3.3 70b instruct>,
    quantization_config=quantization_config,
    device_map=device_map,
    trust_remote_code=script_args.trust_remote_code,
    torch_dtype=torch_dtype,
    token=<TOKEN>,
    attn_implementation=script_args.fa,
)

tokenizer = AutoTokenizer.from_pretrained(model_config['model_name'], token=<TOKEN>)
tokenizer.pad_token = model_config['pad_token'] # "<|finetune_right_pad_id|>"
tokenizer.padding_side = ("right")

def formatting_func(example):
    text = tokenizer.bos_token + example["text"] + tokenizer.eos_token
    return text

trainer = SFTTrainer(
    model=model,
    args=training_args,
    max_seq_length=script_args.seq_length,
    train_dataset=dataset["train"],
    formatting_func=formatting_func,
    peft_config=None,
    packing=True,
)

print("Start Training")
trainer.train()

output:

...
10.0.29.212: {'loss': 0.5341, 'grad_norm': 0.8705843687057495, 'learning_rate': 3.1578947368421055e-07, 'epoch': 0.98}
10.0.29.212: {'loss': 0.5119, 'grad_norm': 5.807155132293701, 'learning_rate': 2.105263157894737e-07, 'epoch': 0.99}
10.0.28.174: [2025-01-22 07:48:04,012] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 880348
10.0.28.174: [2025-01-22 07:48:04,055] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 880349
...

errors:

This timeout happens at the last step of the training (I confirmed that the last checkpoint was not created). When I sampled only 1K records (out of 70K records) and train with them, I found the error dissapeared and the training finished without any errors (which is very strange to me...).

 0%|          | 0/101 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
 99%|█████████▉| 100/101 [3:04:16<01:50, 110.64s/it][rank2]:[E122 07:47:07.196390673 ProcessGroupNCCL.cpp:616] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=100368, OpType=_ALLGATHER_BASE, NumelIn=65667072, NumelOut=1050673152, Timeout(ms)=1800000) ran for 1800003 milliseconds before timing out.
10.0.29.212: [rank2]:[E122 07:47:07.196986688 ProcessGroupNCCL.cpp:1785] [PG ID 0 PG GUID 0(default_pg) Rank 2] Exception (either an error or timeout) detected by watchdog at work: 100368, last enqueued NCCL work: 100370, last completed NCCL work: 100367.
10.0.29.212: [rank2]:[E122 07:47:07.197021650 ProcessGroupNCCL.cpp:1834] [PG ID 0 PG GUID 0(default_pg) Rank 2] Timeout at NCCL work: 100368, last enqueued NCCL work: 100370, last completed NCCL work: 100367.
10.0.29.212: [rank2]:[E122 07:47:07.197029051 ProcessGroupNCCL.cpp:630] [Rank 2] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
10.0.29.212: [rank2]:[E122 07:47:07.197032661 ProcessGroupNCCL.cpp:636] [Rank 2] To avoid data inconsistency, we are taking the entire process down.
10.0.29.212: [rank3]:[E122 07:47:07.197858759 ProcessGroupNCCL.cpp:616] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=100368, OpType=_ALLGATHER_BASE, NumelIn=65667072, NumelOut=1050673152, Timeout(ms)=1800000) ran for 1800003 milliseconds before timing out.
10.0.29.212: [rank2]:[E122 07:47:07.198108174 ProcessGroupNCCL.cpp:1595] [PG ID 0 PG GUID 0(default_pg) Rank 2] Process group watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=100368, OpType=_ALLGATHER_BASE, NumelIn=65667072, NumelOut=1050673152, Timeout(ms)=1800000) ran for 1800003 milliseconds before timing out.
10.0.29.212: Exception raised from checkTimeout at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:618 (most recent call first):
10.0.29.212: frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7ccd838b9446 in /home/ubuntu/.pyenv/versions/anaconda3-2024.10-1/lib/python3.12/site-packages/torch/lib/libc10.so)
10.0.29.212: frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x282 (0x7ccd38bcc772 in /home/ubuntu/.pyenv/versions/anaconda3-2024.10-1/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
10.0.29.212: frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0x233 (0x7ccd38bd3bb3 in /home/ubuntu/.pyenv/versions/anaconda3-2024.10-1/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
10.0.29.212: frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7ccd38bd561d in /home/ubuntu/.pyenv/versions/anaconda3-2024.10-1/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
10.0.29.212: frame #4: <unknown function> + 0x145c0 (0x7ccd842f95c0 in /home/ubuntu/.pyenv/versions/anaconda3-2024.10-1/lib/python3.12/site-packages/torch/lib/libtorch.so)
10.0.29.212: frame #5: <unknown function> + 0x94ac3 (0x7ccda4694ac3 in /lib/x86_64-linux-gnu/libc.so.6)
10.0.29.212: frame #6: <unknown function> + 0x126850 (0x7ccda4726850 in /lib/x86_64-linux-gnu/libc.so.6)
10.0.29.212:

Also I found a weird part in the same log file below. I passed my dataset and am using packing=True. I expected only one "Generating train split: <NUM_PACKED_SAMPLES> examples" message and the dataset is shared across all the rank (8 gpu *2 node =16 ranks) in deepspeed Zero3 training.

Generating train split: 9617 examples [00:12, 775.12 examples/s] 
Generating train split: 9629 examples [00:12, 769.35 examples/s] 
Generating train split: 9626 examples [00:12, 765.56 examples/s] 
Generating train split: 9618 examples [00:12, 760.59 examples/s] 
Generating train split: 9610 examples [00:12, 759.24 examples/s] 
Generating train split: 9628 examples [00:12, 753.58 examples/s] 
Generating train split: 9615 examples [00:12, 747.94 examples/s] 
Generating train split: 9614 examples [00:13, 736.11 examples/s] 
Generating train split: 9614 examples [00:13, 734.01 examples/s] 
Generating train split: 9605 examples [00:13, 730.34 examples/s] 
Generating train split: 9590 examples [00:13, 717.06 examples/s] 
Generating train split: 9628 examples [00:13, 716.29 examples/s] 
Generating train split: 9624 examples [00:13, 706.29 examples/s] 
Generating train split: 1071 examples [00:01, 699.63 examples/s]
... (In total, 16)

System Info

  • Platform: Ubuntu 22.04
  • Python version: 3.12.7
  • PyTorch version: 2.5.1+cu124
  • CUDA device: NVIDIA H100 80GB HBM3
  • Transformers version: 4.46.3
  • Accelerate version: 1.0.1
  • Datasets version: 3.1.0
  • HF Hub version: 0.26.3
  • TRL version: 0.11.4
  • DeepSpeed version: 0.15.3
  • PEFT version: 0.13.2

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@github-actions github-actions bot added 🚀 deepspeed Related to deepspeed 🏋 SFT Related to SFT labels Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🚀 deepspeed Related to deepspeed 🏋 SFT Related to SFT
Projects
None yet
Development

No branches or pull requests

1 participant