From 5ddad486cab3ef067d6ae0ab87475d52f34dc27f Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 18 Oct 2024 13:55:31 +0800 Subject: [PATCH] [fp8] add fallback and make compile option configurable (#6092) --- colossalai/quantization/fp8.py | 6 +++++- colossalai/quantization/fp8_config.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 colossalai/quantization/fp8_config.py diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 8243a29ac825..e23da5cccd4d 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -8,6 +8,8 @@ from packaging.version import Version from torch.distributed import ReduceOp +from .fp8_config import dynamic_kernel + SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0") SCALE_BYTES = 4 try: @@ -832,11 +834,13 @@ def backward(ctx: Any, out_grad) -> Any: return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False) +@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel) def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _LinearFp8.apply(input, weight, bias) def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0: + return F.linear(input, weight, bias) out = _linear_fp8(input, weight, bias) return out diff --git a/colossalai/quantization/fp8_config.py b/colossalai/quantization/fp8_config.py new file mode 100644 index 000000000000..efa6251856aa --- /dev/null +++ b/colossalai/quantization/fp8_config.py @@ -0,0 +1 @@ +dynamic_kernel: bool = False