diff --git a/src/haliax/hof.py b/src/haliax/hof.py index f90ef7c..6f9ae8d 100644 --- a/src/haliax/hof.py +++ b/src/haliax/hof.py @@ -1,11 +1,13 @@ import dataclasses +import functools import inspect from functools import wraps -from typing import Any, Callable, ParamSpec, Protocol, Tuple, TypeVar, Union, overload +from typing import Any, Callable, Optional, ParamSpec, Protocol, Sequence, Tuple, TypeVar, Union, overload import equinox as eqx import jax import jax.lax as lax +import numpy as np from jaxtyping import PyTree import haliax @@ -14,7 +16,7 @@ from ._src.util import index_where from .axis import Axis, AxisSelector, selects_axis from .core import NamedArray -from .jax_utils import Static, broadcast_prefix, is_jax_array_like +from .jax_utils import Static, broadcast_prefix, checkpointed_scan, is_jax_array_like from .partitioning import physical_axis_name from .util import is_jax_or_hax_array_like, is_named_array @@ -45,6 +47,8 @@ def scan( reverse: bool = False, unroll: int = 1, is_scanned: BoolAxisSpec = is_named_or_shaped_array_like, + grad_checkpointing: bool = False, + checkpoint_blocks: Optional[Sequence[int]] = None, ) -> Callable[[Carry, PyTree[X]], Tuple[Carry, PyTree[Y]]]: ... @@ -57,6 +61,8 @@ def scan( reverse: bool = False, unroll: int = 1, is_scanned: BoolAxisSpec = is_named_or_shaped_array_like, + grad_checkpointing: bool = False, + checkpoint_blocks: Optional[Sequence[int]] = None, ) -> Callable: ... @@ -68,6 +74,8 @@ def scan( reverse=False, unroll=1, is_scanned: BoolAxisSpec = is_named_or_shaped_array_like, + grad_checkpointing: bool = False, + checkpoint_blocks: Optional[Sequence[int]] = None, ): """ Scan over a named axis. Non-scalar unnamed arrays will have their first axis scanned over. @@ -112,6 +120,16 @@ def scanned_f(init, *args, **kwargs): # invariants until we're ready to create the result. axis_first_xs = htu.tree_map(_ensure_first(axis), scanned_xs) + # if we were passed in a string arg, we need to get its axis size out from some arg + if isinstance(axis, str): + true_axis = _infer_axis_size_from_tree(axis_first_xs, axis) + if true_axis is not None: + true_axis + else: + raise ValueError("scan requires either an actual Axis or at least one NamedArray or array arg") + else: + true_axis = axis + # now get a template of an element of "X" x_elem = htu.tree_map(_select_0th(axis), axis_first_xs) # NB: we don't want to use htu.tree_structure here because we want to eliminate the leading axis @@ -130,9 +148,20 @@ def wrapped_fn(carry, scanned_x_leaves): # as above, we don't want to use htu.tree_leaves here because we want to eliminate the leading axis leaves = jax.tree_util.tree_leaves(axis_first_xs) - with jax.named_scope(f"scan({haliax.axis_name(axis)})"): - carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll) - true_axis = _infer_axis_size_from_result(ys, axis) + if grad_checkpointing: + if unroll != 1: + # TODO: support for case when it's a suffix of block size? + raise ValueError("Can't use grad_checkpointing with unroll != 1") + with jax.named_scope(f"ckpt_scan({haliax.axis_name(axis)})"): + blocks = _rectify_scan_lengths(true_axis, checkpoint_blocks) + + scan_fn = functools.partial(checkpointed_scan, lengths=blocks, prevent_cse=False, reverse=reverse) + carry, ys = scan_fn(wrapped_fn, init, leaves) + else: + with jax.named_scope(f"scan({haliax.axis_name(axis)})"): + carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll) + + true_axis = _infer_axis_size_from_tree(ys, axis) ys = jax.tree_util.tree_map(_prepend_named_batch_axis(true_axis), ys, is_leaf=_is_passive_array) return carry, ys @@ -140,6 +169,18 @@ def wrapped_fn(carry, scanned_x_leaves): return scanned_f +def _rectify_scan_lengths(axis: Axis, checkpoint_blocks: Optional[Sequence[int]]) -> list[int]: + blocks = checkpoint_blocks or [axis.size] + cur_size = np.prod(blocks) + if cur_size != axis.size: + left = axis.size // cur_size + if left * cur_size != axis.size: + raise ValueError(f"Can't partition {axis.size} into blocks of size {blocks}") + return list(blocks) + [left] + else: + return list(blocks) + + @overload def fold( fn: Callable[[Carry, X], Carry], @@ -148,6 +189,8 @@ def fold( reverse: bool = False, unroll: int = 1, is_scanned: BoolAxisSpec = is_jax_or_hax_array_like, + grad_checkpointing: bool = False, + checkpoint_blocks: Optional[Sequence[int]] = None, ) -> Callable[[Carry, PyTree[X]], Carry]: ... @@ -160,6 +203,8 @@ def fold( reverse: bool = False, unroll: int = 1, is_scanned: BoolAxisSpec = is_jax_or_hax_array_like, + grad_checkpointing: bool = False, + checkpoint_blocks: Optional[Sequence[int]] = None, ) -> Callable: ... @@ -171,6 +216,8 @@ def fold( reverse: bool = False, unroll: int = 1, is_scanned: BoolAxisSpec = is_named_or_shaped_array_like, + grad_checkpointing: bool = False, + checkpoint_blocks: Optional[Sequence[int]] = None, ) -> Callable: """ Slightly simpler implementation of scan that folds over the named axis of the array, not returning intermediates. @@ -196,7 +243,15 @@ def fold( def scan_compatible_fn(carry, *args, **kwargs): return fn(carry, *args, **kwargs), None - scan_preconfig = scan(scan_compatible_fn, axis, reverse=reverse, unroll=unroll, is_scanned=is_scanned) + scan_preconfig = scan( + scan_compatible_fn, + axis, + reverse=reverse, + unroll=unroll, + is_scanned=is_scanned, + grad_checkpointing=grad_checkpointing, + checkpoint_blocks=checkpoint_blocks, + ) def scanned_f(init, *args, **kwargs): return scan_preconfig(init, *args, **kwargs)[0] @@ -359,7 +414,7 @@ def wrapped_fn(args, kwargs): result = eqx.combine(result_dynamic, result_static.value) # if we were passed in a string arg, we need to get its axis size out from some result - true_axis = _infer_axis_size_from_result(result, axis) + true_axis = _infer_axis_size_from_tree(result, axis) if true_axis is None: raise ValueError("vmap failed to infer axis size from result") @@ -369,17 +424,19 @@ def wrapped_fn(args, kwargs): return wrapped_vmap_fn -def _infer_axis_size_from_result(result, axis): +def _infer_axis_size_from_tree(result, axis): if isinstance(axis, str): result_leaves = jax.tree_util.tree_leaves(result, is_leaf=_is_passive_array) if len(result_leaves) == 0: - # this really shouldn't happen return None - if isinstance(result_leaves[0], _PassiveNamedArray): - true_axis_size = result_leaves[0].array.shape[0] # batch axis is defined to be 0 above + leaf = result_leaves[0] + if isinstance(leaf, _PassiveNamedArray): + true_axis_size = leaf.array.shape[0] # batch axis is defined to be 0 above true_axis = Axis(axis, true_axis_size) - else: - true_axis_size = result_leaves[0].shape[0] # batch axis is defined to be 0 above + elif isinstance(leaf, NamedArray): + true_axis = leaf.resolve_axis(axis) + elif isinstance(leaf, jax.numpy.ndarray) and leaf.ndim > 0: + true_axis_size = leaf.shape[0] # batch axis is defined to be 0 above true_axis = Axis(axis, true_axis_size) else: true_axis = axis @@ -424,7 +481,7 @@ def tree_unflatten(cls, aux, tree: Any) -> Any: def _is_passive_array(arr): - return isinstance(arr, _PassiveNamedArray) + return isinstance(arr, _PassiveNamedArray) or isinstance(arr, NamedArray) def _prepend_named_batch_axis(leading_axis: Axis): diff --git a/src/haliax/jax_utils.py b/src/haliax/jax_utils.py index 7856b25..fb7017c 100644 --- a/src/haliax/jax_utils.py +++ b/src/haliax/jax_utils.py @@ -1,3 +1,4 @@ +import functools import functools as ft import typing from typing import Any, Callable, List, Optional, Sequence, Union @@ -9,6 +10,8 @@ from jax import random as jrandom from jaxtyping import PRNGKeyArray +import haliax + F = typing.TypeVar("F", bound=Callable[..., Any]) @@ -140,3 +143,80 @@ def is_pallas_dslice(x: object) -> bool: _PALLAS_DSLICE_TYPE = type(pdslice(0, 1)) return isinstance(x, _PALLAS_DSLICE_TYPE) + + +def is_scalarish(x): + if isinstance(x, haliax.NamedArray): + return x.ndim == 0 + else: + return jnp.isscalar(x) or x.shape == () + + +def checkpointed_scan( + body_fn, + init, + xs, + lengths: Sequence[int], + *, + reverse: bool = False, + policy: Optional[Callable[..., bool]] = None, + prevent_cse: bool = False, +): + """ + Runs a recursive checkpointed scan over xs, where the scan is split into multiple scans, each of which has length + lengths[i] for some i. + + This uses less memory than not checkpointing a scan, but more than + + Note this uses "vanilla" JAX arrays, not NamedArrays + + """ + if len(lengths) == 1: + return jax.lax.scan(jax.checkpoint(body_fn, prevent_cse=prevent_cse, policy=policy), init, xs, lengths[0]) + else: + # we want to split the scan up into multiple recursive scans, doing a total of `prod(lengths)` steps + # this makes a tree of scans with depth len(lengths) + # check total length against any xs + total_length = np.prod(lengths) + + def check_leaf(x): + assert x.shape[0] == total_length + + jax.tree_util.tree_map(lambda x: check_leaf(x), xs) + + ckpt = functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy) + + @ckpt + def _body_fn(carry, i, start): + my_xs = jax.tree_util.tree_map(lambda x: x[start + i], xs) + return body_fn(carry, my_xs) + + def rec_scan_fn(lengths): + # returns a fn that, when called, scans over prod(lengths) steps recursively + if len(lengths) == 1: + range = jnp.arange(lengths[0]) + return ckpt( + lambda carry, start: jax.lax.scan( + functools.partial(_body_fn, start=start), carry, range, lengths[0], reverse=reverse + ) + ) + else: + my_len = lengths[0] + rest_len = lengths[1:] + range_to_scan = jnp.arange(my_len) * np.prod(rest_len) + return ckpt( + lambda carry, start: jax.lax.scan( + rec_scan_fn(rest_len), + carry, + range_to_scan + start, + reverse=reverse, + ) + ) + + res, unflattened = rec_scan_fn(lengths)(init, 0) + + # need to flatten the output + # we need to flatten the leading len(lengths) dimensions of the output + flattened = jax.tree_util.tree_map(lambda y: jnp.reshape(y, (-1,) + y.shape[len(lengths) :]), unflattened) + + return res, flattened diff --git a/src/haliax/nn/scan.py b/src/haliax/nn/scan.py index f9750c3..27202ef 100644 --- a/src/haliax/nn/scan.py +++ b/src/haliax/nn/scan.py @@ -1,4 +1,5 @@ import functools +import math from typing import Dict, Generic, Optional, Protocol, Sequence, Type, TypeVar import equinox as eqx @@ -6,7 +7,7 @@ import haliax import haliax.util -from haliax.jax_utils import filter_checkpoint +from haliax.jax_utils import filter_checkpoint, named_call from ..axis import Axis @@ -70,7 +71,7 @@ class Stacked(eqx.Module, Generic[M]): @staticmethod def init( - Block: Axis, module: Type[M], *, gradient_checkpointing: bool = False, prevent_cse: bool = True + Block: Axis, module: Type[M], *, gradient_checkpointing: bool = False, prevent_cse: bool = False ) -> ModuleInit["Stacked[M]"]: """ Initialize a Stacked module. This method is curried: you can pass in the Block and module, and it will return @@ -89,16 +90,34 @@ def fn(*args, **kwargs): return fn - def scan(self, init, *extra_args, **extra_kwargs): + @named_call(name="Stacked.scan") + def scan(self, init, *args, **kwargs): if self.gradient_checkpointing: do_block = filter_checkpoint(self._do_block, prevent_cse=self.prevent_cse) + # determine a checkpoint block size, should be roughly sqrt(self.Block.size) + size = int(math.sqrt(self.Block.size)) + num_blocks = int(math.ceil(self.Block.size / size)) + rest = self.Block.size // size + block_spec = [num_blocks, rest] + + return haliax.scan( + do_block, self.Block, grad_checkpointing=self.gradient_checkpointing, checkpoint_blocks=block_spec + )(init, self.stacked, *args, **kwargs) else: - do_block = self._do_block - return haliax.scan(do_block, self.Block)(init, self.stacked, *extra_args, **extra_kwargs) + return haliax.scan(self._do_block, self.Block)(init, self.stacked, *args, **kwargs) + @named_call(name="Stacked.fold") def fold(self, init, *args, **kwargs): + print(f"FOLD! {self.gradient_checkpointing} {self.prevent_cse}", flush=True) if self.gradient_checkpointing: - do_block = filter_checkpoint(self._do_block) + do_block = filter_checkpoint(self._do_block, prevent_cse=self.prevent_cse) + # determine a checkpoint block size, should be roughly sqrt(self.Block.size) + size = int(math.sqrt(self.Block.size)) + num_blocks = int(math.ceil(self.Block.size / size)) + + return haliax.fold( + do_block, self.Block, grad_checkpointing=self.gradient_checkpointing, checkpoint_blocks=[num_blocks] + )(init, self.stacked, *args, **kwargs) else: do_block = self._do_block diff --git a/tests/test_hof.py b/tests/test_hof.py index 3923837..0a754e3 100644 --- a/tests/test_hof.py +++ b/tests/test_hof.py @@ -248,3 +248,41 @@ def __call__(self, x): Width = Axis("Width", 3) hax.vmap(lambda a: Module(a), Batch)(Width) + + +def test_scan_raises_with_string_arg_and_no_args(): + def scan_fun(acc): + return acc, acc + + try: + hax.scan(scan_fun, "Height")(0.0) + except ValueError as e: + assert "scan requires either an actual Axis or at least one NamedArray or array" in str(e) + else: + assert False, "should have raised" + + +def test_scan_works_with_string_arg_and_one_arg(): + Height = Axis("Height", 10) + named1 = hax.random.uniform(PRNGKey(0), (Height,)) + + def scan_fun(acc, x): + return acc + x.scalar(), x + + total, named2 = hax.scan(scan_fun, "Height")(0.0, named1) + + assert jnp.all(jnp.isclose(total, jnp.sum(named1.array))) + assert jnp.all(jnp.equal(named1.array, named2.array)) + + +def test_scan_works_with_string_and_unnamed_args(): + Height = Axis("Height", 10) + named1 = hax.random.uniform(PRNGKey(0), (Height,)) + + def scan_fun(acc, x): + return acc + x, x + + total, named2 = hax.scan(scan_fun, "Height")(0.0, named1.array) + + assert jnp.all(jnp.isclose(total, jnp.sum(named1.array))) + assert jnp.all(jnp.equal(named1.array, named2)) diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py new file mode 100644 index 0000000..ca96442 --- /dev/null +++ b/tests/test_jax_utils.py @@ -0,0 +1,37 @@ +import jax.lax +import jax.numpy as jnp + +from haliax.jax_utils import checkpointed_scan + + +def test_checkpointed_scan(): + def body_fn(carry, x): + return carry - x, carry + jnp.log1p(x) + + init = 0 + xs = jnp.arange(2 * 3 * 4, dtype=jnp.float32) + + lengths = [2, 3, 4] + + result, partials = checkpointed_scan(body_fn, init, xs, lengths) + + # compare to vanilla + vanilla_result, vanilla_partials = jax.lax.scan(body_fn, init, xs) + + assert jnp.all(result == vanilla_result) + assert jnp.all(partials == vanilla_partials) + + # check derivatives + def f(x): + x, y = checkpointed_scan(body_fn, init, x, lengths=lengths) + return x + y.sum() + + def vanilla_f(x): + x, y = jax.lax.scan(body_fn, init, x) + return x + y.sum() + + z = jax.jit(jax.grad(f))(xs) + + vanilla_z = jax.jit(jax.grad(vanilla_f))(xs) + + assert jnp.allclose(z, vanilla_z)