-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MoE][PyTorch] Add mask-based MoE permutation #1373
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Hongxiao Bai <[email protected]>
for more information, see https://pre-commit.ci Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
6160104
to
ca94d72
Compare
] | ||
|
||
|
||
class _moe_permute(torch.autograd.Function): | ||
"""functional Permute""" | ||
class _moe_permute_indice_map(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class _moe_permute_indice_map(torch.autograd.Function): | |
class _moe_permute_index_map(torch.autograd.Function): |
We should make sure to use "index" in user-facing APIs like moe_permute
/moe_unpermute
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, "index" is singular and "indices" is plural. It's weird since it's from Latin.
We still have some places in the tests that use "indice", but it's less important since it's internal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I fixed these names.
if ctx.fp8: | ||
assert isinstance( | ||
permuted_act_grad, Float8Tensor | ||
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't we decouple FP8 in the forward and backward?
if ctx.fp8: | |
assert isinstance( | |
permuted_act_grad, Float8Tensor | |
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." | |
fp8 = isinstance(permuted_act_grad, Float8Tensor) | |
if fp8: |
If there are no obstacles, we could also do the same thing for _moe_unpermute_mask_map
and _moe_chunk_sort
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified. Now for bwd, it would follow the dtype of the grad tensor.
# Results Check | ||
# | ||
################################################################################################################################### | ||
tols = dtype_tols(te_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should expect bit-wise exact results.
tols = dtype_tols(te_dtype) | |
tols = { "atol": 0, "rtol": 0 } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like the one above, I changed to bit-wise matching except for fp8.
mask=(offset < num_tokens), | ||
other=0, | ||
).to(tl.int64) | ||
expert_token_cumsum = tl.cumsum(expert_token_mask) * expert_token_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An interesting way to exclude the zero token_mask. Happy to learn!
chunk_cumsum = tl.load( | ||
row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0 | ||
) | ||
|
||
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) | ||
chunk_sums = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx) | ||
chunk_cumsum = tl.where(chunk_cumsum == 0, -1, chunk_cumsum + tl.sum(chunk_sums) - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three names chuck_cumsum
, chuck_sums
, and chunk_cumsum
are quite confusing.
If I understand it correctly, I suggest to rename them to:
chuck_cumsum
->row_id_within_token_block
chuck_sums
-> n_tokens_per_expertchuck_cumsum
->row_id
In addition, I think we should move the -1
to the pass1
as it is the correction for the calculation of expert_token_cumsum
, as:
expert_token_cumsum = (tl.cumsum(expert_token_mask) - 1) * expert_token_mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. You are right. I modified these names (renamed chunk_sums
to n_tokens_per_block
rather than n_tokens_per_expert
).
For the -1
, if we move it to pass1
, then we cannot easily distinguish the row_id: 0
and the mask: 0
and we need extra ways to handle whether it is masked out. So, I still left the -1
in the pass2
. Do you think it is OK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me.
Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
Hi @timmoon10 @phu0ngng, could you help take another look at this? We intend to incorporate this optimization into mcore v0.11 (6th Feb). Thanks a lot! |
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. My suggestions are stylistic.
tests/pytorch/test_permutation.py
Outdated
@pytest.mark.parametrize("num_expert", [8, 16]) | ||
@pytest.mark.parametrize("tp_size", [1, 2, 8]) | ||
@pytest.mark.parametrize("hidden_size", [4096]) | ||
def test_chunk_permuation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
def test_chunk_permuation( | |
def test_chunk_permutation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
tests/pytorch/test_permutation.py
Outdated
def dtype_tols(te_dtype: tex.DType, has_tolerance: bool = True) -> Dict[str, float]: | ||
"""Estimated tolerances for a datatype | ||
|
||
Based on tolerances for torch.testing.assert_close. | ||
|
||
""" | ||
if te_dtype == tex.DType.kFloat32: | ||
return dict(rtol=1.0e-6, atol=1.0e-6) | ||
if te_dtype == tex.DType.kFloat16: | ||
return dict(rtol=3.0e-3, atol=1.0e-5) | ||
if te_dtype == tex.DType.kBFloat16: | ||
return dict(rtol=2.0e-2, atol=1.0e-5) | ||
if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3: | ||
return dict(rtol=2.0e-1, atol=1.0e-1) | ||
raise ValueError(f"Unsuppored dtype ({te_dtype})") | ||
if has_tolerance: | ||
if te_dtype == tex.DType.kFloat32: | ||
return dict(rtol=1.0e-6, atol=1.0e-6) | ||
if te_dtype == tex.DType.kFloat16: | ||
return dict(rtol=3.0e-3, atol=1.0e-5) | ||
if te_dtype == tex.DType.kBFloat16: | ||
return dict(rtol=2.0e-2, atol=1.0e-5) | ||
if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3: | ||
return dict(rtol=2.0e-1, atol=1.0e-1) | ||
raise ValueError(f"Unsuppored dtype ({te_dtype})") | ||
else: | ||
if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3: | ||
return dict(rtol=2.0e-1, atol=1.0e-1) | ||
else: | ||
return dict(rtol=0.0, atol=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this change makes this function much less intuitive. With has_tolerance=False
, why do we expect bitwise accuracy with FP32 but not with FP8? It makes sense if you know the details of the MoE impl, but why should you need to know that in a simple helper function like this?
Also, based on your comments I think using dtype tolerances is fine for these tests. I think we can just revert this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Just reverted.
Signed-off-by: Hongxiao Bai <[email protected]>
Description
Add mask-based token permutation and local chunk permutation fused kernels. These kernels are implemented with OpenAI Triton.
Related commit in Megatron-LM NVIDIA/Megatron-LM@ac0474d
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
te.pytorch.permutation.moe_permute
andte.pytorch.permutation.moe_unpermute
te.pytorch.permutation.moe_sort_chunks_by_indices
Checklist: