Skip to content

Commit

Permalink
updates for MLP with torchao.float8
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Dec 30, 2024
1 parent db3a6c9 commit d9ed305
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 5 deletions.
22 changes: 22 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
So far, non-tensor ``ctx`` attributes seem to be folded into a trace.
"""
from thunder.core.baseutils import check, sequencify
from thunder.core.transforms import dce

custom_autograd_function_cls = unwrap(obj)
custom_forward = custom_autograd_function_cls.forward
Expand All @@ -678,6 +679,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
)
if trace_of_fwd is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return trace_of_fwd
trace_of_fwd = dce(trace_of_fwd)

# Forward.
unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args)
Expand All @@ -691,6 +693,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
for a in filter(lambda a: isinstance(a, Proxy), trace_of_fwd.args)
]
trace_of_fwd.bound_symbols = unpack_bsyms + trace_of_fwd.bound_symbols
trace_of_fwd = dce(trace_of_fwd)

@wraps(trace_of_fwd.python_callable())
def core_of_forward(*args, **kwargs):
Expand Down Expand Up @@ -737,6 +740,7 @@ def core_of_forward(*args, **kwargs):
for a in filter(lambda a: isinstance(a, Proxy), trace_of_backward.args)
]
trace_of_backward.bound_symbols = bwd_unpack_bsyms + trace_of_backward.bound_symbols
trace_of_backward = dce(trace_of_backward)

bwd_trace_impl = TraceCtx()
bwd_trace_impl.bound_symbols.extend(trace_of_backward.bound_symbols)
Expand Down Expand Up @@ -770,6 +774,24 @@ def grad_transform(*args, **kwargs):
execution_transform=core_of_forward,
grad_transform=grad_transform,
)

added_bsym: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1]
import_ctx, call_ctx, object_ctx = {}, {}, {}
for bsym in trace_of_fwd.bound_symbols:
cur_import_ctx, cur_call_ctx, cur_object_ctx = bsym.gather_ctxs()
import_ctx.update(cur_import_ctx)
call_ctx.update(cur_call_ctx)
object_ctx.update(cur_object_ctx)

if import_ctx:
added_bsym._import_ctx.update(import_ctx)
if call_ctx:
if added_bsym._call_ctx is not None:
added_bsym._call_ctx.update(call_ctx)
else:
added_bsym._call_ctx = call_ctx
if object_ctx:
added_bsym._object_ctx.update(object_ctx)
return forward_result


Expand Down
20 changes: 20 additions & 0 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,26 @@ def distparallel_type(self):
def thunder_fsdp_padding_size(self):
return self._thunder_fsdp_padding_size

# n.b.(crcrpar): just returning contiguous for `_make_wrapper_subclasses`
def stride(self) -> Sequence[int]:
shape = self.shape
if len(shape) == 1:
return (1,)
elif len(shape) == 0:
return tuple()
else:
import numpy

_stride = reversed(numpy.cumprod([1] + list(shape[1:])).tolist())
return tuple(_stride)

def storage_offset(self) -> int:
return -1

@property
def layout(self) -> torch.layout:
return torch.strided

# We need to implement `__len__` as
# > In addition to bypassing any instance attributes in the
# > interest of correctness, implicit special method lookup
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from functools import partial
from types import FunctionType
import dataclasses
Expand Down Expand Up @@ -64,6 +65,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE):
and not is_likely_from_collections_namedtuple(args)
and not dataclasses.is_dataclass(args)
and not type(args).__module__.startswith("torch.return_types")
and not issubclass(type(args), Enum)
):
raise TypeError(f"tree_flatten of type {type(args)} is not supported.")
return optree.tree_flatten(args, none_is_leaf=True, namespace=namespace)
Expand Down
7 changes: 5 additions & 2 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,11 @@ def add_to_swap_map(old, new):
old = old.replace(shape=new._shape)

if isinstance(new, VJPDual):
swap_map[variableify(new.primal)] = old
new.primal = old
# note(crcrpar): Without this sanity check, `subclass.__tensor_flatten__`,
# seems to cause `new.primal` == `old`, leading to a cycle in swapping.
if (key := variableify(new.primal)) != variableify(old):
swap_map[variableify(new.primal)] = old
new.primal = old
else:
assert isinstance(new, ProxyInterface), (old, new)
swap_map[variableify(new)] = old
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
# may mark some of the operation's outputs as unused
some_unused = False
for out in bsym.flat_proxy_outs:
if variableify(out) in needed_proxies and producer_map[out] == bsym:
if variableify(out) in needed_proxies and producer_map.get(out, None) == bsym:
needed = True
else:
some_unused = True
Expand Down
6 changes: 6 additions & 0 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def _to_torch(*args, **kwargs) -> Any:
if torch_op is None:
raise RuntimeError("op not found for {bsym.sym.name}")

# NOTE(crcrpar): Currently `ltorch.t` is mapped to `torchex.transpose`
# thus `args` needs to be updated to have dim0 and dim1
if bsym.sym.id == "torch.t":
utils.check(len(args) == 1, lambda: f"{bsym.sym.id} takes only one argument but {args=}")
args = args + (0, 1)

return torch_op(*args, **kwargs)

return _to_torch
Expand Down
31 changes: 31 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,13 +1403,44 @@ def _copy_with_setitem_impl(a, key, value):
#

matmul = _register_torch_operation("matmul")
_scaled_mm = _register_torch_operation("_scaled_mm")
outer = _register_torch_operation("outer")

_register_implementation(prims.matmul, matmul, checker=_always_executable)

_register_implementation(ltorch.matmul, matmul, checker=_always_executable)
_register_implementation(ltorch.outer, outer, checker=_always_executable)


def _scaled_mm_transform(
a: TensorLike,
b: TensorLike,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypeLike | None = None,
use_fast_accum: bool = False,
):

def is_column_major(mat: TensorLike) -> bool:
return mat.stride()[0] == 1 and mat.stride()[0] > 1

result_dtype: torch.dtype = to_torch_dtype(a.dtype if out_dtype is None else out_dtype)
if not is_column_major(b):
b = b.t().contiguous().t()

return _scaled_mm(a, b, scale_a, scale_b, bias, scale_result, result_dtype, use_fast_accum)


_register_implementation(
ltorch._scaled_mm, _scaled_mm, checker=_always_executable, execution_transform=_scaled_mm_transform
)
_register_implementation(
ltorch.core_aten_scaled_mm, _scaled_mm, checker=_always_executable, execution_transform=_scaled_mm_transform
)


#
# Normalization operations
#
Expand Down
81 changes: 80 additions & 1 deletion thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
from __future__ import annotations
from typing import TYPE_CHECKING

from lightning_utilities.core.imports import package_available
import pytest
import torch
import torch.nn as nn
from torch.utils import _pytree as pytree

import thunder
from thunder.tests.framework import instantiate
from thunder.dynamo.compiler import ThunderCompiler
from thunder.tests.framework import (
DynamoThunderExecutor,
TorchExecutor,
instantiate,
nvFuserExecutor,
)
from thunder.tests.make_tensor import make_tensor

if TYPE_CHECKING:
from typing import Any


TORCHAO_AVAILABLE = package_available("torchao")


@torch._dynamo.allow_in_graph
class EncapsulateXandScale(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -232,3 +243,71 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.
torch.testing.assert_close(expected, actual)
if requires_grad:
actual.mean().backward()


@instantiate(
dtypes=(thunder.core.dtypes.float32, thunder.core.dtypes.bfloat16),
devicetypes=(thunder.core.devices.DeviceType.CUDA,),
executors=(TorchExecutor, nvFuserExecutor, DynamoThunderExecutor),
decorators=(
pytest.mark.skipif(
not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)),
reason="Requires capability >= 8.9 and torchao",
),
pytest.mark.parametrize("bias", (True, False)),
),
)
def test_torchao_float8_linear(executor, device, dtype, bias):
from torchao.float8 import convert_to_float8_training

batch_size, in_features, out_features = 16, 32, 64
device = torch.device("cuda")
torch_dtype = thunder.core.dtypes.to_torch_dtype(dtype)

model = nn.Sequential(
nn.Linear(in_features, out_features, bias=bias),
nn.GELU(approximate="tanh"),
nn.Linear(out_features, out_features, bias=bias),
).to(device=device, dtype=torch_dtype)
fp8_model = convert_to_float8_training(model)
x = make_tensor((batch_size, in_features), device=device, dtype=torch_dtype)

expected: torch.Tensor
jitted: nn.Module
backend: ThunderCompiler | None = None

if is_thunderfx := executor == DynamoThunderExecutor:
torch._dynamo.reset()
expected = torch.compile(fp8_model)(x)
backend = ThunderCompiler()
jitted = torch.compile(fp8_model, backend=backend)
else:
expected = fp8_model(x)
jitted = executor.make_callable(fp8_model)

if bias and dtype == thunder.core.dtypes.bfloat16 and executor == nvFuserExecutor:
with pytest.raises(
RuntimeError, match="Failed to compute the min-cut on the graph due to a path with infinite capacity"
):
jitted(x)
return
actual = jitted(x)
if bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor:
with pytest.raises(AssertionError, match="Tensor-likes are not close"):
torch.testing.assert_close(actual, expected)
return

if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or (
not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor
):
pytest.xfail("numerical error")
torch.testing.assert_close(actual, expected)

# TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`.
# Currently no subgraphs go to thunder.jit.
if is_thunderfx:
for subgraph in backend.subgraph_infos:
if not bias and dtype == thunder.core.dtypes.bfloat16:
assert not subgraph.thunder_compiled_fns
else:
assert subgraph.thunder_compiled_fns
91 changes: 90 additions & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,9 @@ def t(a: TensorLike, /) -> TensorLike:
lambda: f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D",
RuntimeError,
)
return prims.transpose(a, (1, 0)) if a.ndim == 2 else a
if a.ndim != 2:
return a
return transpose(a, 0, 1)


@torchsymbol(torch.ops.aten.t.default, id="torch.ops.aten.t.default")
Expand Down Expand Up @@ -1480,6 +1482,17 @@ def core_aten_transpose(a: TensorProxy, dim0: int, dim1: int) -> TensorProxy:
return _transpose_impl(a, dim0, dim1)


def _transpose_grad(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike:
fwd = transpose(a, dim0, dim1)
g = get_grad(fwd)
a_grad = transpose(g, dim0, dim1)
put_grad(a, a_grad)
return fwd


register_grad(transpose, _transpose_grad)


@torchsymbol(torch.unbind, is_method=True)
def unbind(a: TensorLike, /, dim: int = 0) -> tuple[TensorLike, ...]:
utils.check(
Expand Down Expand Up @@ -4282,6 +4295,82 @@ def outer(a: TensorLike, b: TensorLike, /) -> TensorLike:
return a[:, None] * b[None, :]


# TODO(crcrpar): Add nvfuser support of `matmul(a.float() * scale_a, b.float() * scale_b) + bias`
# So far I haven't managed to get a nice result from nvfuser region as I left
# https://github.com/Lightning-AI/lightning-thunder/pull/1415/files#r1892875183
# reference: https://github.com/pytorch/pytorch/blob/6d4cd3e/torch/_meta_registrations.py#L5566
def _scaled_mm_impl(
a: TensorLike,
b: TensorLike,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypeLike | None = None,
use_fast_accum: bool = False,
) -> TensorLike:
fp8_dtypes = {dtypes.float8_e4m3fn, dtypes.float8_e4m3fnuz, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz}
# TODO(crcrpar): Devise a way to make sure `a` is row-major and `b` is column-major.
utils.check(
(
(a.ndim == 2 and b.ndim == 2)
and (a.shape[1] == b.shape[0])
and (a.shape[1] % 16 == 0 and b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
and (to_dtype(a.dtype) in fp8_dtypes and to_dtype(b.dtype) in fp8_dtypes)
and not (a.dtype == dtypes.float8_e5m2 and b.dtype == dtypes.float8_e5m2)
and to_device(a.device).type == "cuda"
),
lambda: f"data matrices of {a=} and {b=} do not satisfy the condition.",
)
args = [a, b, scale_a, scale_b]
if bias is not None:
args.append(bias)
utils.check_same_device(args)
utils.check(
(
(scale_a.numel() == 1 and scale_b.numel() == 1)
and (scale_a.dtype == dtypes.float32 and scale_b.dtype == dtypes.float32)
),
lambda: f"Only tensor-wise scaling is supported but {scaled_a.shape = } and {scaled_b.shape = }",
exception_type=NotImplementedError,
)
result_dtype = a.dtype if out_dtype is None else to_dtype(out_dtype)
return TensorProxy(
like=a,
shape=(a.shape[0], b.shape[1]),
device=a.device,
dtype=result_dtype,
)


@torchsymbol(torch._scaled_mm)
def _scaled_mm(
a: TensorLike,
b: TensorLike,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypeLike | None = None,
use_fast_accum: bool = False,
) -> TensorLike:
return _scaled_mm_impl(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum)


@torchsymbol(torch.ops.aten._scaled_mm.default, id="torch.ops.aten._scaled_mm")
def core_aten_scaled_mm(
a: TensorLike,
b: TensorLike,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypeLike | None = None,
use_fast_accum: bool = False,
) -> TensorLike:
return _scaled_mm_impl(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum)


#
# Normalization operations
#
Expand Down

0 comments on commit d9ed305

Please sign in to comment.