diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 0d216f8af8..4ac05c7d55 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -443,7 +443,9 @@ def __new__( return self - def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument + def fsdp_pre_all_gather(self, mesh): + # pylint: disable=missing-function-docstring + return (self._data,), (self,) def fsdp_post_all_gather( @@ -454,6 +456,7 @@ def fsdp_post_all_gather( *, out: Optional[torch.Tensor] = None, ): + # pylint: disable=unused-argument (data,) = all_gather_outputs (sample,) = metadata if out is not None: