-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling FP8 all-gather for TE Float8Tensor when using Torch FSDP2 #1358
Conversation
Signed-off-by: Youngeun Kwon <[email protected]>
Signed-off-by: Youngeun Kwon <[email protected]>
Signed-off-by: Youngeun Kwon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Youngeun Kwon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Youngeun Kwon <[email protected]>
Signed-off-by: Youngeun Kwon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Youngeun Kwon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Youngeun Kwon <[email protected]>
Signed-off-by: Youngeun Kwon <[email protected]>
Signed-off-by: Youngeun Kwon <[email protected]>
/te-ci pytorch L0 L1 |
Signed-off-by: Youngeun Kwon <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
torch.ops.aten.copy_.default, | ||
torch.ops.aten.view.default, | ||
torch.ops.aten.as_strided.default, | ||
torch.ops.aten._to_copy.default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having _to_copy
in this list puts us in a weird position. @youngeunkwon0405 Where did you get this list of ops and could we figure out a way to remove _to_copy
?
The trouble is that we are implicitly using torch.Tensor.to
as a dequantize function, so always expect the _to_copy
op to output a plain PyTorch tensor. The reason for this design was to work with Mcore's logic for maintaining FP32 master weights (see logic for DDP and distopt). With this PR, we now see many spurious errors whenever we dequantize an FP8 tensor with to
/float
/half
/etc.
If the current impl of _to_copy
leads to insurmountable problems with FSDP2, we'll probably need to remove the implicit dequantization and change Mcore so that it explicitly calls Float8Tensor.dequantize
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was following Torch AO's implementation.
https://github.com/pytorch/ao/blob/main/torchao/float8/fsdp_utils.py#L86-L99
I checked that it was okay for _to_copy
to not preserve the tensor class currently. But leaved the warning for the future reference.
This PR did not change the functional behavior of the Float8Tensor _to_copy
it only adds a warning here.
Description
This PR enables FP8 all-gather for TE Float8Tensor when using the Torch FSDP2 (a.k.a. per-parameter-sharding FSDP).
This feature will be automatically enabled when a user creates a module with the
transformer_engine.pytorch.fp8_model_init
.Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: