Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename AdHocExecutor -> TransientOperatorExecutor and add documentation #1573

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def jit(

# Resolve names of executors
executors = resolve_executors(executors)
ad_hoc_executor = extend.AdHocExecutor()
ad_hoc_executor = extend.TemporaryExecutor()
executors = (ad_hoc_executor, *executors)

# TODO: verify that tutorials don't have false positives and enable warning by default
Expand Down
4 changes: 2 additions & 2 deletions thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,12 +619,12 @@ def import_ctx(self):
# NOTE If the call ctx was specified directly, then no import is needed to call the function
import_ctx = {}
else:
from thunder.extend import AdHocExecutor
from thunder.extend import TemporaryExecutor

# BoundSymbols of Symbols without Python implementations (either because they
# have Python implementations or defined call ctxs) are assumed to need
# a module import to run properly
if isinstance(self.sym.executor, AdHocExecutor):
if isinstance(self.sym.executor, TemporaryExecutor):
import_ctx = {}
else:
assert self.sym.module is not None # TODO: Is this a valid assumption?
Expand Down
30 changes: 27 additions & 3 deletions thunder/extend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,32 @@ def register_implementation(
self.implmap[_id] = impl


class AdHocExecutor(OperatorExecutor):
"""An "Anonymous" executor to be used for temporary registrations"""
class TemporaryExecutor(OperatorExecutor):
"""
A specialized executor for managing temporary operator registrations at runtime.

This executor generates unique identifiers for each registered operator by combining
the operator name with instance-specific identifiers. It's designed for scenarios
requiring dynamic operator registration without conflicting with existing operations.

Key Features:
- Creates unique operator names using instance ID and counter
- Supports runtime registration of operators
- Handles opaque function registration
- Maintains isolation between different temporary registrations

Example:
>>> executor = TemporaryExecutor()
>>> op = executor.register_operator(
... name="temp_add",
... like=thunder.torch.add,
... fn=lambda x, y: x + y
... )

Note:
Operators registered through this executor are intended for temporary use
and should not be relied upon for permanent implementations.
"""

def __init__(self):
super().__init__(f"__ad_hoc_executor_{id(self)}")
Expand Down Expand Up @@ -349,7 +373,7 @@ def meta(*args, **kwargs):
return symbol

def __repr__(self) -> str:
return f"<thunder.extend.AdHocExecutor object {id(self)}>"
return f"<thunder.extend.TemporaryExecutor object {id(self)}>"


def single_op_executor(
Expand Down
22 changes: 22 additions & 0 deletions thunder/tests/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,25 @@ def test_validate_executors():
assert thunder.resolve_executors(("python", pytorch_executor)) == (pythonex, pytorch_executor)
with pytest.raises(ValueError, match=re.compile("Expected an Executor or the name of a registered Executor")):
assert thunder.resolve_executors(("python", "foo", pytorch_executor, "bar"))


def test_transient_operator_executor():
from thunder.extend import TemporaryExecutor
from functools import partial

executor = TemporaryExecutor()
op = executor.register_operator(
name="temp_add",
like=thunder.torch.add,
fn=lambda a, b: a + b,
)

@partial(thunder.jit, executors=[executor])
def f(a, b):
return op(a, b)

a = torch.randn(4, 4, device="cpu")

assert_close(f(a, a), a + a)
assert thunder.last_traces(f)[-1].bound_symbols[2].sym.executor is executor
assert thunder.last_traces(f)[-1].bound_symbols[2].sym.name.startswith("temp_add")
Loading