Skip to content

Commit

Permalink
Remove interpreter-based vmap transform (#1116)
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk authored Sep 7, 2024
1 parent 111ffa6 commit d14374f
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 578 deletions.
374 changes: 0 additions & 374 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,144 +1514,6 @@ def write(v: Variable, val: Any, allow_duplicates=False) -> None:
return tree_map(read, trace.output)


# VMAP transform
# -------------


class NotMapped:
"""Represents a non-batched dimension."""

def __repr__(self):
return "not_mapped"


@dataclass(frozen=True)
class BatchedValue:
"""Batched value for the vmap transform.
Attributes:
value: Batched value.
batch_dim: Batching dimension or not_mapped
"""

value: Any
batch_dim: int | NotMapped

def __iter__(self):
yield self.value
yield self.batch_dim


def pair_to_batched_value(pair):
"""Converts a pair to a BatchedValue.
Args:
pair (Sequence): Pair to convert.
Returns:
BatchedValue: BatchedValue representation of the pair.
"""
if isinstance(pair, BatchedValue):
return pair
else:
assert isinstance(pair, Sequence) and len(pair) == 2
return BatchedValue(*pair)


def vectorized_batcher(prim, axis_size, batched_values, **kwargs):
batch_dim = batched_values[0].batch_dim
assert all(
batch_dim == bv.batch_dim for bv in batched_values[1:]
), f"`vectorized_batcher` got different batched dimensions {[bv.batch_dim for bv in batched_values]}"
return BatchedValue(prim(*[bv.value for bv in batched_values], **kwargs), batch_dim)


not_mapped = NotMapped()


def movedim(x, src: int, dst: int):
perm = [i for i in range(x.ndim) if i != src]
perm.insert(dst, src)
return prims.transpose(x, tuple(perm))


def move_batch_dim(axis_size, src, dst, x):
if src is not_mapped:
if isinstance(x, Number):
return x
target_shape = list(x.shape)
target_shape.insert(dst, axis_size)
bcast_dims = list(range(len(target_shape)))
bcast_dims.pop(dst)
return prims.broadcast_in_dim(x, target_shape, bcast_dims)
elif src == dst:
return x
else:
return movedim(x, src, dst)


def binary_op_batching_rule(op: prims.PrimIDs, axis_size: int, vals_in: BatchedValue):
((x, x_bdim), (y, y_bdim)) = vals_in
if x_bdim != y_bdim:
if x_bdim is not_mapped:
x = move_batch_dim(axis_size, x_bdim, y_bdim, x)
x_bdim = y_bdim
else:
y = move_batch_dim(axis_size, y_bdim, x_bdim, y)
return BatchedValue(op(x, y), x_bdim)


def sin_vmap(axis_size: int, a: BatchedValue) -> BatchedValue:
return vectorized_batcher(prims.sin, axis_size, (a,))


def cos_vmap(axis_size: int, a: BatchedValue) -> BatchedValue:
return vectorized_batcher(prims.cos, axis_size, (a,))


def mul_vmap(axis_size: int, a: BatchedValue, b: BatchedValue) -> BatchedValue:
return binary_op_batching_rule(prims.mul, axis_size, (a, b))


def add_vmap(axis_size: int, a: BatchedValue, b: BatchedValue) -> BatchedValue:
return binary_op_batching_rule(prims.add, axis_size, (a, b))


def sum_vmap(axis_size: int, a: BatchedValue, dims: Sequence[int], **kwargs) -> BatchedValue:
bdim = a.batch_dim
# TODO: remove this when dims becomes a mandatory kwarg
if len(dims) > 0:
dims, _ = safe_zip(*dims)
dims_before = tuple(el for el in dims if el < bdim)
dims_after = tuple(el + 1 for el in dims if el >= bdim)
batched_dims = dims_before + dims_after
return vectorized_batcher(prims.sum, axis_size, (a,), dims=batched_dims, **kwargs)


# TODO: Please test this extensively
def broadcast_in_dim_vmap(
axis_size: int, a: BatchedValue, shape: Sequence[BatchedValue], broadcast_dimensions: Sequence[BatchedValue]
) -> BatchedValue:
bdim = a.batch_dim
# TODO: remove this when shape and broadcast_dimensions become mandatory kwargs
shape, _ = safe_zip(*shape)
if len(broadcast_dimensions) > 0:
broadcast_dimensions, _ = safe_zip(*broadcast_dimensions)
if bdim is not_mapped:
return BatchedValue(prims.broadcast_in_dim(a.value, shape, broadcast_dimensions), bdim)
else:
new_bdim = bdim + sum(1 for dim in broadcast_dimensions if dim < bdim)
new_shape = list(shape)
new_shape.insert(new_bdim, axis_size)
new_broadcast_dimensions = (0,) + tuple(dim + 1 if dim >= bdim else dim for dim in broadcast_dimensions)
if broadcast_dimensions == ():
new_broadcast_dimensions = ()
return BatchedValue(prims.broadcast_in_dim(a.value, new_shape, new_broadcast_dimensions), new_bdim)


vmap_impls: dict[prims.Symbol, Callable] = dict()


def unwrap_one_level_of_subsymbols(trace):
new_symbols_iter = (
bound_symbol.subsymbols if len(bound_symbol.subsymbols) > 0 else [bound_symbol]
Expand All @@ -1662,225 +1524,6 @@ def unwrap_one_level_of_subsymbols(trace):
return trace


def decomposed_fn_vmap_rule(axis_size, *args, fn, **kwargs):
args, in_dims = unzip2(args)
unbatched_args = tree_map(lambda x: remove_batch_dim(x) if isinstance(x, TensorProxy) else x, args)
trace = construct_trace()(fn, *unbatched_args, **kwargs)
trace = unwrap_one_level_of_subsymbols(trace)
outs = _vmap_call_metafunc(False, args, in_dims, 0, axis_size, function_trace=trace, **kwargs)
if isinstance(outs, Sequence):
out_dims = (0,) * len(outs)
return safe_map(pair_to_batched_value, safe_zip(outs, out_dims))
return BatchedValue(outs, 0)


vmap_impls[prims.PrimIDs.SIN] = sin_vmap
vmap_impls[prims.PrimIDs.COS] = cos_vmap
vmap_impls[prims.PrimIDs.MUL] = mul_vmap
vmap_impls[prims.PrimIDs.ADD] = add_vmap
vmap_impls[prims.PrimIDs.SUM] = sum_vmap
vmap_impls[prims.PrimIDs.BROADCAST_IN_DIM] = broadcast_in_dim_vmap


def vmap_symbol_mapper(symbol: prims.Symbol, *, axis_size: int):
"""Maps a symbol to a vmap function that evaluates it.
Args:
symbol (prims.Symbol): Symbol to evaluate.
Raises:
NotImplementedError: If the vmap for the symbol is not implemented.
Returns:
Callable: vmap function that evaluates the symbol.
"""

def wrap_arg(x):
if isinstance(x, BatchedValue):
return x
elif isinstance(x, (Number, NumberProxy)):
return BatchedValue(x, not_mapped)
else:
raise ValueError(f"vmap wrap_arg got an unsupported type {type(x)}")

if symbol.are_all_args_constant:

def _vmap_impl_const(symbol, *args, **kwargs):
out = symbol_to_eval(symbol)(*args, **kwargs)

if isinstance(out, Sequence):
return safe_map(pair_to_batched_value, safe_zip(args, [not_mapped] * len(out)))

return BatchedValue(out, not_mapped)

return partial(_vmap_impl_const, symbol)

vmap_impl = vmap_impls.get(symbol.sym.id)
if vmap_impl is None:
if len(symbol.subsymbols) > 0:
vmap_impl = partial(decomposed_fn_vmap_rule, fn=symbol.sym)
else:
raise NotImplementedError(f"vmap for {symbol.sym.id} is not implemented")

def _vmap_impl(*args, **kwargs):
args = tree_map(wrap_arg, args)
assert all(isinstance(arg, BatchedValue) for arg in tree_flatten(args)[0])
return vmap_impl(axis_size, *args, **kwargs)

return _vmap_impl


def remove_batch_dim(tensor: TensorProxy, batch_dim: int = 0) -> TensorProxy:
"""Removes the batch dimension from a tensor.
Args:
tensor (TensorProxy): Tensor to remove the batch dimension from.
Returns:
TensorProxy: Tensor with the batch dimension removed.
"""
new_shape = tensor.shape[:batch_dim] + tensor.shape[batch_dim + 1 :]
return TensorProxy(like=tensor, shape=new_shape)


# TODO: in JAX args, in_dims are flattened the same way
# TODO: in JAX out_dims are flattened as well
def _vmap_call_metafunc(detached: bool, args, in_dims, out_dims, axis_size, function_trace: Trace, **kwargs):
"""Metafunction for vmap call.
Args:
detached (bool): Whether to detach the trace.
args (Tuple[Proxy]): Arguments to the function.
in_dims (Tuple[int]): Batch dimension for each argument.
out_dims (Tuple[int]): Batch dimension for return values.
function_trace (Trace): Trace to use for the function.
kwargs: Keyword arguments.
Raises:
AssertionError: If the vmap for keyword arguments is not implemented.
Returns:
Result of the vmap transform.
"""
common_device = {x.device for x in args if isinstance(x, TensorProxy)}
assert len(common_device) <= 1, "vmap for multiple devices is not implemented"
(common_device,) = common_device if len(common_device) == 1 else (cpu,)

if axis_size is None:
(axis_size,) = {x.shape[ax] for x, ax in zip(args, in_dims) if ax is not not_mapped}
in_dims = in_dims if isinstance(in_dims, Sequence) else (in_dims,)
in_dims = tuple(not_mapped if isinstance(a, (Number, NumberProxy)) else d for a, d in safe_zip(args, in_dims))
out_dims = out_dims if isinstance(out_dims, Sequence) else (out_dims,)

ctx = detached_trace() if detached else nullcontext()
with ctx:
# We propagate the BatchValue through the trace, and then unwrap it at the end
batched_args = safe_map(pair_to_batched_value, safe_zip(args, in_dims))
result = eval_trace(
function_trace, *batched_args, symbol_mapper=partial(vmap_symbol_mapper, axis_size=axis_size), **kwargs
)
# Unwrapping the BatchedValue's
if isinstance(result, Sequence):
flat_result, spec = tree_flatten(result)
assert all(isinstance(x, BatchedValue) for x in flat_result)
outs, bdims = unzip2(flat_result)
# TODO: handle the case where out_dims is a single value better
if len(out_dims) == 1:
out_dims = out_dims * len(outs)
outs = safe_map(partial(move_batch_dim, axis_size), bdims, out_dims, outs)
return tree_unflatten(outs, spec)
if isinstance(result, (Number, NumberProxy)) and axis_size is not None:
# TODO: fetch the default device from the context
result = ltorch.full(shape=(), fill_value=result, device=common_device)
result = BatchedValue(result, not_mapped)
elif (
isinstance(result, BatchedValue)
and isinstance(result.value, (Number, NumberProxy))
and axis_size is not None
):
result = BatchedValue(
ltorch.full(shape=(), fill_value=result.value, device=common_device), result.batch_dim
)
assert isinstance(result, BatchedValue)
out = move_batch_dim(axis_size, result.batch_dim, out_dims[0], result.value)
return out


vmap_call = Symbol(id=Transforms.VmapOp, name="vmap_call", meta=partial(_vmap_call_metafunc, False))


def _jvp_call_vmap(axis_size, batched_primals, batched_tangents, *, function_trace: Trace, **kwargs):
primals, primals_bdims = safe_zip(*batched_primals)
tangents, tangents_bdims = safe_zip(*batched_tangents)
jvp_func = partial(_jvp_call_metafunc, False, function_trace=function_trace)
vmapped_jvp_func = vmap(jvp_func, in_dims=(primals_bdims, tangents_bdims), axis_size=axis_size)
result = vmapped_jvp_func(primals, tangents, **kwargs)
return tree_map(lambda x: BatchedValue(x, 0), result)


vmap_impls[Transforms.JvpOp] = _jvp_call_vmap


def vmap(func, in_dims=0, out_dims=0, axis_size=None):
"""Vectorizing transform for a Thunder function.
Args:
func (Callable): A Thunder function to be transformed.
Returns:
Callable: A vmapped version of the function.
"""

# TODO: flatten
# In JAX flattening of in_dims is rather complicated because it can optionally be
# specified as a “prefix” pytree, meaning that a single leaf value can be applied
# to an entire sub-pytree.

def flatten_func_for_vmap(func, args, kwargs):
flat_args, spec = tree_flatten((args, kwargs))

def flat_func(*flat_args):
fn_args, fn_kwargs = tree_unflatten(flat_args, spec)
return func(*fn_args, **fn_kwargs)

return flat_func, flat_args, spec

def wrapper(*args, **kwargs):
func_flat, args_flat, args_spec = flatten_func_for_vmap(func, args, kwargs)
if isinstance(in_dims, int):
in_dims_flat = (in_dims,) * len(args_flat)
else:
in_dims_flat, in_dims_spec = tree_flatten(in_dims)
assert len(in_dims_flat) == len(args_flat), "in_dims must have the same length as args, kwargs"
unbatched_args_flat = [remove_batch_dim(arg) if isinstance(arg, TensorProxy) else arg for arg in args_flat]
trace = construct_trace()(func_flat, *unbatched_args_flat)
outs = vmap_call(args_flat, in_dims_flat, out_dims, axis_size=axis_size, function_trace=trace)
return outs

return wrapper


# TODO This function commented out because it calls make_traced, which does not exist
# def vmap_eager(func, args, in_dims=0, out_dims=0, axis_size=None, executor="torch"):
# """Computes the vmap of a Thunder function.

# Args:
# func (Callable): A Thunder function to be transformed.
# args (_type_): Args of the function.
# executor (str, optional): Executor to use. Defaults to "torch".

# Returns:
# The result of the vmapped function.
# """
# # TODO: fix this - not all args may be batched
# # TODO: here we assume batch axis is 0
# vmap_trace = make_trace(
# vmap(func, in_dims=in_dims, out_dims=out_dims, axis_size=axis_size), executor=executor,
# *args)
# vmap_traced = make_traced(partial(eval_trace, vmap_trace), executor=executor)
# return vmap_traced(*args)


# JVP transform
# -------------

Expand Down Expand Up @@ -2066,23 +1709,6 @@ def _jvp_call_metafunc(detached: bool, primals, tangents, *, function_trace: Tra
jvp_call = Symbol(id=Transforms.JvpOp, name="jvp_call", meta=partial(_jvp_call_metafunc, False))


def _vmap_call_jvp(args: JVPDual, in_dims, out_dims, axis_size, trace: Trace, **kwargs):
primals, tangents = safe_zip(*args)
in_dims, _ = safe_zip(*in_dims)
out_dims, _ = safe_zip(*out_dims)
vmapped_trace = construct_trace()(
vmap(partial(eval_trace, trace), in_dims=in_dims, out_dims=out_dims, axis_size=axis_size), *primals
)
vmapped_func = partial(eval_trace, vmapped_trace)
out_primals, out_tangents = jvp(vmapped_func)(primals, tangents, **kwargs)
if isinstance(out_primals, Sequence):
return safe_map(pair_to_jvp_dual, safe_zip(out_primals, out_tangents))
return JVPDual(out_primals, out_tangents)


jvp_impls[Transforms.VmapOp] = _vmap_call_jvp


def jvp(func):
"""Jacobian-vector product transform for a Thunder function.
Expand Down
Loading

0 comments on commit d14374f

Please sign in to comment.