Skip to content

Commit

Permalink
move to dev_utils as that seems more suitable location
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Jan 22, 2025
1 parent b2435c9 commit ddc2603
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 65 deletions.
79 changes: 79 additions & 0 deletions thunder/dev_utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import collections

from torch.utils.benchmark import Timer
from thunder.core.prims import PrimIDs
from thunder.core.symbol import BoundSymbol
from thunder.executors.torch_compile import make_compiled as make_torch_compile_callable

NON_COMPUTATION_PRIMS = (
PrimIDs.ASSERT_TENSOR_METADATA,
Expand Down Expand Up @@ -36,3 +41,77 @@
PrimIDs.PRINT,
PrimIDs.RETURN,
)

BenchmarkComparisonData = collections.namedtuple(
"BenchmarkComparisonData",
[
"nvfuser_walltime",
"torch_compile_walltime",
"nvfuser_kernel_time",
"torch_compile_kernel_time",
"nvfuser_profiler_data",
],
)


def _benchmark_fusion_region_with_nvfuser_and_torch_compile(bsym: BoundSymbol) -> BenchmarkComparisonData:
"""
Benchmark the performance of nvFuser and torch.compile for a given fusion region.
This function takes a BoundSymbol generated from nvFuser and performs the following:
1. Executes the fusion region using both nvFuser and torch.compile.
2. Measures wall time and kernel time for both implementations.
3. Collects profiling data for the nvFuser implementation.
Args:
bsym (BoundSymbol): A BoundSymbol generated from nvFuser.
Returns:
BenchmarkComparisonData: A named tuple containing:
- nvfuser_walltime: Wall time for nvFuser execution using `torch.utils.benchmark.Timer`.
- torch_compile_walltime: Wall time for torch.compile execution using `torch.utils.benchmark.Timer`.
- nvfuser_kernel_time: Kernel time for nvFuser execution using `triton.testing.do_bench`.
- torch_compile_kernel_time: Kernel time for torch.compile execution using `triton.testing.do_bench`.
- nvfuser_profiler_data: Profiling data for the nvFuser implementation by calling `fusion_defition.profile`.
.. note:: The function assumes that the fusion has been previously executed and inputs are recorded.
"""
assert "nvFusion" in bsym.sym.name, "Expected the BoundSymbol to be generated from nvFuser"
import triton # Import triton here as it may not be available in CPU only setting.

nvfuser_callable = bsym._call_ctx[bsym.sym.name]
inputs = nvfuser_callable.last_inputs
if nvfuser_callable.last_used is None:
raise RuntimeError(
"Fusion definition needs to be executed to record the inputs. You must execute the fusion first before you can query the repro."
)

if nvfuser_callable.last_inputs is None:
raise RuntimeError(
"Fusion definition inputs need to be recorded. Use compile option 'nv_store_fusion_inputs=True' while tracing."
)

torch_compile_callable = make_torch_compile_callable(bsym.subsymbols, bsym.flat_args, bsym.flat_outs)

nvfuser_callable(*inputs)
torch_compile_callable(*inputs)

nvfuser_timer = Timer("nvfuser_callable(*inputs)", globals={"nvfuser_callable": nvfuser_callable, "inputs": inputs})
tc_timer = Timer(
"torch_compile_callable(*inputs)", globals={"torch_compile_callable": torch_compile_callable, "inputs": inputs}
)

# Wall times
wall_time_nvfuser = nvfuser_timer.blocked_autorange(min_run_time=2)
wall_time_tc = tc_timer.blocked_autorange(min_run_time=2)

# Kernel Times
kernel_time_nvfuser = triton.testing.do_bench(lambda: nvfuser_callable(*inputs), return_mode="median")
kernel_time_tc = triton.testing.do_bench(lambda: torch_compile_callable(*inputs), return_mode="median")

# nvFuser's profiling utility.
fd = nvfuser_callable.get_fd(nvfuser_callable.to_descriptors(inputs))
fd.execute(inputs, profile=True)
nvfuser_prof = fd.profile()

