Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MoE][PyTorch] Add mask-based MoE permutation #1373

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

hxbai
Copy link

@hxbai hxbai commented Dec 13, 2024

Description

Add mask-based token permutation and local chunk permutation fused kernels. These kernels are implemented with OpenAI Triton.

Related commit in Megatron-LM NVIDIA/Megatron-LM@ac0474d

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Non-breaking API changes in te.pytorch.permutation.moe_permute and te.pytorch.permutation.moe_unpermute
  • Add new APIs of te.pytorch.permutation.moe_sort_chunks_by_indices

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@hxbai hxbai changed the title [MoE][Common/PyTorch] Add mask-based MoE permutation [MoE][PyTorch] Add mask-based MoE permutation Dec 13, 2024
@phu0ngng phu0ngng self-requested a review January 8, 2025 15:20
]


class _moe_permute(torch.autograd.Function):
"""functional Permute"""
class _moe_permute_indice_map(torch.autograd.Function):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class _moe_permute_indice_map(torch.autograd.Function):
class _moe_permute_index_map(torch.autograd.Function):

We should make sure to use "index" in user-facing APIs like moe_permute/moe_unpermute.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, modified.

Copy link
Collaborator

@timmoon10 timmoon10 Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, "index" is singular and "indices" is plural. It's weird since it's from Latin.

We still have some places in the tests that use "indice", but it's less important since it's internal.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I fixed these names.

transformer_engine/pytorch/permutation.py Outdated Show resolved Hide resolved
Comment on lines 292 to 295
if ctx.fp8:
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we decouple FP8 in the forward and backward?

Suggested change
if ctx.fp8:
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
fp8 = isinstance(permuted_act_grad, Float8Tensor)
if fp8:

If there are no obstacles, we could also do the same thing for _moe_unpermute_mask_map and _moe_chunk_sort.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified. Now for bwd, it would follow the dtype of the grad tensor.

tests/pytorch/test_permutation.py Show resolved Hide resolved
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should expect bit-wise exact results.

Suggested change
tols = dtype_tols(te_dtype)
tols = { "atol": 0, "rtol": 0 }

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the one above, I changed to bit-wise matching except for fp8.

@timmoon10 timmoon10 self-requested a review January 8, 2025 21:57
mask=(offset < num_tokens),
other=0,
).to(tl.int64)
expert_token_cumsum = tl.cumsum(expert_token_mask) * expert_token_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An interesting way to exclude the zero token_mask. Happy to learn!

Comment on lines 61 to 67
chunk_cumsum = tl.load(
row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0
)

workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
chunk_sums = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
chunk_cumsum = tl.where(chunk_cumsum == 0, -1, chunk_cumsum + tl.sum(chunk_sums) - 1)
Copy link
Collaborator

@phu0ngng phu0ngng Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three names chuck_cumsum, chuck_sums, and chunk_cumsum are quite confusing.
If I understand it correctly, I suggest to rename them to:

  • chuck_cumsum -> row_id_within_token_block
  • chuck_sums -> n_tokens_per_expert
  • chuck_cumsum -> row_id

In addition, I think we should move the -1 to the pass1 as it is the correction for the calculation of expert_token_cumsum, as:

expert_token_cumsum = (tl.cumsum(expert_token_mask) - 1) * expert_token_mask

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. You are right. I modified these names (renamed chunk_sums to n_tokens_per_block rather than n_tokens_per_expert).

For the -1, if we move it to pass1, then we cannot easily distinguish the row_id: 0 and the mask: 0 and we need extra ways to handle whether it is masked out. So, I still left the -1 in the pass2. Do you think it is OK?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me.

@yanring
Copy link

yanring commented Jan 21, 2025

Hi @timmoon10 @phu0ngng, could you help take another look at this? We intend to incorporate this optimization into mcore v0.11 (6th Feb). Thanks a lot!

@phu0ngng
Copy link
Collaborator

/te-ci pytorch

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. My suggestions are stylistic.

@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
@pytest.mark.parametrize("hidden_size", [4096])
def test_chunk_permuation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
def test_chunk_permuation(
def test_chunk_permutation(

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines 184 to 204
def dtype_tols(te_dtype: tex.DType, has_tolerance: bool = True) -> Dict[str, float]:
"""Estimated tolerances for a datatype

Based on tolerances for torch.testing.assert_close.

"""
if te_dtype == tex.DType.kFloat32:
return dict(rtol=1.0e-6, atol=1.0e-6)
if te_dtype == tex.DType.kFloat16:
return dict(rtol=3.0e-3, atol=1.0e-5)
if te_dtype == tex.DType.kBFloat16:
return dict(rtol=2.0e-2, atol=1.0e-5)
if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3:
return dict(rtol=2.0e-1, atol=1.0e-1)
raise ValueError(f"Unsuppored dtype ({te_dtype})")
if has_tolerance:
if te_dtype == tex.DType.kFloat32:
return dict(rtol=1.0e-6, atol=1.0e-6)
if te_dtype == tex.DType.kFloat16:
return dict(rtol=3.0e-3, atol=1.0e-5)
if te_dtype == tex.DType.kBFloat16:
return dict(rtol=2.0e-2, atol=1.0e-5)
if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3:
return dict(rtol=2.0e-1, atol=1.0e-1)
raise ValueError(f"Unsuppored dtype ({te_dtype})")
else:
if te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3:
return dict(rtol=2.0e-1, atol=1.0e-1)
else:
return dict(rtol=0.0, atol=0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change makes this function much less intuitive. With has_tolerance=False, why do we expect bitwise accuracy with FP32 but not with FP8? It makes sense if you know the details of the MoE impl, but why should you need to know that in a simple helper function like this?

Also, based on your comments I think using dtype tolerances is fine for these tests. I think we can just revert this function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Just reverted.

tests/pytorch/test_permutation.py Show resolved Hide resolved
Signed-off-by: Hongxiao Bai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants