Skip to content

Commit

Permalink
[fp8] add fallback and make compile option configurable (#6092)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 authored Oct 18, 2024
1 parent 3b1d7d1 commit 5ddad48
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
6 changes: 5 additions & 1 deletion colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from packaging.version import Version
from torch.distributed import ReduceOp

from .fp8_config import dynamic_kernel

SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
SCALE_BYTES = 4
try:
Expand Down Expand Up @@ -832,11 +834,13 @@ def backward(ctx: Any, out_grad) -> Any:
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad


@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False)
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)


def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0:
return F.linear(input, weight, bias)
out = _linear_fp8(input, weight, bias)
return out
1 change: 1 addition & 0 deletions colossalai/quantization/fp8_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dynamic_kernel: bool = False

0 comments on commit 5ddad48

Please sign in to comment.