Skip to content

Commit

Permalink
Fix fusible ops checkpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman committed Jan 8, 2025
1 parent 560bccf commit 78c869d
Showing 1 changed file with 1 addition and 13 deletions.
14 changes: 1 addition & 13 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,6 @@ def get_extra_state(self) -> torch.Tensor:
# See: https://github.com/NVIDIA/TransformerEngine/pull/351
# See: https://github.com/NVIDIA/TransformerEngine/pull/363

# Return immediately if op has no FP8 state
has_fp8_state = any(
self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output")
)
if not has_fp8_state:
return torch.Tensor()

def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor
Expand All @@ -548,12 +541,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor:
# Store FP8 state
state = {}
for mode in ("input", "param", "grad_output"):

# Get state for a given FP8 tensor
if self.num_fp8_scales(mode) == 0:
state[mode] = None
continue
fp8_meta = self.get_fp8_meta(mode)
fp8_meta = self._fp8_metas
if fp8_meta is None:
continue
state[mode] = {}
Expand Down

0 comments on commit 78c869d

Please sign in to comment.