return BenchmarkComparisonData(wall_time_nvfuser, wall_time_tc, kernel_time_nvfuser, kernel_time_tc, nvfuser_prof)
65 changes: 0 additions & 65 deletions thunder/examine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
from thunder.torch import _torch_to_thunder_function_map
from thunder.torch.default_torch_ops import torch_auto_registered_ops
from thunder.core.langctxs import resolve_language, LanguageContext, Languages
from thunder.executors.torch_compile import make_compiled as make_torch_compile_callable

import torch
from torch.utils.benchmark import Timer
from warnings import warn
from itertools import chain
import importlib
Expand Down Expand Up @@ -415,66 +413,3 @@ def _get_color(node_id):
"nvfuser_profiler_data",
],
)


def _benchmark_fusion_region_with_nvfuser_and_torch_compile(bsym: BoundSymbol) -> BenchmarkComparisonData:
"""
Benchmark the performance of nvFuser and torch.compile for a given fusion region.
This function takes a BoundSymbol generated from nvFuser and performs the following:
1. Executes the fusion region using both nvFuser and torch.compile.
2. Measures wall time and kernel time for both implementations.
3. Collects profiling data for the nvFuser implementation.
Args:
bsym (BoundSymbol): A BoundSymbol generated from nvFuser.
Returns:
BenchmarkComparisonData: A named tuple containing:
- nvfuser_walltime: Wall time for nvFuser execution using `torch.utils.benchmark.Timer`.
- torch_compile_walltime: Wall time for torch.compile execution using `torch.utils.benchmark.Timer`.
- nvfuser_kernel_time: Kernel time for nvFuser execution using `triton.testing.do_bench`.
- torch_compile_kernel_time: Kernel time for torch.compile execution using `triton.testing.do_bench`.
- nvfuser_profiler_data: Profiling data for the nvFuser implementation by calling `fusion_defition.profile`.
.. note:: The function assumes that the fusion has been previously executed and inputs are recorded.
"""
assert "nvFusion" in bsym.sym.name, "Expected the BoundSymbol to be generated from nvFuser"
import triton # Import triton here as it may not be available in CPU only setting.

nvfuser_callable = bsym._call_ctx[bsym.sym.name]
inputs = nvfuser_callable.last_inputs
if nvfuser_callable.last_used is None:
raise RuntimeError(
"Fusion definition needs to be executed to record the inputs. You must execute the fusion first before you can query the repro."
)

if nvfuser_callable.last_inputs is None:
raise RuntimeError(
"Fusion definition inputs need to be recorded. Use compile option 'nv_store_fusion_inputs=True' while tracing."
)

torch_compile_callable = make_torch_compile_callable(bsym.subsymbols, bsym.flat_args, bsym.flat_outs)

nvfuser_callable(*inputs)
torch_compile_callable(*inputs)

nvfuser_timer = Timer("nvfuser_callable(*inputs)", globals={"nvfuser_callable": nvfuser_callable, "inputs": inputs})
tc_timer = Timer(
"torch_compile_callable(*inputs)", globals={"torch_compile_callable": torch_compile_callable, "inputs": inputs}
)

# Wall times
wall_time_nvfuser = nvfuser_timer.blocked_autorange(min_run_time=2)
wall_time_tc = tc_timer.blocked_autorange(min_run_time=2)

# Kernel Times
kernel_time_nvfuser = triton.testing.do_bench(lambda: nvfuser_callable(*inputs), return_mode="median")
kernel_time_tc = triton.testing.do_bench(lambda: torch_compile_callable(*inputs), return_mode="median")

# nvFuser's profiling utility.
fd = nvfuser_callable.get_fd(nvfuser_callable.to_descriptors(inputs))
fd.execute(inputs, profile=True)
nvfuser_prof = fd.profile()

return BenchmarkComparisonData(wall_time_nvfuser, wall_time_tc, kernel_time_nvfuser, kernel_time_tc, nvfuser_prof)

0 comments on commit ddc2603

Please sign in to comment.