Skip to content

Commit

Permalink
refactor: autocast.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Dec 9, 2024
1 parent 58b0c69 commit f88ef7a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 37 deletions.
21 changes: 2 additions & 19 deletions thunder/tests/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_thunder_autocast_transform(executor, device, dtype):
def f(a, b, c):
return a @ (b + c)

# The following functions needs to be updated as autocast_impls grows.
def g(a, b, c):
return a + b - c

Expand Down Expand Up @@ -58,6 +59,7 @@ def h(a, b, c):
out = compiled(x, y, z)

devicetype = torch.device(device).type
# note(crcrpar): This test could be broken in the future as thunder autocast develops.
with torch.autocast(device_type=devicetype, dtype=autocast_torch_dtype):
torch_output = func(x, y, z)
assert out.dtype == torch_output.dtype
Expand Down Expand Up @@ -309,22 +311,3 @@ def foo(a, b, c, d):

for eg, jg in zip(eager_grads, jit_grads):
torch.testing.assert_close(eg, jg, rtol=5e-3, atol=5e-3)


# def simple_addition(x, y):
# return x + y


# def test_autocast_transform():
# autocast_transform = AutocastTransform(dtype=torch.bfloat16)
# jitted_fn = jit(simple_addition, transforms=[autocast_transform])

# x = torch.randn(2, 2, dtype=torch.float32)
# y = torch.randn(2, 2, dtype=torch.float32)

# result = jitted_fn(x, y)

# assert result.dtype == torch.bfloat16, f"Expected dtype: bfloat16, but got: {result.dtype}"

# expected_result = simple_addition(x, y).to(torch.bfloat16)
# assert torch.allclose(result, expected_result), "The output values do not match the expected results."
21 changes: 3 additions & 18 deletions thunder/transforms/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,55 +347,40 @@ def __init__(self, trace, dtype, *args, **kwargs):
self.dtype = dtype

def process_bsym(self, bsym):
# Skip special symbols that shouldn't be processed
if bsym.sym.id in trace_interpreter_skip_list:
self.new_trace.bound_symbols.append(bsym.from_bsym())
return

# Check if symbol has an autocast implementation
autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym)

if autocast_impl is not None:
# Read the arguments with potential autocast conversion
args = tree_map(self.read, bsym.args)
kwargs = tree_map(self.read, bsym.kwargs)

# Apply the autocast implementation
with disable_autocast():
result = autocast_impl(*args, **kwargs, dtype=self.dtype)

self.set_result(result)
else:
# No autocast rule, process normally
args = tree_map(self.read, bsym.args)
kwargs = tree_map(self.read, bsym.kwargs)
result = bsym.sym(*args, **kwargs)
self.set_result(result)

# Add the bound symbol to new trace
new_bsym = bsym.from_bsym()
new_bsym.args = args
new_bsym.kwargs = kwargs
self.add_processed_bsyms([new_bsym])

# Process the computation trace
if computation_trace is not None:
processor = AutocastProcessor(computation_trace, self.dtype)

# Get the actual args and kwargs from the kwargs dict
args = kwargs.get("args", ())
kw = kwargs.get("kwargs", {})

with tracectx(processor.new_trace):
# Initialize the processor's environment with input arguments
for trace_arg, arg in zip(computation_trace.args, args):
processor.env[trace_arg.name] = arg
processor.process_args(*args, **kw)

# Initialize kwargs if any
for trace_kwarg, kwarg in zip(computation_trace.kwargs.values(), kw.values()):
processor.env[trace_kwarg.name] = kwarg

new_trace, _ = processor()
computation_trace = new_trace
new_trace, outputs = processor()
computation_trace = new_trace

return prologue_trace, computation_trace, epilogue_trace

0 comments on commit f88ef7a

Please sign in to comment.