Skip to content

Commit

Permalink
Remove interpreter-based jvp transform (#1117)
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk authored Sep 7, 2024
1 parent d14374f commit f7bdd5a
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 408 deletions.
228 changes: 0 additions & 228 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,7 +1431,6 @@ def python_callable(*args, **kwargs):

class Transforms(Enum):
VmapOp = auto()
JvpOp = auto()
VjpOp = auto()


Expand Down Expand Up @@ -1524,233 +1523,6 @@ def unwrap_one_level_of_subsymbols(trace):
return trace


# JVP transform
# -------------


@dataclass(frozen=True)
class JVPDual:
"""Dual number for the JVP transform.
Attributes:
primal: Primal value.
tangent: Tangent value.
"""

primal: Any
tangent: Any

def __iter__(self):
yield self.primal
yield self.tangent


def pair_to_jvp_dual(pair):
"""Converts a pair to a JVPDual.
Args:
pair (Sequence): Pair to convert.
Returns:
JVPDual: JVPDual representation of the pair.
"""
if isinstance(pair, JVPDual):
return pair
else:
assert isinstance(pair, Sequence) and len(pair) == 2
return JVPDual(*pair)


def sin_jvp(a: JVPDual):
x, xd = a
return JVPDual(prims.sin(x), prims.cos(x) * xd)


def mul_jvp(a: JVPDual, b: JVPDual):
x, xd = a
y, yd = b
return JVPDual(x * y, x * yd + y * xd)


def add_jvp(a: JVPDual, b: JVPDual):
x, xd = a
y, yd = b
return JVPDual(x + y, xd + yd)


def broadcast_in_dim_jvp(a: JVPDual, shape: tuple[JVPDual, ...], broadcast_dimensions: tuple[JVPDual, ...]) -> JVPDual:
x, xd = a
# TODO: shape and broadcast_dimensions should be tuples of ints
# but for now it's a tuple of JVPDuals
if len(shape) > 0 and isinstance(shape[0], JVPDual):
shape, _ = safe_zip(*shape)
if len(broadcast_dimensions) > 0 and isinstance(broadcast_dimensions[0], JVPDual):
broadcast_dimensions, _ = safe_zip(*broadcast_dimensions)
return JVPDual(
prims.broadcast_in_dim(x, shape, broadcast_dimensions), prims.broadcast_in_dim(xd, shape, broadcast_dimensions)
)


def unpack_sequence_jvp(sequence: JVPDual, length: JVPDual) -> JVPDual:
x = tree_map(lambda x: x.primal, sequence)
xd = tree_map(lambda x: x.tangent, sequence)
length, _ = length
primals = prims.unpack_sequence(x, length)
tangents = prims.unpack_sequence(xd, length)
return safe_map(pair_to_jvp_dual, safe_zip(primals, tangents))


def unpack_trivial_jvp(x: JVPDual) -> JVPDual:
return x


jvp_impls: dict[prims.Symbol, Callable] = dict()

jvp_impls[prims.PrimIDs.SIN] = sin_jvp
jvp_impls[prims.PrimIDs.MUL] = mul_jvp
jvp_impls[prims.PrimIDs.ADD] = add_jvp
jvp_impls[prims.PrimIDs.BROADCAST_IN_DIM] = broadcast_in_dim_jvp
# jvp_impls[prims.PrimIDs.UNPACK_SEQUENCE] = unpack_sequence_jvp
# jvp_impls[prims.PrimIDs.UNPACK_TRIVIAL] = unpack_trivial_jvp


def jvp_symbol_mapper(symbol: prims.Symbol):
"""Maps a symbol to a JVP function that evaluates it.
Args:
symbol (prims.Symbol): Symbol to evaluate.
Raises:
NotImplementedError: If the JVP for the symbol is not implemented.
Returns:
Callable: JVP function that evaluates the symbol.
"""

def wrap_arg(x):
if isinstance(x, JVPDual):
return x
elif isinstance(x, Number):
return JVPDual(x, type(x)(0))
else:
raise ValueError(f"JVP wrap_arg got an unsupported type {type(x)}")

# If symbol.args doesn't have subclasses of Variable, then we need to return a zero tangent
# TODO: there may be a better way to detect constants in the trace
if symbol.are_all_args_constant:

def zeros_like(x):
if isinstance(x, TensorProxy):
return full_like(x, fill_value=0)
elif isinstance(x, NumberProxy):
return type(x.value)(0)
elif isinstance(x, Number):
return type(x)(0)
else:
raise ValueError(f"zeros_like inside JVP got an unsupported type {type(x)}")

def jvp_impl_const(symbol, *args, **kwargs):
primals = symbol_to_eval(symbol)(*args, **kwargs)
if isinstance(primals, Sequence):
tangents = tuple(zeros_like(p) for p in primals)
return safe_map(pair_to_jvp_dual, safe_zip(primals, tangents))
return JVPDual(primals, zeros_like(primals))

return partial(jvp_impl_const, symbol)

# Normal case, we have a proxy tangent
jvp_impl = jvp_impls.get(symbol.sym.id)
if jvp_impl is None:
raise NotImplementedError(f"JVP for {symbol.sym.id} is not implemented")

def _jvp_impl(*args, **kwargs):
args = tree_map(wrap_arg, args)
# Expecting JVPDuals wrapping pairs of primals and tangents
assert all(isinstance(arg, JVPDual) for arg in tree_flatten(args)[0])
return jvp_impl(*args, **kwargs)

return _jvp_impl


def _jvp_call_metafunc(detached: bool, primals, tangents, *, function_trace: Trace, **kwargs):
"""Metafunction for the JVP transform.
Args:
detached (bool): Whether to detach the trace.
primals (Tuple[Proxy]): Primal values.
tangents (Tuple[Proxy]): Tangent values.
function_trace (Trace): Trace of the function to be transformed.
kwargs: Keyword arguments.
Raises:
AssertionError: If the JVP for keyword arguments is not implemented.
Returns:
Result of the JVP transform.
"""
assert len(kwargs) == 0, "JVP for kwargs is not implemented"

ctx = detached_trace() if detached else nullcontext()
with ctx:
# Wrapping the primals and tangents in JVPDuals is not strictly necessary, but it makes
# the code more readable
# We propagate the JVPDuals through the trace, and then unwrap them at the end
primals_tangents_duals = safe_map(pair_to_jvp_dual, safe_zip(primals, tangents))
result = eval_trace(function_trace, *primals_tangents_duals, symbol_mapper=jvp_symbol_mapper)
# Unwrapping the JVPDuals
if isinstance(result, Sequence):
assert all(isinstance(x, JVPDual) for x in result)
primals, tangents = unzip2(result)
return primals, tangents
assert isinstance(result, JVPDual)
return result.primal, result.tangent


jvp_call = Symbol(id=Transforms.JvpOp, name="jvp_call", meta=partial(_jvp_call_metafunc, False))


def jvp(func):
"""Jacobian-vector product transform for a Thunder function.
Args:
func (Callable): A Thunder function to be transformed.
Returns:
Callable: A function that computes the Jacobian-vector product
taking primals and tangents as arguments.
"""

def wrapper(primals, tangents):
trace = construct_trace()(func, *primals)
return jvp_call(primals, tangents, function_trace=trace)

return wrapper


# TODO This function commented out because it calls make_traced, which does not exist
# def jvp_eager(func, primals, tangents, executor="torch"):
# """Computes the Jacobian-vector product of a Thunder function.

# Args:
# func (Callable): A Thunder function to be transformed.
# primals (_type_): Primals of the function.
# tangents (_type_): Tangents of the function.
# executor (str, optional): Executor to use. Defaults to "torch".

# Returns:
# The result of the Jacobian-vector product.
# """
# trace = make_trace(func, executor=executor, *primals)

# def jvp_func(*primals_and_tangents):
# _primals, _tangents = primals_and_tangents[: len(primals)], primals_and_tangents[len(primals) :]
# return _jvp_call_metafunc(_primals, _tangents, trace, detached=False)

# jvp_trace = make_trace(jvp_func, executor=executor)(*primals, *tangents)
# jvp_traced = make_traced(partial(eval_trace, jvp_trace), executor=executor)
# return jvp_traced(*primals, *tangents)


# VJP transform
# =============
@dataclass(frozen=True)
Expand Down
38 changes: 2 additions & 36 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1856,13 +1856,6 @@ def elementwise_binary_generator(op, device, dtype, requires_grad, *, no_rhs_num
sample_input_generator=elementwise_binary_generator,
torch_reference=torch.add,
test_directives=(
# See issue "broadcast_in_dim: The size of contiguity must equal to the
# number of non-broadcasting IterDomains"
DecorateInfo(
pytest.mark.skip,
"test_jvp_correctness",
executors=("nvfuser",),
),
DecorateInfo(
pytest.mark.skip,
"test_vjp_correctness",
Expand Down Expand Up @@ -2097,13 +2090,6 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs):
sample_input_generator=elementwise_binary_generator,
torch_reference=torch.mul,
test_directives=(
# See issue "broadcast_in_dim: The size of contiguity must equal to the
# number of non-broadcasting IterDomains"
DecorateInfo(
pytest.mark.skip,
"test_jvp_correctness",
executors=("nvfuser",),
),
DecorateInfo(
pytest.mark.skip,
"test_vjp_correctness",
Expand Down Expand Up @@ -3113,13 +3099,6 @@ def broadcast_in_dim_error_generator(op, device, **kwargs):
pytest.mark.xfail,
"test_errors",
),
# See issue "broadcast_in_dim: The size of contiguity must equal to the number of
# non-broadcasting IterDomains"
DecorateInfo(
pytest.mark.skip,
"test_jvp_correctness",
executors=("nvfuser",),
),
DecorateInfo(
pytest.mark.skip,
"test_vjp_correctness",
Expand Down Expand Up @@ -3296,9 +3275,8 @@ def expand_error_generator(op, device, *, dtype=torch.float32, **kwargs):
error_input_generator=expand_error_generator,
torch_reference=torch.Tensor.expand,
test_directives=(
# vjp and jvp not yet implemented
# vjp not yet implemented
DecorateInfo(pytest.mark.xfail, "test_vjp_correctness"),
DecorateInfo(pytest.mark.xfail, "test_jvp_correctness"),
),
)
shape_ops.append(expand_opinfo)
Expand Down Expand Up @@ -3345,9 +3323,8 @@ def expand_as_error_generator(op, device, *, dtype=torch.float32, **kwargs):
error_input_generator=expand_as_error_generator,
torch_reference=torch.Tensor.expand_as,
test_directives=(
# vjp and jvp not yet implemented
# vjp not yet implemented
DecorateInfo(pytest.mark.xfail, "test_vjp_correctness"),
DecorateInfo(pytest.mark.xfail, "test_jvp_correctness"),
),
)
shape_ops.append(expand_as_opinfo)
Expand Down Expand Up @@ -4376,10 +4353,6 @@ def stack_wrapper(*args, dim):
sample_input_generator=stack_sample_generator,
error_input_generator=stack_error_generator,
torch_reference=lambda *args, dim: torch.stack(args, dim=dim),
test_directives=(
# vjp and jvp not yet implemented
DecorateInfo(pytest.mark.xfail, "test_jvp_correctness"),
),
)
shape_ops.append(stack_opinfo)

Expand Down Expand Up @@ -5073,13 +5046,6 @@ def unsqueeze_sample_generator(op, device, dtype, requires_grad, **kwargs):
sample_input_generator=unsqueeze_sample_generator,
jax_reference=jax.lax.expand_dims if JAX_AVAILABLE else None,
test_directives=(
# See issue "broadcast_in_dim: The size of contiguity must equal to the
# number of non-broadcasting IterDomains"
DecorateInfo(
pytest.mark.skip,
"test_jvp_correctness",
executors=("nvfuser",),
),
DecorateInfo(
pytest.mark.skip,
"test_vjp_correctness",
Expand Down
Loading

0 comments on commit f7bdd5a

Please sign in to comment.