-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MoE][PyTorch] Add mask-based MoE permutation #1373
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Hongxiao Bai <[email protected]>
for more information, see https://pre-commit.ci Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
Signed-off-by: Hongxiao Bai <[email protected]>
6160104
to
ca94d72
Compare
] | ||
|
||
|
||
class _moe_permute(torch.autograd.Function): | ||
"""functional Permute""" | ||
class _moe_permute_indice_map(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class _moe_permute_indice_map(torch.autograd.Function): | |
class _moe_permute_index_map(torch.autograd.Function): |
We should make sure to use "index" in user-facing APIs like moe_permute
/moe_unpermute
.
import warnings | ||
from typing import Tuple | ||
import torch | ||
|
||
import transformer_engine_torch as tex | ||
from .constants import TE_DType | ||
from .float8_tensor import Float8Tensor | ||
import transformer_engine.pytorch.triton.permutation as triton_permuataion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
import transformer_engine.pytorch.triton.permutation as triton_permuataion | |
import transformer_engine.pytorch.triton.permutation as triton_permutation |
if ctx.fp8: | ||
assert isinstance( | ||
permuted_act_grad, Float8Tensor | ||
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't we decouple FP8 in the forward and backward?
if ctx.fp8: | |
assert isinstance( | |
permuted_act_grad, Float8Tensor | |
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." | |
fp8 = isinstance(permuted_act_grad, Float8Tensor) | |
if fp8: |
If there are no obstacles, we could also do the same thing for _moe_unpermute_mask_map
and _moe_chunk_sort
.
# Results Check | ||
# | ||
################################################################################################################################### | ||
tols = dtype_tols(te_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we expect bit-wise exact results?
tols = dtype_tols(te_dtype) | |
tols = { "atol": 0, "rtol": 0 } |
# Results Check | ||
# | ||
################################################################################################################################### | ||
tols = dtype_tols(te_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should expect bit-wise exact results.
tols = dtype_tols(te_dtype) | |
tols = { "atol": 0, "rtol": 0 } |
mask=(offset < num_tokens), | ||
other=0, | ||
).to(tl.int64) | ||
expert_token_cumsum = tl.cumsum(expert_token_mask) * expert_token_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An interesting way to exclude the zero token_mask. Happy to learn!
chunk_cumsum = tl.load( | ||
row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0 | ||
) | ||
|
||
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) | ||
chunk_sums = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx) | ||
chunk_cumsum = tl.where(chunk_cumsum == 0, -1, chunk_cumsum + tl.sum(chunk_sums) - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three names chuck_cumsum
, chuck_sums
, and chunk_cumsum
are quite confusing.
If I understand it correctly, I suggest to rename them to:
chuck_cumsum
->row_id_within_token_block
chuck_sums
-> n_tokens_per_expertchuck_cumsum
->row_id
In addition, I think we should move the -1
to the pass1
as it is the correction for the calculation of expert_token_cumsum
, as:
expert_token_cumsum = (tl.cumsum(expert_token_mask) - 1) * expert_token_mask
Description
Add mask-based token permutation and local chunk permutation fused kernels. These kernels are implemented with OpenAI Triton.
Related commit in Megatron-LM NVIDIA/Megatron-LM@ac0474d
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
te.pytorch.permutation.moe_permute
andte.pytorch.permutation.moe_unpermute
te.pytorch.permutation.moe_sort_chunks_by_indices
Checklist: