Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autotp training #6922

Open
wants to merge 51 commits into
base: master
Choose a base branch
from
Open

Autotp training #6922

wants to merge 51 commits into from

Conversation

inkcherry
Copy link
Contributor

@inkcherry inkcherry commented Jan 2, 2025

FYI @tjruwase @GuanhuaWang @delock @skyshine102 context: #5445
changes/support

  • auto tensor parallel training for HF model(zero compatible. I only tested zero1 currently)
  • distributed ckpt save(UCP is not supported).
  • HF model files save(set gather_16bit_weights_on_model_save=True in ds config).
  • Dataloader check.
  • Uts.
  • tp layer refactor by abstract layer design.

HF trainer dependency:
transformer: https://github.com/inkcherry/transformers/tree/ds_tp
accelerate: https://github.com/inkcherry/accelerate/tree/ds_tp
I could send them once ds support these api.

Usage:
Users do not need to modify the client code, they only need to configure the settings in the config file to achieve the desired functionality.
Below is an example of code for fine-tuning a LLaMA 2 model (SFT). It supports Zero3/FSDP training and enables TP training by simply adjusting the configuration

https://github.com/inkcherry/stanford_alpaca/commits/tp_demo_1127/
This branch contains three commits, with the last two commits added for quick experiments and logging purposes.
results
loss curve(gbs=16):
zero3(baseline)
image
tp(this)
image

zero1 with zero1+tp(zero compatible)
image

performance(For your reference only.):
zero3(not enabled any acceleration.) : 18GB 2.3s/it
zero1:38GB 1.30s/it
zero1+tp: 24GB 1.66s/it
extension:
I think async-TP/domino .etc. can be implemented by inheriting a class and overriding the fwd/bwd methods. The logic for gather/partition can be reused to achieve this.(please correct me if I am wrong)

Complex sharding can also be achieved through independent partitioning and gathering. Partitioning is mandatory, while gathering is required for training.
TODO:
embedding vocab parallel
Currently, the parallelism for embeddings is primarily based on hidden_dim parallel combined with allreduce. This approach takes advantage of efficient reduction kernels. and it is not forced to use.
In training, however, the more common method is vocab parallelism. Enabling by default can save a certain amount of GPU memory.

thanks for @delock guidance.
I also verified inference with cpu-inference workloads(Optimized Model List in https://github.com/intel/intel-extension-for-pytorch/tree/main).
many thanks for @xuguangxin @ikurtchen @rogerxfeng8 ,@Yejing-Lai ,@ys950902 .etc. Help review and address matters related to inference.

@delock
Copy link
Collaborator

delock commented Jan 2, 2025

@tjruwase @GuanhuaWang
We had internal review of @inkcherry 's PR. This PR allows train HF models with tensor parallel without need for megatron. Which is very friendly to user.

Let us know your plan for Domino integration. @inkcherry 's memory data looks good. With Domino we think it can have less impact on performance since TP communication can overlap with computation.

@inkcherry by design should autotp training work with ZeRO3 as well?

@inkcherry
Copy link
Contributor Author

@tjruwase @GuanhuaWang We had internal review of @inkcherry 's PR. This PR allows train HF models with tensor parallel without need for megatron. Which is very friendly to user.

Let us know your plan for Domino integration. @inkcherry 's memory data looks good. With Domino we think it can have less impact on performance since TP communication can overlap with computation.

@inkcherry by design should autotp training work with ZeRO3 as well?

for Zero3 + TP: Currently, the logic to combine the saving of HF weights for TP & DP has not been implemented, but it is entirely feasible. If needed, it can be implemented in the future.

@tjruwase
Copy link
Contributor

tjruwase commented Jan 6, 2025

We had internal review of @inkcherry 's PR. This PR allows train HF models with tensor parallel without need for megatron. Which is very friendly to user.

Bravo @inkcherry, this is an excellent technology and massive usability benefit for users. This is really exciting!

Let us know your plan for Domino integration. @inkcherry 's memory data looks good. With Domino we think it can have less impact on performance since TP communication can overlap with computation.

In terms of Domino integration, @GuanhuaWang will take the lead on that.

for Zero3 + TP: Currently, the logic to combine the saving of HF weights for TP & DP has not been implemented, but it is entirely feasible. If needed, it can be implemented in the future.

I would love to prioritize enabling UCP support sooner than later. @inkcherry, can you please share the work needed here?


def test(self):
set_autotp_mode(training=True)
tp_size = 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you parametrize tp_size to improve coverage?

reuse_dist_env = True

def testRowParallel(self):
tp_size = 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parametrize tp_size for coverage.

@@ -339,7 +340,10 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
"""
Override nn.Module apply function, for Stage 3.
"""

autotp_size: int = Field(0, ge=0, new_param="autotp_size")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is autotp_size defined as a subfield of zero, instead of a top-level field in ds_config? Is there a dependency on zero logic?

Returns:
OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None.
"""
#TODO: If we use both Zero3 and tensor parallel simultaneously
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what is meant by the gather mechanism of tensor parallelism?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question.
I could somehow understand as it's a similar function to

def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False):
but specific for TP. The function name can be improved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skyshine102, thanks for the comment. A key difference between is zero3 and TP is that partitioned zero3 modules materialized using allgather before compute, whereas TP modules compute in a partitioned manner. So, it is unclear to me what requires gathering for TP.

@@ -247,6 +248,11 @@ def _post_forward_hook(self, module, input, output):
self._model_times.append(elapsed_time)

def _create_model_parallel_group(self, config):

if is_autotp_training_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, control flow for training should not come here. I think some refactoring/restructuring is needed for code quality.

Copy link

@skyshine102 skyshine102 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @inkcherry for this contribution. I have spent some time to read this PR and I'm happy to be involved in this discussion. (I'm not from deepspeed team but deepspeed user. My comments are relatively minor though.)

return Yuan_LinearALlreduce(child, self.mp_group)

# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This additional code block is trying to deal with "MLP including chunk layer" (general case), but the returned module/object is in the name of GLM prefix.
It could be better to rename the GLM_LinearLayer to sth like GateUpPack_LinearLayer.

@@ -11,10 +11,12 @@
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearALlreduce, Yuan_LinearLayer, GLM_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original coding style is LinearAllreduce instead of LinearALlreduce.

broadcast_and_check(args, bcast_rank, bcast_group)
broadcast_and_check(kwargs, bcast_rank, bcast_group)

print(f"RANK[{dist.get_rank()}]:The Dataloader has passed the TP group consistency check.")
Copy link

@skyshine102 skyshine102 Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use the logger at rank 0 instead of print.

Returns:
OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None.
"""
#TODO: If we use both Zero3 and tensor parallel simultaneously

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question.
I could somehow understand as it's a similar function to

def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False):
but specific for TP. The function name can be improved.

@skyshine102
Copy link

skyshine102 commented Jan 7, 2025

@tjruwase @GuanhuaWang We had internal review of @inkcherry 's PR. This PR allows train HF models with tensor parallel without need for megatron. Which is very friendly to user.

Let us know your plan for Domino integration. @inkcherry 's memory data looks good. With Domino we think it can have less impact on performance since TP communication can overlap with computation.

@inkcherry by design should autotp training work with ZeRO3 as well?

@inkcherry I have the same question. Does this PR support the flow like https://pytorch.org/tutorials/intermediate/TP_tutorial.html#combine-tensor-parallel-with-fully-sharded-data-parallel-together ? (TP to shared weight $W$ to $W_{tp_i}$, then further shard $W_{tp_i}$ by ZeRO-3 to $W_{tp_i, dp_j}$)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants