diff --git a/examples/datasets.py b/examples/datasets.py index 345a3d3..3bd57d7 100644 --- a/examples/datasets.py +++ b/examples/datasets.py @@ -35,6 +35,14 @@ _IMAGENET_STDDEV_RGB = (0.229, 0.224, 0.225) +def sharded_iterator( + dataset: tf.data.Dataset, + sharding: jax.sharding.NamedSharding, +) -> Iterator[Batch]: + for batch in iter(tensorflow_datasets.as_numpy(dataset)): + yield jax.device_put(batch, sharding) + + def mnist_dataset( split: str, has_labels: bool, @@ -43,6 +51,7 @@ def mnist_dataset( repeat: bool, shuffle: bool, drop_remainder: bool, + sharding: jax.sharding.NamedSharding, seed: Optional[int] = None, multi_device: bool = True, reshuffle_each_iteration: bool = True, @@ -59,13 +68,13 @@ def mnist_dataset( shuffle: Whether to shuffle the dataset. drop_remainder: Whether to drop the remainder of the dataset if the number of data points is not divisible by the total batch size. + sharding: Sharding spec for each batch. seed: Any seed to use for random pre-processing. multi_device: If the returned batch should take into account the number of devices present, in which case it will return an array with shape `(num_device, device_batch_size, ...)`. reshuffle_each_iteration: Whether to reshuffle the dataset in a new order after each iteration. - dtype: The returned data type of the images. Returns: The MNIST dataset as a tensorflow dataset. @@ -74,14 +83,7 @@ def mnist_dataset( # Set for multi devices vs single device num_devices = jax.device_count() if multi_device else 1 num_local_devices = jax.local_device_count() if multi_device else 1 - - if multi_device: - host_batch_shape = [num_local_devices, device_batch_size] - else: - host_batch_shape = [device_batch_size] - host_batch_size = num_local_devices * device_batch_size - num_examples = tfds.builder("mnist").info.splits[split].num_examples if num_examples % num_devices != 0: @@ -95,8 +97,7 @@ def preprocess_batch( """Standard reshaping of the images to (28, 28).""" images = tf.image.convert_image_dtype(images, dtype) single_example_shape = [784] if flatten_images else [28, 28] - images = tf.reshape(images, host_batch_shape + single_example_shape) - labels = tf.reshape(labels, host_batch_shape) + images = tf.reshape(images, [host_batch_size] + single_example_shape) if has_labels: return dict(images=images, labels=labels) else: @@ -123,7 +124,7 @@ def preprocess_batch( ds = ds.prefetch(tf.data.experimental.AUTOTUNE) - return iter(tensorflow_datasets.as_numpy(ds)) + return sharded_iterator(ds, sharding) def imagenet_num_examples_and_split( diff --git a/examples/optimizers.py b/examples/optimizers.py index f2aa15e..8e55919 100644 --- a/examples/optimizers.py +++ b/examples/optimizers.py @@ -85,6 +85,10 @@ def __init__( axis_name=self.pmap_axis_name, ) + @property + def state_sharding(self) -> jax.sharding.NamedSharding: + raise NotImplementedError() + def init( self, params: Params, @@ -438,7 +442,6 @@ def create_optimizer( value_func_has_aux=has_aux, value_func_has_state=has_func_state, value_func_has_rng=has_rng, - multi_device=True, **kwargs, ) elif name == "sgd": diff --git a/examples/training.py b/examples/training.py index 3fafbf9..45d8ce3 100644 --- a/examples/training.py +++ b/examples/training.py @@ -21,6 +21,7 @@ from absl import logging import jax +from jax.experimental import mesh_utils import jax.numpy as jnp from jaxline import experiment from jaxline import utils as pipe_utils @@ -41,20 +42,23 @@ InitFunc = Callable[[PRNGKey, Batch], Params] -class SupervisedExperiment(experiment.AbstractExperiment): +class SupervisedExperiment(abc.ABC): """Abstract supervised experiment. Attributes: mode: Either 'train' or 'eval' specifying whether to run training or evaluation of the experiment. - init_rng: The Jax PRNG key that is used to seed any randomness of the - experiment. + init_rng: The Jax PRNG key that is used to seed the initialization of the + model parameters. + seed_rng: An RNG used fo seeding the dataset iterators. config: The experiment config. has_aux: Whether the model function returns any auxiliary data. has_rng: Whether the model function needs an PRNG key. has_func_state: Whether the model function has a state. + eval_splits: Evaluation splits of the evaluation dataset loader. init_parameters_func: A function that initializes the parameters and optionally the state of the model if it has one. + params_init: A function that initializes the model parameters. model_loss_func: A function that computes the loss for the model. train_model_func: The `model_loss_func` with `is_training` set to `True`. eval_model_func: The `model_loss_func` with `is_training` set to `False`. @@ -62,16 +66,6 @@ class SupervisedExperiment(experiment.AbstractExperiment): optimizer: The optimizer instance used for training. """ - CHECKPOINT_ATTRS = { - "_params": "params", - "_state": "state", - "_opt_state": "opt_state", - } - - NON_BROADCAST_CHECKPOINT_ATTRS = { - "_python_step": "python_step" - } - def __init__( self, mode: str, @@ -82,6 +76,7 @@ def __init__( has_aux: bool, has_rng: bool, has_func_state: bool, + eval_splits: Tuple[str, ...] = ("train", "test"), ): """Initializes experiment. @@ -97,17 +92,20 @@ def __init__( has_aux: Whether the model function returns auxiliary data. has_rng: Whether the model function requires an RNG. has_func_state: Whether the model function has a state. + eval_splits: Evaluation splits of the evaluation dataset loader. """ - super().__init__(mode=mode, init_rng=init_rng) self.mode = mode - self.init_rng = init_rng + self.init_rng, self.seed_rng = jax.random.split(init_rng) + self.seed_rng = jax.random.fold_in(self.seed_rng, jax.process_index()) self.config = config self.has_aux = has_aux self.has_rng = has_rng self.has_func_state = has_func_state + self.eval_splits = eval_splits self.verify_batch_size_config() - self.init_parameters_func = init_parameters_func + self.params_init = jax.jit(init_parameters_func, + out_shardings=self.model_sharding) self.model_loss_func = model_loss_func self.train_model_func = functools.partial( self.model_loss_func, is_training=True) @@ -123,13 +121,13 @@ def __init__( self.optimizer = self.create_optimizer() # Initialize the state - self._train_input, self._eval_input = None, None + self._train_input, self._eval_input, self._init_batch = None, None, None self._params, self._state, self._opt_state = None, None, None self._python_step = 0 - self.initialize_state() def log_machines_setup(self): + """Logs the machine setup for the experiment.""" logging.info("Worker with mode %s", self.mode) logging.info("Number of hosts[%d]: %d", jax.process_index(), jax.process_count()) @@ -168,6 +166,64 @@ def verify_batch_size_config(self): "``batch_size.eval.per_device`` config arguments must " "be set to a value and the other one must be ``None``.") + @functools.cached_property + def sharding_config(self) -> ml_collections.ConfigDict: + """The sharding config.""" + default_config = ml_collections.ConfigDict(dict( + mesh_shape=(jax.device_count(),), + mesh_axis=("batch",), + dataset_axis=("batch",), + model_axis=(), + optimizer_axis=(), + )) + config = self.config.get("sharding", default_config) + config.update(default_config) + return config + + @functools.cached_property + def jit_mesh(self) -> jax.sharding.Mesh: + """The device mesh used when calling `jax.jit`.""" + devices = mesh_utils.create_device_mesh(self.sharding_config.mesh_shape) + return jax.sharding.Mesh(devices, self.sharding_config.mesh_axis) + + @functools.cached_property + def dataset_sharding_spec(self) -> jax.sharding.PartitionSpec: + """The sharding specification for the dataset.""" + axis = [(None if name not in self.sharding_config.dataset_axis else name) + for name in self.sharding_config.mesh_axis] + return jax.sharding.PartitionSpec(*axis) + + @functools.cached_property + def dataset_sharding(self) -> jax.sharding.NamedSharding: + """The NamedSharding for the dataset.""" + return jax.sharding.NamedSharding(self.jit_mesh, self.dataset_sharding_spec) + + @functools.cached_property + def model_sharding_spec(self) -> jax.sharding.PartitionSpec: + """The sharding specification for the model.""" + axis = [(None if name not in self.sharding_config.model_axis else name) + for name in self.sharding_config.mesh_axis] + return jax.sharding.PartitionSpec(*axis) + + @functools.cached_property + def model_sharding(self) -> jax.sharding.NamedSharding: + """The NamedSharding for the model.""" + return jax.sharding.NamedSharding(self.jit_mesh, self.model_sharding_spec) + + @functools.cached_property + def optimizer_sharding_spec(self) -> jax.sharding.PartitionSpec: + """The sharding specification for the optimizer.""" + axis = [ + (None if name not in self.sharding_config.optimizer_axis else name) + for name in self.sharding_config.mesh_axis + ] + return jax.sharding.PartitionSpec(*axis) + + @functools.cached_property + def optimizer_state_sharding(self) -> jax.sharding.NamedSharding: + """The NamedSharding for the optimizer state.""" + return self.optimizer.state_sharding + @property @abc.abstractmethod def dataset_size(self) -> int: @@ -247,12 +303,58 @@ def eval_total_batch_size(self) -> int: """The evaluator total batch size.""" return self.eval_per_device_batch_size * self.num_eval_devices + @property + @functools.lru_cache(maxsize=1) + def train_input(self) -> Iterator[Batch]: + """Returns the current training iterator.""" + if self._train_input is None: + logging.info("Initializing data iterators.") + seed_rng = jax.random.fold_in(self.seed_rng, self._python_step) + self._train_input = pipe_utils.py_prefetch( + functools.partial( + self._build_train_input, + split="train", + seed=int(seed_rng[0]), + device_batch_size=self.train_per_device_batch_size, + ) + ) + return self._train_input + @property @functools.lru_cache(maxsize=1) def train_inputs(self) -> Union[Iterator[Batch], Tuple[Iterator[Batch], Iterator[Batch]]]: """The training data iterator.""" - return self._train_input + return self.train_input + + @property + @functools.lru_cache(maxsize=1) + def eval_input(self) -> Dict[str, Callable[[], Iterator[Batch]]]: + """"Returns all evaluation iterators constructors.""" + if self._eval_input is None: + seed_rng = jax.random.fold_in(self.seed_rng, self._python_step) + self._eval_input = {} + for split in self.eval_splits: + self._eval_input[split] = functools.partial( + self._build_eval_input, + split="train", + seed=int(seed_rng[1]), + device_batch_size=self.eval_per_device_batch_size, + ) + return self._eval_input + + @property + @functools.lru_cache(maxsize=1) + def init_batch(self) -> Batch: + """A fake batch size used to initialize the model parameters and state.""" + if self._init_batch is None: + if self.mode == "train": + self._init_batch, iterator = kfac_jax.utils.fake_element_from_iterator( + self.train_input) + self._train_input = iterator + else: + self._init_batch = next(self.eval_input["train"]()) + return self._init_batch def progress( self, @@ -269,7 +371,7 @@ def progress( return data_seen / total_data - def should_run_step( + def terminate_training( self, global_step: int, config: ml_collections.ConfigDict, @@ -277,7 +379,7 @@ def should_run_step( del config # not used - return int(self.progress(global_step)) < 1 + return int(self.progress(global_step)) >= 1 def create_optimizer(self) -> Union[optimizers.OptaxWrapper, kfac_jax.Optimizer]: @@ -297,67 +399,26 @@ def create_optimizer(self) -> Union[optimizers.OptaxWrapper, epochs=self.config.training.epochs, ) - def initialize_state(self): + def maybe_initialize_state(self): """Initializes all the experiment's state variables.""" - - init_rng, seed_rng = jax.random.split(self.init_rng) - init_rng = kfac_jax.utils.replicate_all_local_devices(init_rng) - - # Beause we fold in the process index here, it's important that any sharding - # happen *before* shuffling - seed_rng = jax.random.fold_in(seed_rng, jax.process_index()) - seed = int(seed_rng[0]) - - # Initialize and load dataset - if self.mode == "train": - self._train_input = pipe_utils.py_prefetch( - functools.partial( - self._build_train_input, - split="train", - seed=seed, - device_batch_size=self.train_per_device_batch_size, - ) - ) - # Need an example batch for initialization - init_batch, self._train_input = kfac_jax.utils.fake_element_from_iterator( - self._train_input) - - elif self.mode == "eval": - self._eval_input = dict( - train=functools.partial( - self._build_eval_input, - split="train", - seed=seed, - device_batch_size=self.eval_per_device_batch_size - ), - test=functools.partial( - self._build_eval_input, - split="test", - seed=seed, - device_batch_size=self.eval_per_device_batch_size - ), - ) - init_batch = next(self._eval_input["train"]()) - - else: - raise NotImplementedError() + if self._params is not None: + logging.info("Loaded from checkpoint, not initializing parameters.") + return # Initialize parameters and optional state - init_func = jax.pmap(self.init_parameters_func) - params_rng, optimizer_rng = kfac_jax.utils.p_split(init_rng) + params_rng, optimizer_rng = jax.random.split(self.init_rng) if self.has_func_state: - self._params, self._state = init_func(params_rng, init_batch) + self._params, self._state = self.params_init(params_rng, self.init_batch) else: - self._params = init_func(params_rng, init_batch) + self._params = self.params_init(params_rng, self.init_batch) # Initialize optimizer state self._opt_state = self.optimizer.init( - self._params, optimizer_rng, init_batch, self._state) + self._params, optimizer_rng, self.init_batch, self._state) if not self.has_func_state: # Needed for checkpointing - self._state = kfac_jax.utils.replicate_all_local_devices( - jax.numpy.zeros([])) + self._state = () # _ _ # | |_ _ __ __ _(_)_ __ @@ -376,13 +437,9 @@ def _build_train_input( ) -> datasets.tf.data.Dataset: """Constructs the training dataset.""" - def step( # pytype: disable=signature-mismatch - self, - global_step: Array, - rng: PRNGKey, - **unused_args: Any, - ) -> Dict[str, Numeric]: - del global_step + def train_step(self, global_step: Array, rng: PRNGKey) -> Dict[str, Numeric]: + """Performs a single training step.""" + del global_step # Unused # Perform optimizer step result = self.optimizer.step( @@ -402,7 +459,7 @@ def step( # pytype: disable=signature-mismatch if "aux" in stats: # Average everything in aux and then put it in stats - stats.update(kfac_jax.utils.compute_mean(stats.pop("aux"))) + stats.update(stats.pop("aux", {})) stats["progress"] = self.progress(self._python_step) @@ -463,18 +520,18 @@ def _evaluate_single_batch( if hasattr(opt_state, "data_seen"): stats["data_seen"] = opt_state.data_seen - return kfac_jax.utils.pmean_if_pmap(stats, "eval_axis") # pytype: disable=bad-return-type + return stats - def evaluate( # pytype: disable=signature-mismatch + def run_evaluation( self, global_step: Array, rng: PRNGKey, - **unused_args: Any, ) -> Dict[str, Numeric]: + """Runs the evaluation of the currently loaded model parameters.""" all_stats = dict() # Evaluates both the train and eval split metrics - for name, dataset_iter_thunk in self._eval_input.items(): + for name, dataset_iter_thunk in self.eval_input.items(): # pytype: disable=attribute-error logging.info("Running evaluation for %s", name) @@ -501,11 +558,49 @@ def evaluate( # pytype: disable=signature-mismatch return all_stats +class JaxlineExperiment(SupervisedExperiment, experiment.AbstractExperiment): + """A Jaxline supervised experiment.""" + + CHECKPOINT_ATTRS = { + "_params": "params", + "_state": "state", + "_opt_state": "opt_state", + } + + NON_BROADCAST_CHECKPOINT_ATTRS = { + "_python_step": "python_step" + } + + def should_run_step( + self, + global_step: int, + config: ml_collections.ConfigDict, + ) -> bool: + return not self.terminate_training(global_step, config) + + def step( # pytype: disable=signature-mismatch + self, + global_step: Array, + rng: PRNGKey, + **unused_kwargs, + ) -> Dict[str, Numeric]: + self.maybe_initialize_state() + return self.train_step(global_step, rng) + + def evaluate( # pytype: disable=signature-mismatch + self, + global_step: Array, + rng: PRNGKey, + **unused_kwargs, + ) -> Dict[str, Numeric]: + return self.run_evaluation(global_step, rng) + + def train_standalone_supervised( random_seed: int, full_config: ml_collections.ConfigDict, experiment_ctor: - Callable[[str, PRNGKey, ml_collections.ConfigDict], SupervisedExperiment], + Callable[[str, PRNGKey, ml_collections.ConfigDict], JaxlineExperiment], storage_folder: Optional[str], ) -> Dict[str, Array]: """Run an experiment without the Jaxline runtime.""" @@ -571,7 +666,7 @@ def train_standalone_supervised( return stats -class MnistExperiment(SupervisedExperiment): +class MnistExperiment(JaxlineExperiment): """An experiment using the MNIST dataset.""" def __init__( @@ -601,8 +696,7 @@ def __init__( model_loss_func=model_loss_func, ) - @property - @functools.lru_cache(maxsize=1) + @functools.cached_property def dataset_size(self) -> int: return 60_000 @@ -612,7 +706,7 @@ def _build_train_input( seed: int, device_batch_size: int, **_: Any, - ) -> datasets.tf.data.Dataset: + ) -> Iterator[Batch]: assert split == "train" return datasets.mnist_dataset( split=split, @@ -622,6 +716,7 @@ def _build_train_input( repeat=True, shuffle=True, drop_remainder=True, + sharding=self.dataset_sharding, seed=seed, reshuffle_each_iteration=True, ) @@ -632,8 +727,7 @@ def _build_eval_input( seed: int, device_batch_size: int, **_: Any, - ) -> datasets.tf.data.Dataset: - + ) -> Iterator[Batch]: assert split in ("train", "test") return datasets.mnist_dataset( @@ -644,11 +738,12 @@ def _build_eval_input( repeat=False, shuffle=False, drop_remainder=False, - seed=seed + sharding=self.dataset_sharding, + seed=seed, ) -class ImageNetExperiment(SupervisedExperiment): +class ImageNetExperiment(JaxlineExperiment): """An experiment using the ImageNet dataset.""" def __init__( diff --git a/kfac_jax/_src/curvature_blocks.py b/kfac_jax/_src/curvature_blocks.py index 1a21368..a0aec7f 100644 --- a/kfac_jax/_src/curvature_blocks.py +++ b/kfac_jax/_src/curvature_blocks.py @@ -15,7 +15,7 @@ import abc import collections import functools -from typing import Optional, Sequence, Any, Set, Tuple, Union, Dict +from typing import Optional, Sequence, Any, Set, Tuple, Union, Dict, Type import jax import jax.numpy as jnp @@ -33,6 +33,7 @@ Shape = utils.Shape DType = utils.DType ScalarOrSequence = Union[Scalar, Sequence[Scalar]] +Cache = Dict[str, Union[Array, Dict[str, Array]]] # Special global variables # The default value that would be used for the argument @@ -44,6 +45,21 @@ # curvature blocks that inherit from ``Full`. _DEFAULT_EIGEN_DECOMPOSITION_THRESHOLD = 5 +_CLASSES_FROM_DICT = {} + + +def add_block_class_from_dict(class_type: Type[Any]) -> Type[Any]: + _CLASSES_FROM_DICT[class_type.__name__] = class_type + return class_type + + +def block_from_dict(dict_rep: Dict[str, Any]) -> "CurvatureBlock": + class_name = dict_rep.pop("__class__") + if class_name not in _CLASSES_FROM_DICT: + raise ValueError(f"Did not find how to reconstruct class {class_name}.") + cls = _CLASSES_FROM_DICT[class_name] + return cls.from_dict(dict_rep) + def set_max_parallel_elements(value: int): """Sets the default value of maximum parallel elements in the module. @@ -140,7 +156,10 @@ class State(utils.State): this are updated via calls to :func:`~CurvatureBlock.update_cache`, and do not necessarily correspond to the most up-to-date curvature estimate. """ - cache: Optional[Dict[str, Union[Array, Dict[str, Array]]]] + cache: Optional[Cache] + + def as_dict(self) -> Dict[str, Any]: + return {"__class__": self.__class__.__name__, "cache": self.cache} def __init__(self, layer_tag_eq: tags.LayerTagEqn, name: str): """Initializes the block. @@ -267,6 +286,17 @@ def __str__(self): return f"{self._name!r}[{self.parameters_shapes!r}]" + @abc.abstractmethod + def state_sharding( + self, + vector_sharding: Sequence[jax.sharding.NamedSharding], + exact_powers_to_cache: Optional[ScalarOrSequence], + approx_powers_to_cache: Optional[ScalarOrSequence], + cache_eigenvalues: bool, + **kwargs, + ) -> "CurvatureBlock.State": + """Constructs the block sharding from the corresponding prameter sharding.""" + @utils.auto_scope_method def init( self, @@ -455,7 +485,6 @@ def update_curvature_matrix_estimate( ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, - pmap_axis_name: Optional[str], ) -> "CurvatureBlock.State": """Updates the block's curvature estimates using the ``info`` provided. @@ -475,8 +504,6 @@ def update_curvature_matrix_estimate( ema_new: Specifies the weight of the new value when computing the updated estimate in the moving average. batch_size: The batch size used in computing the values in ``info``. - pmap_axis_name: The name of any pmap axis, which might be needed for - computing the updates. """ @utils.auto_scope_method @@ -533,6 +560,7 @@ def _to_dense_unscaled(self, state: "CurvatureBlock.State") -> Array: """A dense representation of the curvature, ignoring ``self.scale``.""" +@add_block_class_from_dict class ScaledIdentity(CurvatureBlock): """A block that assumes that the curvature is a scaled identity matrix.""" @@ -566,9 +594,18 @@ def _init( del rng, exact_powers_to_cache, approx_powers_to_cache # Unused - return CurvatureBlock.State( - cache=None, - ) + return CurvatureBlock.State(cache=None) + + def state_sharding( + self, + vector_sharding: Sequence[jax.sharding.NamedSharding], + exact_powers_to_cache: Optional[ScalarOrSequence], + approx_powers_to_cache: Optional[ScalarOrSequence], + cache_eigenvalues: bool, + **kwargs: Any, + ) -> CurvatureBlock.State: + del vector_sharding, kwargs # Unused + return CurvatureBlock.State(cache=None) def _multiply_matpower_unscaled( self, @@ -609,7 +646,6 @@ def update_curvature_matrix_estimate( ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, - pmap_axis_name: Optional[str], ) -> CurvatureBlock.State: return state.copy() @@ -660,6 +696,22 @@ def _init( shape, self.dtype) for shape in self.parameters_shapes), ) + def state_sharding( + self, + vector_sharding: Sequence[jax.sharding.NamedSharding], + exact_powers_to_cache: Optional[ScalarOrSequence], + approx_powers_to_cache: Optional[ScalarOrSequence], + cache_eigenvalues: bool, + **kwargs: Any, + ) -> "Diagonal.State": + return Diagonal.State( + cache=None, + diagonal_factors=tuple( + utils.WeightedMovingAverage.state_sharding(sharding) + for sharding in vector_sharding + ), + ) + def _multiply_matpower_unscaled( self, state: "Diagonal.State", @@ -697,16 +749,12 @@ def update_curvature_matrix_estimate( ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, - pmap_axis_name: Optional[str], ) -> "Diagonal.State": # This function call will return a copy of state: state = self._update_curvature_matrix_estimate( state, estimation_data, ema_old, ema_new, batch_size) - for factor in state.diagonal_factors: - factor.sync(pmap_axis_name) - return state @abc.abstractmethod @@ -854,6 +902,38 @@ def _init( [self.dim, self.dim], self.dtype), ) + def state_sharding( + self, + vector_sharding: Sequence[jax.sharding.NamedSharding], + exact_powers_to_cache: Optional[ScalarOrSequence], + approx_powers_to_cache: Optional[ScalarOrSequence], + cache_eigenvalues: bool, + **kwargs: Any, + ) -> "Full.State": + # This block does not have any notion of "approximate" powers + exact_powers_to_cache = (_to_real_set(exact_powers_to_cache) | + _to_real_set(approx_powers_to_cache)) + cache = {} + + # TODO(botev, jamesmartens): Figure out a better way of getting sharding + sharding = vector_sharding[0] + + if len(exact_powers_to_cache) > self._eigen_decomposition_threshold: + cache["eigenvalues"] = sharding + cache["eigen_vectors"] = sharding + + elif cache_eigenvalues: + cache["eigenvalues"] = sharding + + if len(exact_powers_to_cache) <= self._eigen_decomposition_threshold: + for power in exact_powers_to_cache: + cache[str(power)] = sharding + + return Full.State( + cache=cache, + matrix=utils.WeightedMovingAverage.state_sharding(sharding, **kwargs), + ) + def _multiply_matpower_unscaled( self, state: "Full.State", @@ -912,15 +992,12 @@ def update_curvature_matrix_estimate( ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, - pmap_axis_name: Optional[str], ) -> "Full.State": # This function call will return a copy of state: state = self._update_curvature_matrix_estimate( state, estimation_data, ema_old, ema_new, batch_size) - state.matrix.sync(pmap_axis_name) - return state @abc.abstractmethod @@ -1003,6 +1080,20 @@ class State(CurvatureBlock.State): inputs_factor: utils.WeightedMovingAverage outputs_factor: utils.WeightedMovingAverage + def as_dict(self) -> Dict[str, Any]: + dict_rep = super().as_dict() + dict_rep["inputs_factor"] = self.inputs_factor.as_dict() + dict_rep["outputs_factor"] = self.outputs_factor.as_dict() + return dict_rep + + @classmethod + def from_dict(cls, dict_rep) -> "TwoKroneckerFactored.State": + dict_rep["inputs_factor"] = utils.WeightedMovingAverage.from_dict( + dict_rep["inputs_factor"]) + dict_rep["outputs_factor"] = utils.WeightedMovingAverage.from_dict( + dict_rep["outputs_factor"]) + return cls(**dict_rep) + @property def has_bias(self) -> bool: """Whether this layer's equation has a bias.""" @@ -1084,6 +1175,48 @@ def _init( [d_out, d_out], self.dtype), ) + def state_sharding( + self, + vector_sharding: Sequence[jax.sharding.NamedSharding], + exact_powers_to_cache: Optional[ScalarOrSequence], + approx_powers_to_cache: Optional[ScalarOrSequence], + cache_eigenvalues: bool, + **kwargs: Any, + ) -> "TwoKroneckerFactored.State": + # TODO(jamesmartens,botev): Figure out better way of doing this + input_sharding = vector_sharding[0] + output_sharding = vector_sharding[0] + + exact_powers_to_cache = _to_real_set(exact_powers_to_cache) + approx_powers_to_cache = _to_real_set(approx_powers_to_cache) + cache = {} + + if cache_eigenvalues or exact_powers_to_cache: + cache["inputs_factor_eigenvalues"] = input_sharding + cache["outputs_factor_eigenvalues"] = output_sharding + + if exact_powers_to_cache: + cache["inputs_factor_eigen_vectors"] = input_sharding + cache["outputs_factor_eigen_vectors"] = output_sharding + + for power in approx_powers_to_cache: + if power != -1: + raise NotImplementedError( + f"Approximations for power {power} is not yet implemented." + ) + cache[str(power)] = dict( + inputs_factor=input_sharding, + outputs_factor=output_sharding, + ) + + return TwoKroneckerFactored.State( + cache=cache, + inputs_factor=utils.WeightedMovingAverage.state_sharding( + input_sharding), + outputs_factor=utils.WeightedMovingAverage.state_sharding( + output_sharding) + ) + def _multiply_matpower_unscaled( self, state: "TwoKroneckerFactored.State", @@ -1169,16 +1302,12 @@ def update_curvature_matrix_estimate( ema_old: Numeric, ema_new: Numeric, batch_size: Numeric, - pmap_axis_name: Optional[str], ) -> "TwoKroneckerFactored.State": # This function call will return a copy of state: state = self._update_curvature_matrix_estimate( state, estimation_data, ema_old, ema_new, batch_size) - state.inputs_factor.sync(pmap_axis_name) - state.outputs_factor.sync(pmap_axis_name) - return state @abc.abstractmethod @@ -1262,6 +1391,7 @@ def _to_dense_unscaled( return jnp.kron(inputs_factor, state.outputs_factor.value) +@add_block_class_from_dict class NaiveDiagonal(Diagonal): """Approximates the diagonal of the curvature with in the most obvious way. @@ -1289,6 +1419,7 @@ def _update_curvature_matrix_estimate( return state +@add_block_class_from_dict class NaiveFull(Full): """Approximates the full curvature with in the most obvious way. @@ -1327,6 +1458,7 @@ def _update_curvature_matrix_estimate( # +@add_block_class_from_dict class DenseDiagonal(Diagonal): """A `Diagonal` block specifically for dense layers.""" @@ -1364,6 +1496,7 @@ def _update_curvature_matrix_estimate( return state +@add_block_class_from_dict class DenseFull(Full): """A `Full` block specifically for dense layers.""" @@ -1395,6 +1528,7 @@ def _update_curvature_matrix_estimate( return state +@add_block_class_from_dict class DenseTwoKroneckerFactored(TwoKroneckerFactored): """A :class:`~TwoKroneckerFactored` block specifically for dense layers.""" @@ -1451,6 +1585,7 @@ def _update_curvature_matrix_estimate( # +@add_block_class_from_dict class Conv2DDiagonal(Diagonal): """A :class:`~Diagonal` block specifically for 2D convolution layers.""" @@ -1545,6 +1680,7 @@ def _update_curvature_matrix_estimate( return state +@add_block_class_from_dict class Conv2DFull(Full): """A :class:`~Full` block specifically for 2D convolution layers.""" @@ -1637,6 +1773,7 @@ def _update_curvature_matrix_estimate( return state +@add_block_class_from_dict class Conv2DTwoKroneckerFactored(TwoKroneckerFactored): """A :class:`~TwoKroneckerFactored` block specifically for 2D convolution layers.""" @@ -1832,6 +1969,7 @@ def compatible_sum(tensor, target_shape, skip_axes): return jnp.sum(tensor, axis=axis) +@add_block_class_from_dict class ScaleAndShiftDiagonal(Diagonal): """A diagonal approximation specifically for a scale and shift layers.""" @@ -1893,6 +2031,7 @@ def _update_curvature_matrix_estimate( return state +@add_block_class_from_dict class ScaleAndShiftFull(Full): """A full dense approximation specifically for a scale and shift layers.""" diff --git a/kfac_jax/_src/curvature_estimator.py b/kfac_jax/_src/curvature_estimator.py index 58bd99a..8dd0bdc 100644 --- a/kfac_jax/_src/curvature_estimator.py +++ b/kfac_jax/_src/curvature_estimator.py @@ -634,7 +634,6 @@ def multiply_matpower( power: Scalar, exact_power: bool, use_cached: bool, - pmap_axis_name: Optional[str], ) -> utils.Params: """Computes ``(CurvatureMatrix + identity_weight I)**power`` times ``vector``. @@ -654,9 +653,6 @@ def multiply_matpower( vary across different blocks. use_cached: Whether to use a cached (and possibly stale) version of the curvature matrix estimate. - pmap_axis_name: The name of any pmap axis, which will be used for - aggregating any computed values over multiple devices, as well as - parallelizing the computation over devices in a block-wise fashion. Returns: A parameter structured vector containing the product. @@ -669,7 +665,6 @@ def multiply( identity_weight: Numeric, exact_power: bool, use_cached: bool, - pmap_axis_name: Optional[str], ) -> utils.Params: """Computes ``(CurvatureMatrix + identity_weight I)`` times ``vector``.""" @@ -680,7 +675,6 @@ def multiply( power=1, exact_power=exact_power, use_cached=use_cached, - pmap_axis_name=pmap_axis_name ) def multiply_inverse( @@ -690,7 +684,6 @@ def multiply_inverse( identity_weight: Numeric, exact_power: bool, use_cached: bool, - pmap_axis_name: Optional[str], ) -> utils.Params: """Computes ``(CurvatureMatrix + identity_weight I)^-1`` times ``vector``.""" @@ -701,7 +694,6 @@ def multiply_inverse( power=-1, exact_power=exact_power, use_cached=use_cached, - pmap_axis_name=pmap_axis_name ) @abc.abstractmethod @@ -732,7 +724,6 @@ def update_curvature_matrix_estimate( batch_size: Numeric, rng: PRNGKey, func_args: utils.FuncArgs, - pmap_axis_name: Optional[str], estimation_mode: Optional[str] = None, ) -> StateType: """Updates the estimator's curvature estimates. @@ -750,9 +741,6 @@ def update_curvature_matrix_estimate( function (the ``tagged_func`` passed into the constructor) which to be used for the estimation process. Should have the same structure as the argument ``func_args`` passed in the constructor. - pmap_axis_name: When calling this method within a pmap context this - argument specifies the axis name over which to aggregate across - multiple devices/hosts. estimation_mode: The type of curvature estimator to use. By default (e.g. if ``None``) will use ``self.default_estimation_mode``. One of: @@ -793,7 +781,6 @@ def update_cache( exact_powers: Optional[curvature_blocks.ScalarOrSequence], approx_powers: Optional[curvature_blocks.ScalarOrSequence], eigenvalues: bool, - pmap_axis_name: Optional[str], ) -> StateType: """Updates the estimator cached values. @@ -810,9 +797,6 @@ def update_cache( eigenvalues: Specifies whether to update the cached eigenvalues of each block. If they have not been cached before, this will create an entry with them in the block's cache. - pmap_axis_name: The name of any pmap axis, which will be used for - aggregating any computed values over multiple devices, as well as - parallelizing the computation over devices in a block-wise fashion. Returns: The updated state. @@ -837,6 +821,19 @@ class State(utils.State): """ blocks_states: Tuple[curvature_blocks.CurvatureBlock.State, ...] + def as_dict(self) -> Dict[str, Any]: + dict_reps = tuple(block.as_dict() for block in self.blocks_states) + return {"blocks_states": dict_reps} + + @classmethod + def from_dict( + cls, + dict_rep: Dict[str, Any], + ) -> "BlockDiagonalCurvature.State": + return cls(blocks_states=tuple( + curvature_blocks.block_from_dict(block_rep) + for block_rep in dict_rep["blocks_states"])) + def __init__( self, func: utils.Func, @@ -1016,6 +1013,30 @@ def _compute_losses_vjp(self, func_args: utils.FuncArgs): """Computes all model statistics needed for estimating the curvature.""" return self._vjp(func_args) + def state_sharding( + self, + params_sharding: utils.ShardingTree, + exact_powers_to_cache: Optional[curvature_blocks.ScalarOrSequence], + approx_powers_to_cache: Optional[curvature_blocks.ScalarOrSequence], + cache_eigenvalues: bool = False, + **kwargs: Any, + ) -> "BlockDiagonalCurvature.State": + blocks_vectors_sharding: list[utils.ShardingTree] = ( # pytype: disable=annotation-type-mismatch + self.params_vector_to_blocks_vectors(params_sharding)) + + shardings = [] + for block, vector_sharding in zip(self.blocks, blocks_vectors_sharding): + block_sharding = block.state_sharding( + vector_sharding, + exact_powers_to_cache=exact_powers_to_cache, + approx_powers_to_cache=approx_powers_to_cache, + cache_eigenvalues=cache_eigenvalues, + **kwargs, + ) + shardings.append(block_sharding) + + return BlockDiagonalCurvature.State(blocks_states=tuple(shardings)) + def params_vector_to_blocks_vectors( self, parameter_structured_vector: utils.Params, @@ -1097,7 +1118,6 @@ def multiply_matpower( power: Scalar, exact_power: bool, use_cached: bool, - pmap_axis_name: Optional[str], ) -> utils.Params: blocks_vectors = self.params_vector_to_blocks_vectors( @@ -1121,13 +1141,7 @@ def multiply_matpower( ) ) - if self._distributed_multiplies and pmap_axis_name is not None: - - result = utils.distribute_thunks(thunks, pmap_axis_name) - - else: - result = tuple(thunk() for thunk in thunks) - + result = tuple(thunk() for thunk in thunks) parameter_structured_result = self.blocks_vectors_to_params_vector(result) assert utils.abstract_objects_equal( @@ -1178,7 +1192,6 @@ def update_curvature_matrix_estimate( batch_size: Numeric, rng: PRNGKey, func_args: utils.FuncArgs, - pmap_axis_name: Optional[str], estimation_mode: Optional[str] = None, ) -> "BlockDiagonalCurvature.State": @@ -1210,7 +1223,7 @@ def update_blocks(vjp_vec_, state_, ema_old_, ema_new_): new_state.append(block_.update_curvature_matrix_estimate( block_state_, block_info_, ema_old_, ema_new_, - batch_size, pmap_axis_name)) + batch_size)) return BlockDiagonalCurvature.State(blocks_states=tuple(new_state)) @@ -1313,7 +1326,6 @@ def update_cache( exact_powers: Optional[curvature_blocks.ScalarOrSequence], approx_powers: Optional[curvature_blocks.ScalarOrSequence], eigenvalues: bool, - pmap_axis_name: Optional[str], ) -> "BlockDiagonalCurvature.State": identity_weight = utils.to_tuple_or_repeat(identity_weight, self.num_blocks) @@ -1332,39 +1344,7 @@ def update_cache( ) ) - if self._distributed_cache_updates and pmap_axis_name is not None: - - assert utils.in_pmap(pmap_axis_name) - - def filter_outputs(thunk, vals): - - # We must precompute the matches outside of the thunk itself, as the - # thunk will be traced separately from the current compiled context - # (since it's called within a lax.switch statement). - matches = jax.tree_util.tree_map(lambda o, v: o is v, thunk(), vals) - - def new_thunk(): - return jax.tree_util.tree_map( - lambda o, m: None if m else o, thunk(), matches - ) - return new_thunk - - # Create new thunks that only return the state arrays that they actually - # modify. This should reduce the communication costs associated with the - # syncs performed by utils.distribute_thunks. - filtered_thunks = tuple( - filter_outputs(thunk, block_state) - for thunk, block_state in zip(thunks, state.blocks_states)) - - new_states = utils.distribute_thunks(filtered_thunks, pmap_axis_name) - - # Restore all of the unmodified state arrays. - new_states = jax.tree_util.tree_map(lambda s, n: s if n is None else n, - state.blocks_states, new_states) - - else: - new_states = tuple(thunk() for thunk in thunks) - + new_states = tuple(thunk() for thunk in thunks) return BlockDiagonalCurvature.State(blocks_states=new_states) @utils.auto_scope_method @@ -1535,7 +1515,6 @@ def update_curvature_matrix_estimate( batch_size: Numeric, rng: PRNGKey, func_args: utils.FuncArgs, - pmap_axis_name: Optional[str], estimation_mode: Optional[str] = None, ) -> curvature_blocks.Full.State: @@ -1561,7 +1540,6 @@ def single_state_update( batch_size=1, rng=rng[index], func_args=args, - pmap_axis_name=pmap_axis_name, estimation_mode=estimation_mode, ) @@ -1574,7 +1552,6 @@ def update_cache( exact_powers: Optional[curvature_blocks.ScalarOrSequence], approx_powers: Optional[curvature_blocks.ScalarOrSequence], eigenvalues: bool, - pmap_axis_name: Optional[str], ) -> curvature_blocks.Full.State: block_state = self.blocks[0].update_cache( diff --git a/kfac_jax/_src/optimizer.py b/kfac_jax/_src/optimizer.py index 560f1e7..faf7db1 100644 --- a/kfac_jax/_src/optimizer.py +++ b/kfac_jax/_src/optimizer.py @@ -81,6 +81,22 @@ class State(Generic[Params], utils.State): data_seen: Numeric step_counter: Numeric + def as_dict(self) -> Dict[str, Any]: + return { + "velocities": self.velocities, + "estimator_state": self.estimator_state.as_dict(), + "damping": self.damping, + "data_seen": self.data_seen, + "step_counter": self.step_counter, + } + + @classmethod + def from_dict(cls, dict_representation: Dict[str, Any]) -> OptimizerState: + dict_representation["estimator_state"] = ( + curvature_estimator.BlockDiagonalCurvature.State.from_dict( + dict_representation["estimator_state"])) + return cls(**dict_representation) + def __init__( self, value_and_grad_func: ValueAndGradFunc, @@ -113,13 +129,12 @@ def __init__( register_only_generic: bool = False, patterns_to_skip: Sequence[str] = (), auto_register_kwargs: Optional[Dict[str, Any]] = None, - layer_tag_to_block_ctor: - Optional[Dict[str, curvature_estimator.CurvatureBlockCtor]] = None, - multi_device: bool = False, + layer_tag_to_block_ctor: Optional[ + Dict[str, curvature_estimator.CurvatureBlockCtor] + ] = None, debug: bool = False, batch_size_extractor: Callable[[Batch], Numeric] = utils.default_batch_size_extractor, - pmap_axis_name: str = "kfac_axis", forbid_setting_attributes_after_finalize: bool = True, modifiable_attribute_exceptions: Sequence[str] = (), include_norms_in_stats: bool = False, @@ -279,16 +294,11 @@ def __init__( that specific tag. See the documentation for :class:`~CurvatureEstimator` for a more detailed description. (Default: ``None``) - multi_device: Boolean. Whether to use pmap and run the optimizer on - multiple devices. (Default: ``False``) debug: Boolean. If neither the step or init functions should be jitted. - Note that this also overrides ``multi_device`` and prevents using pmap. (Default: ``False``) batch_size_extractor: A function that takes as input the function arguments and returns the batch size for a single device. (Default: ``kfac.utils.default_batch_size_extractor``) - pmap_axis_name: String. The name of the pmap axis to use when - ``multi_device`` is set to True. (Default: ``kfac_axis``) forbid_setting_attributes_after_finalize: Boolean. By default after the object is finalized, you can not set any of its properties. This is done in order to protect the user from making changes to the object @@ -318,8 +328,6 @@ def __init__( (Default: True) """ super().__init__( - multi_device=multi_device, - pmap_axis_name=pmap_axis_name if multi_device else None, debug=debug, forbid_setting_attributes_after_finalize= forbid_setting_attributes_after_finalize, @@ -347,7 +355,7 @@ def __init__( self._value_func_has_state = value_func_has_state self._value_func_has_rng = value_func_has_rng self._value_func: ValueFunc = convert_value_and_grad_to_value_func( - value_and_grad_func, + self._value_and_grad_func, has_aux=value_func_has_aux, ) self._l2_reg = jnp.asarray(l2_reg) @@ -413,10 +421,16 @@ def schedule_with_first_step_zero(global_step: Array) -> Array: batch_size_extractor=batch_size_extractor, ) - # Each subclass should call finalize on its own, so this gets called only - # for instances of exactly this class type. - if type(self) == Optimizer: # pylint: disable=unidiomatic-typecheck - self.finalize() + self._mesh = None + self._params_sharding = None + self._func_state_sharding = None + self._state_sharding = None + self._scalar_sharding = None + + @property + def state_sharding(self) -> jax.sharding.NamedSharding: + assert self._state_sharding is not None + return self._state_sharding @property def num_burnin_steps(self) -> int: @@ -527,12 +541,9 @@ def verify_args_and_get_step_counter( "not pass a value to the step function.") if global_step_int is None: - if self.multi_device: - return int(utils.get_first(step_counter)) - else: - return int(step_counter) - - return global_step_int + return int(step_counter) + else: + return global_step_int @utils.staged def _setup_state_and_schedules( @@ -609,7 +620,6 @@ def _update_estimator_curvature( batch_size=self._batch_size_extractor(func_args[-1]), rng=rng, func_args=func_args, - pmap_axis_name=self.pmap_axis_name ) @utils.auto_scope_method @@ -644,7 +654,6 @@ def _maybe_update_inverse_cache( exact_powers=self._exact_powers_to_cache, approx_powers=self._approx_powers_to_cache, eigenvalues=False, - pmap_axis_name=self.pmap_axis_name, ), lambda state_: state_, state.estimator_state @@ -669,7 +678,6 @@ def _compute_preconditioned_gradient( identity_weight=self.l2_reg + damping, exact_power=self._use_exact_inverses, use_cached=self._use_cached_inverses, - pmap_axis_name=self.pmap_axis_name, ) if self._norm_constraint is not None: @@ -765,14 +773,73 @@ def _update_damping( new_loss = self.compute_loss_value(new_func_args) - # Sync - new_loss = utils.pmean_if_pmap(new_loss, self.pmap_axis_name) - damping, rho = self._compute_new_damping_and_rho( old_loss, new_loss, quad_change, old_damping) return damping, rho, new_loss + def _finalize( + self, + params: Params, + rng: PRNGKey, + batch: Batch, + func_state: Optional[FuncState] = None, + ): + if not self._estimator.finalized: + self._estimator.finalize( + make_func_args( + params=params, + func_state=func_state, + rng=rng, + batch=self._batch_process_func(batch), + has_state=self._value_func_has_state, + has_rng=self._value_func_has_rng, + ), + ) + + # Compute shardings + self._params_sharding = utils.get_sharding(params) + self._mesh = jax.tree_util.tree_leaves(self._params_sharding)[0].mesh + self._func_state_sharding = utils.get_sharding(func_state) + self._scalar_sharding = jax.sharding.NamedSharding( + self._mesh, jax.sharding.PartitionSpec() + ) + self._state_sharding = Optimizer.State( + velocities=self._params_sharding, + estimator_state=self.estimator.state_sharding( + self._params_sharding, + exact_powers_to_cache=self._exact_powers_to_cache, + approx_powers_to_cache=self._approx_powers_to_cache, + cache_eigenvalues=False, + ), + damping=self._scalar_sharding, + data_seen=self._scalar_sharding, + step_counter=self._scalar_sharding, + ) + + def _shard_params(self, params: Params) -> Params: + return jax.lax.with_sharding_constraint(params, self._params_sharding) + + def _shard_state(self, state: OptimizerState) -> OptimizerState: + return jax.lax.with_sharding_constraint(state, self._state_sharding) + + def _shard_func_state(self, func_state: FuncState) -> FuncState: + if func_state is None: + return + return jax.lax.with_sharding_constraint( + func_state, self._func_state_sharding) + + def _shard_func_state_acc( + self, + acccumulator: utils.MultiChunkAccumulator, + ) -> utils.MultiChunkAccumulator: + if self._func_state_sharding is None: + return acccumulator + return jax.lax.with_sharding_constraint( + acccumulator, + utils.MultiChunkAccumulator.state_sharding(self._func_state_sharding), + ) + @utils.staged def _init( self, @@ -780,11 +847,11 @@ def _init( rng: PRNGKey, batch: Batch, func_state: Optional[FuncState] = None, - ) -> "Optimizer.State": + ) -> OptimizerState: """A staged function to initialize the optimizer state .""" - return Optimizer.State( - velocities=jax.tree_util.tree_map(jnp.zeros_like, params), + return self._shard_state(Optimizer.State( + velocities=params, estimator_state=self.estimator.init( rng=rng, func_args=make_func_args( @@ -803,7 +870,7 @@ def _init( if self._use_adaptive_damping else None), data_seen=jnp.array(0, dtype=int), step_counter=jnp.array(0, dtype=int) - ) + )) def init( self, @@ -813,11 +880,15 @@ def init( func_state: Optional[FuncState] = None, ) -> "Optimizer.State": """Initializes the optimizer and returns the appropriate optimizer state.""" - if not self.finalized: self.finalize(params, rng, batch, func_state) - return self._init(params, rng, batch, func_state) + return self._init( + params, + rng, + batch, + func_state, + ) @functools.partial(utils.staged, donate_argnums=[1, 3, 5]) def _burnin( @@ -849,7 +920,7 @@ def _burnin( accumulator.add(func_state) - return state, accumulator + return self._shard_state(state), self._shard_func_state_acc(accumulator) def burnin( self, @@ -863,16 +934,19 @@ def burnin( """Runs all burnin steps required.""" if num_steps > 0: - rng = self._rng_split(rng, num_steps) - - accumulator = utils.MultiChunkAccumulator.zeros_like( - func_state, self.multi_device) + rng = jax.random.split(rng, num_steps) + accumulator = utils.MultiChunkAccumulator.zeros_like(func_state) for rng_i in rng: batch = next(data_iterator) - state, accumulator = self._burnin( - params, state, rng_i, batch, func_state, accumulator) + params, + state, + rng_i, + batch, + func_state, + accumulator, + ) func_state = accumulator.value_and_clear() @@ -917,9 +991,7 @@ def _step( # Compute loss and gradients loss, grads, func_state, aux = self._compute_loss_and_grads(func_args) - - # Sync - loss, grads = utils.pmean_if_pmap((loss, grads), self.pmap_axis_name) + func_state = self._shard_func_state(func_state) if self._include_norms_in_stats: grad_norm = utils.norm(grads) @@ -964,7 +1036,7 @@ def _step( update_norm_per_param = utils.per_parameter_norm(delta, "update_norm") # Update parameters - params = jax.tree_util.tree_map(jnp.add, params, delta) + params = self._shard_params(jax.tree_util.tree_map(jnp.add, params, delta)) # Optionally compute the reduction ratio and update the damping if self._use_adaptive_damping: @@ -979,17 +1051,11 @@ def _step( else: new_loss, rho = jnp.nan, jnp.nan - # Compute per-device and total batch size - batch_size = self._batch_size_extractor(func_args[-1]) - - if self.multi_device: - total_batch_size = batch_size * jax.device_count() - else: - total_batch_size = batch_size - # Update data seen and step counter - state.data_seen = state.data_seen + total_batch_size + batch_size = self._batch_size_extractor(func_args[-1]) + state.data_seen = state.data_seen + batch_size state.step_counter = state.step_counter + 1 + state = self._shard_state(state) # Statistics with useful information # Unlike other norm stats, sq_norm_scaled_grads has to be computed if @@ -998,7 +1064,7 @@ def _step( # no other grad stats are desired. stats = dict( step=state.step_counter, - batch_size=jnp.asarray(total_batch_size, dtype=jnp.int32), + batch_size=jnp.asarray(batch_size, dtype=jnp.int32), data_seen=state.data_seen, loss=loss, new_loss=new_loss, @@ -1025,6 +1091,8 @@ def _step( stats.update(precon_grad_norm_per_param) stats.update(update_norm_per_param) + stats = jax.lax.with_sharding_constraint(stats, self._scalar_sharding) + if self._value_func_has_state: return params, state, func_state, stats else: @@ -1078,7 +1146,6 @@ def step( * stats is a dictionary of useful statistics including the loss. """ - if (data_iterator is None) == (batch is None): raise ValueError("Exactly one of the arguments ``data_iterator`` and " "``batch`` must be provided.") @@ -1095,7 +1162,7 @@ def step( if data_iterator is not None: - rng, burnin_rng = self._rng_split(rng, 2) + rng, burnin_rng = jax.random.split(rng) state, func_state = self.burnin( num_steps=self.num_burnin_steps, params=params, @@ -1108,8 +1175,16 @@ def step( if data_iterator is not None: batch = next(data_iterator) - return self._step(params, state, rng, batch, func_state, - learning_rate, momentum, damping) + return self._step( + params, + state, + rng, + batch, + func_state, + learning_rate, + momentum, + damping, + ) def compute_l2_quad_matrix( self, @@ -1169,7 +1244,6 @@ def c_times_v(v): identity_weight=0.0, exact_power=True, use_cached=False, - pmap_axis_name=self.pmap_axis_name, ) c_vectors = [c_times_v(v_i) for v_i in vectors] @@ -1240,13 +1314,6 @@ def _solve_quad_model( A_no_diag, D, b = quad_model_parameters A = A_no_diag + self.compute_l2_quad_matrix(vectors) A_damped = A + damping * D - - # Sync. - # TODO(jamesmartens, botev): we should perform this earlier since it's - # dangerous to have the convention of doing it right before use (especially - # since the convention everywhere else is to sync quantities immediately - # after they are first computed). - A, A_damped, b = utils.pmean_if_pmap((A, A_damped, b), self.pmap_axis_name) # This needs explicit annotation A_damped: Array diff --git a/kfac_jax/_src/tag_graph_matcher.py b/kfac_jax/_src/tag_graph_matcher.py index 5d721d0..e8b26ea 100644 --- a/kfac_jax/_src/tag_graph_matcher.py +++ b/kfac_jax/_src/tag_graph_matcher.py @@ -443,6 +443,7 @@ def graph(self) -> JaxprGraph: clean_broadcasts=True, ) object.__setattr__(self, "_graph", graph) + assert self._graph is not None return self._graph def tag_ctor( diff --git a/kfac_jax/_src/utils/__init__.py b/kfac_jax/_src/utils/__init__.py index 13f91c0..cc62ef4 100644 --- a/kfac_jax/_src/utils/__init__.py +++ b/kfac_jax/_src/utils/__init__.py @@ -29,6 +29,7 @@ DType = types.DType PyTree = types.PyTree ArrayTree = types.ArrayTree +ShardingTree = types.ShardingTree TArrayTree = types.TArrayTree Params = types.Params Batch = types.Batch @@ -48,6 +49,7 @@ del types # misc +get_sharding = misc.get_sharding to_tuple_or_repeat = misc.to_tuple_or_repeat first_dim_is_size = misc.first_dim_is_size fake_element_from_iterator = misc.fake_element_from_iterator diff --git a/kfac_jax/_src/utils/accumulators.py b/kfac_jax/_src/utils/accumulators.py index db94b49..2011d7a 100644 --- a/kfac_jax/_src/utils/accumulators.py +++ b/kfac_jax/_src/utils/accumulators.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """K-FAC for accumulating statistics.""" -from typing import Any, Optional, Generic +from typing import Optional, Generic, Any, Dict import jax import jax.numpy as jnp from kfac_jax._src.utils import misc -from kfac_jax._src.utils import parallel from kfac_jax._src.utils import types Array = types.Array @@ -30,16 +29,17 @@ @misc.pytree_dataclass -class WeightedMovingAverage(Generic[TArrayTree]): +class WeightedMovingAverage(Generic[TArrayTree], misc.State): """A wrapped class for an arbitrary weighted moving average.""" weight: Numeric raw_value: Optional[TArrayTree] @property - def value(self) -> TArrayTree: + def value(self) -> Optional[TArrayTree]: """The value of the underlying arrays data structure.""" - if self.raw_value is None: - raise ValueError("`raw_value` has not been set yet.") + if types.tree_is_empty(self.raw_value): + return self.raw_value + return jax.tree_util.tree_map(lambda x: x / self.weight, self.raw_value) def update( @@ -61,12 +61,6 @@ def update( value, ) - def sync(self, pmap_axis_name: Optional[str]): - """Syncs the underlying array across devices.""" - if self.raw_value is None: - raise ValueError("`raw_value` has not been set yet.") - self.raw_value = parallel.pmean_if_pmap(self.raw_value, pmap_axis_name) - def clear(self, value_to_none: bool = False): """Resets the weighted average.""" self.weight = jnp.zeros_like(self.weight) @@ -78,7 +72,7 @@ def value_and_clear(self) -> TArrayTree: self.clear() return value - def copy(self) -> "WeightedMovingAverage[TArrayTree]": + def copy(self): """Returns a copy of the PyTree structure (but not the JAX arrays).""" (flattened, structure) = jax.tree_util.tree_flatten(self) return jax.tree_util.tree_unflatten(structure, flattened) @@ -90,7 +84,7 @@ def zeros_array( dtype: Optional[DType] = None, ) -> "WeightedMovingAverage[Array]": """Initializes a `WeightedMovingAverage` with a single array of zeros.""" - return WeightedMovingAverage( + return cls( # pytype: disable=wrong-keyword-args weight=jnp.zeros([], dtype=dtype), raw_value=jnp.zeros(shape, dtype=dtype), ) @@ -98,178 +92,50 @@ def zeros_array( @classmethod def zeros_like(cls, value: TArrayTree) -> "WeightedMovingAverage[TArrayTree]": """Initializes a `WeightedMovingAverage` with zeros structure like `value`.""" - return WeightedMovingAverage( - weight=jnp.array( - 0.0, dtype=types.get_float_dtype_and_check_consistency(value) - ), - raw_value=jax.tree_util.tree_map(jnp.zeros_like, value), + dtype = types.get_float_dtype_and_check_consistency(value) + weight = jnp.array(0.0, dtype=dtype) + if value is not None: + weight = jax.device_put(weight, jax.sharding.NamedSharding( + jax.tree_leaves(value)[0].sharding.mesh, + jax.sharding.PartitionSpec(), + )) + return cls( # pytype: disable=wrong-keyword-args + weight=weight, + raw_value=misc.zeros_like_with_sharding(value), ) @classmethod def empty(cls, dtype: Optional[DType] = None) -> "WeightedMovingAverage[Any]": """Returns an empty moving average instance.""" weight = jnp.zeros([]) if dtype is None else jnp.zeros([], dtype=dtype) - return WeightedMovingAverage(weight=weight, raw_value=None) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.weight!r}, {self.raw_value!r})" - - -class MultiChunkAccumulator(Generic[TArrayTree]): - """Statistics accumulation, abstracted over multiple chunks.""" - - def __init__( - self, - init_obj_value: Optional[TArrayTree], - weight: Numeric, - multi_device: bool, - ): - """Initializes an accumulator instance with the provided object and counter. - - Args: - init_obj_value: The initial value of the accumulator. - weight: The initial weight, which specifies how many samples are assumed - to have been already counted in the initial value of the accumulator. - multi_device: Whether the objects that are accumulated are outputs of a - multi-device computation (e.g. `jax.pmap`). - """ - self._accumulator = init_obj_value - self._weight = weight - self._multi_device = multi_device - - @property - def accumulator(self) -> TArrayTree: - """The current value of the underlying not-normalized accumulator.""" - return self._accumulator - - @property - def weight(self) -> Numeric: - """The current normalization weight of the underlying accumulator.""" - return self._weight - - @property - def multi_device(self) -> bool: - """Whether the accumulator is the output of a multi-device computation.""" - return self._multi_device - - @property - def value(self) -> TArrayTree: - """The current normalized value of the accumulator.""" - - if types.tree_is_empty(self.accumulator): - return self.accumulator - - if self._multi_device: - return parallel.pmap_sync_and_divide_value(self.accumulator, self.weight) - else: - return parallel.jit_sync_and_divide_value(self.accumulator, self.weight) - - def clear(self) -> None: - """Sets the underlying accumulator and weight to `None`.""" - self._accumulator = None - self._weight = None - - def value_and_clear(self) -> TArrayTree: - """Retrieves the normalized value of the accumulator and clears it.""" - value = self.value - self.clear() - return value - - def add(self, value_obj: TArrayTree, weight: Numeric = 1): - """Adds an element to the moving average and the max. - - The exact update equation for the statistics are: - raw_value_t = raw_value_{t-1} + value_obj * weight - weight_t = weight_{t-1} + weight - - Args: - value_obj: The value of the object, which scaled by `weight` will be added - to the accumulator. - weight: The relative weight of the `value_obj`. - """ - - value_obj = jax.tree_util.tree_map(lambda x: x * weight, value_obj) - - if self._accumulator is None: - - self._accumulator = value_obj - - if isinstance(weight, types.SCALAR_TYPES): - self._weight = jnp.full_like(self._weight, weight) - - elif not isinstance(weight, jax.Array): - raise ValueError("`weight` should be an instance of float, int or " - "jax.Array.") - - elif self._weight.shape != weight.shape: # pytype: disable=attribute-error # numpy-scalars - raise ValueError("If `weight` is an `jnp.ndarray` then should have the " - "same shape as the weight of the accumulator.") - else: - self._weight = weight - - return - - if not types.tree_is_empty(self._accumulator): - - if types.tree_is_empty(value_obj): - raise ValueError("The provided `value_obj` has an empty PyTree " - "structure, but the accumulator has been initialized " - "with a non-empty PyTree object.") - - self._accumulator = jax.tree_util.tree_map( - jnp.add, self._accumulator, value_obj) - - elif not types.tree_is_empty(value_obj): - - raise ValueError("The provided `value_obj` has a non-empty PyTree " - "structure, but the accumulator has been initialized " - "with an empty PyTree object.") - - self._weight = self._weight + weight + return cls(weight=weight, raw_value=None) # pytype: disable=wrong-keyword-args @classmethod - def zeros_like( + def state_sharding( cls, - obj: TArrayTree, - multi_device: bool - ) -> "MultiChunkAccumulator[TArrayTree]": - """Creates a zero initialized accumulator as `obj`.""" - - if multi_device: - value = (parallel.pmap_zeros_like(obj) - if not types.tree_is_empty(obj) else obj) - weight = parallel.replicate_all_local_devices( - jnp.zeros([], dtype=jnp.int32)) - else: - value = (parallel.jit_zeros_like(obj) - if not types.tree_is_empty(obj) else obj) - weight = jnp.zeros([], dtype=jnp.int32) - - return cls(value, weight, multi_device) - - @classmethod - def empty(cls, multi_device: bool) -> "MultiChunkAccumulator[Any]": - """Creates an empty accumulator.""" - - weight = jnp.zeros([], dtype=jnp.int32) + sharding: jax.sharding.NamedSharding, + ) -> "WeightedMovingAverage[jax.sharding.NamedSharding]": + return cls( # pytype: disable=wrong-keyword-args + weight=jax.sharding.NamedSharding( + sharding.mesh, jax.sharding.PartitionSpec() + ), + raw_value=sharding, + ) - if multi_device: - weight = parallel.replicate_all_local_devices(weight) + def as_dict(self) -> Dict[str, Any]: + return {"weight": self.weight, "raw_value": self.raw_value} - return cls(None, weight, multi_device) + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.weight!r}, {self.raw_value!r})" - def __repr__(self): - return (f"{self.__class__.__name__}({self._accumulator!r}, " - f"{self._weight!r}, {self._multi_device})") - def copy(self): - """Returns a copy of the PyTree structure (but not the JAX arrays).""" - (flattened, structure) = jax.tree_util.tree_flatten(self) - return jax.tree_util.tree_unflatten(structure, flattened) +@misc.pytree_dataclass +class MultiChunkAccumulator(WeightedMovingAverage[TArrayTree]): + """Statistics accumulation, abstracted over multiple chunks.""" + def add(self, value: TArrayTree, weight: Numeric = 1): + """Adds an element to the moving average and the max.""" + return self.update(value, 1, weight) -jax.tree_util.register_pytree_node( - MultiChunkAccumulator, - lambda x: ((x.accumulator, x.weight), (x.multi_device,)), - lambda fixed, arrays: MultiChunkAccumulator(*arrays, *fixed) -) + def clear(self, value_to_none: bool = True): + return super().clear(value_to_none=value_to_none) diff --git a/kfac_jax/_src/utils/misc.py b/kfac_jax/_src/utils/misc.py index ba5b2a9..7708ec6 100644 --- a/kfac_jax/_src/utils/misc.py +++ b/kfac_jax/_src/utils/misc.py @@ -15,7 +15,7 @@ import abc import dataclasses import functools -from typing import Any, Iterator, Sequence, Type, Tuple, Union +from typing import Any, Iterator, Sequence, Type, Tuple, Union, Dict import jax import jax.numpy as jnp @@ -25,6 +25,18 @@ Numeric = types.Numeric ArrayTree = types.ArrayTree TArrayTree = types.TArrayTree +ShardingTree = types.ShardingTree + + +def get_sharding(x: ArrayTree) -> types.ShardingTree: + return jax.tree_util.tree_map(lambda x: x.sharding, x) + + +def zeros_like_with_sharding(x: TArrayTree) -> TArrayTree: + def zero_like_array(x: Array) -> Array: + return jax.device_put(jnp.zeros_like(x), x.sharding) + + return jax.tree_util.tree_map(zero_like_array, x) def fake_element_from_iterator( @@ -47,7 +59,8 @@ def fake_element_from_iterator( equivalent iterator to the input one. """ init_element = next(iterator) - fake_element = jax.tree_util.tree_map(jnp.zeros_like, init_element) + fake_element = zeros_like_with_sharding(init_element) + def equivalent_iterator() -> Iterator[ArrayTree]: yield init_element # For some reason unknown to us, "yield from" can fail in certain @@ -121,13 +134,23 @@ def unflatten(_: Any, args: Sequence[Any]) -> Any: @pytree_dataclass -class State(object): +class State(abc.ABC): + """Abstract class for optimizer state.""" def copy(self): """Returns a copy of the PyTree structure (but not the JAX arrays).""" (flattened, structure) = jax.tree_util.tree_flatten(self) return jax.tree_util.tree_unflatten(structure, flattened) + @abc.abstractmethod + def as_dict(self) -> Dict[str, Any]: + """Returns a recursively constructed dictionary of the state.""" + + @classmethod + def from_dict(cls, dict_rep: Dict[str, Any]) -> "State": + """Returns a recursively reconstructed dictionary of the state.""" + return cls(**dict_rep) + class Finalizable(abc.ABC): """A mixin for classes that can "finalize" their attributes. diff --git a/kfac_jax/_src/utils/staging.py b/kfac_jax/_src/utils/staging.py index 7073987..32cc657 100644 --- a/kfac_jax/_src/utils/staging.py +++ b/kfac_jax/_src/utils/staging.py @@ -13,17 +13,16 @@ # limitations under the License. """K-FAC utilities for classes with staged methods.""" import functools -import operator from typing import Any, Callable, Optional, Sequence, Tuple, Union import jax -import jax.numpy as jnp from kfac_jax._src.utils import misc from kfac_jax._src.utils import parallel from kfac_jax._src.utils import types TArrayTree = types.TArrayTree +ArrayTree = types.ArrayTree class WithStagedMethods(misc.Finalizable): @@ -54,19 +53,12 @@ def __exit__(self, *_): def __init__( self, - multi_device: bool = False, - pmap_axis_name: Optional[str] = None, debug: bool = False, **parent_kwargs: Any, ): """Initializes the instance. Args: - multi_device: Whether any of decorated staged methods are to be run on a - single or multiple devices. If this is set to `True` than any call - would internally be delegated to `jax.pmap` and otherwise to `jax.jit`. - pmap_axis_name: The name of the pmap axis to use when running on - multiple devices. This is required if `multi_device=True`. debug: If this is set `True` than any call to a stage method would directly call the method and would not stage/compile it. **parent_kwargs: Any additional keyword arguments for the parent class. @@ -78,26 +70,9 @@ def __init__( parent_kwargs["excluded_attribute_names"] = ("_in_staging",) super().__init__(**parent_kwargs) - - if multi_device and not isinstance(pmap_axis_name, str): - raise ValueError("When `multi_device=True` you must pass in a string for " - "`pmap_axis_name`.") - - self._multi_device = multi_device - self._pmap_axis_name = pmap_axis_name self._debug = debug self._in_staging = False - @property - def multi_device(self) -> bool: - """Indicates whether staged method will be run across multiple devices.""" - return self._multi_device - - @property - def pmap_axis_name(self) -> Optional[str]: - """The name of the `jax.pmap` axis to use for staged methods.""" - return self._pmap_axis_name - @property def debug(self) -> bool: """Whether staged methods would be run in 'debug' mode.""" @@ -114,21 +89,18 @@ def staging_context(self) -> "StagingContext": def get_first(self, obj: TArrayTree) -> TArrayTree: """Indexes the `obj` PyTree leaves over leading axis if `multi_device`.""" - return parallel.get_first(obj) if self.multi_device else obj + return obj def copy_obj(self, obj: TArrayTree) -> TArrayTree: """Copies the object.""" - if self.multi_device: - return parallel.pmap_copy_obj(obj) - else: - return parallel.copy_obj(obj) + return parallel.copy_obj(obj) - def replicate(self, obj: TArrayTree) -> TArrayTree: - """Replicates the object to all local devices if `multi_device`.""" - if self.multi_device: - return parallel.replicate_all_local_devices(obj) - else: - return obj + # def replicate(self, obj: PyTree) -> PyTree: + # """Replicates the object to all local devices if `multi_device`.""" + # if self.multi_device: + # return parallel.replicate_all_local_devices(obj) + # else: + # return obj def staged( @@ -173,76 +145,24 @@ def try(self, x): else: donate_argnums: Tuple[int, ...] = tuple(donate_argnums) - bcast_argnums = static_argnums or () - # shift static_argnums by 1 and include instance (self) static_argnums = (0,) + tuple(i + 1 for i in (static_argnums or ())) # shift donate_argnums by 1 and include state donate_argnums = tuple(i + 1 for i in donate_argnums) - pmap_funcs = {} - jitted_func = jax.jit(method, - static_argnums=static_argnums, - donate_argnums=donate_argnums) + jitted_method = jax.jit( + method, donate_argnums=donate_argnums, static_argnums=static_argnums) @functools.wraps(method) - def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree: + def decorated( + instance: "WithStagedMethods", + *args: Any, + ) -> ArrayTree: - if instance.in_staging: + if instance.in_staging or instance.debug: return method(instance, *args) with instance.staging_context(): - if instance.multi_device and instance.debug: - # In this case we want to call `method` once for each device index. - # Note that this might not always produce sensible behavior, and will - # depend on the details of the method and if it has side effects on the - # state of the class. - - outs = [] - non_bcast_args = [args[i] if i not in bcast_argnums else None - for i in range(len(args))] - - for i in range(jax.local_device_count()): - - non_bcast_args_i = jax.tree_util.tree_map( - operator.itemgetter(i), non_bcast_args) - - args_i = [ - non_bcast_args_i[j] if j not in bcast_argnums else args[j] - for j in range(len(args)) - ] - - outs.append(method(instance, *args_i)) - - outs = jax.tree_util.tree_map(lambda *args_: jnp.stack(args_), *outs) - - elif instance.debug: - outs = method(instance, *args) - - elif instance.multi_device: - - new_args = list(args) - - for i in range(len(args)): - if i + 1 not in static_argnums: - new_args[i] = parallel.check_and_fix_format_for_pmap(args[i]) - - func = pmap_funcs.get(instance.pmap_axis_name) - - if func is None: - func = jax.pmap( - method, - static_broadcasted_argnums=static_argnums, - donate_argnums=donate_argnums, - axis_name=instance.pmap_axis_name, - ) - pmap_funcs[instance.pmap_axis_name] = func - - outs = func(instance, *new_args) - - else: - outs = jitted_func(instance, *args) - - return outs + return jitted_method(instance, *args) return decorated diff --git a/kfac_jax/_src/utils/types.py b/kfac_jax/_src/utils/types.py index d81171e..beb605a 100644 --- a/kfac_jax/_src/utils/types.py +++ b/kfac_jax/_src/utils/types.py @@ -27,6 +27,7 @@ DType = jnp.dtype PyTree = Union[T, Sequence["PyTree[T]"], Mapping[str, "PyTree[T]"]] ArrayTree = PyTree[Array] +ShardingTree = PyTree[jax.sharding.NamedSharding] TArrayTree = TypeVar("TArrayTree", bound=ArrayTree) Params = TypeVar("Params", bound=ArrayTree) Batch = TypeVar("Batch", bound=ArrayTree) diff --git a/tests/test_estimator.py b/tests/test_estimator.py index 47280b1..8f50836 100644 --- a/tests/test_estimator.py +++ b/tests/test_estimator.py @@ -71,7 +71,6 @@ def compute_exact_approx_curvature( batch_size=batch_size, rng=rng, func_args=func_args, - pmap_axis_name="i", estimation_mode=f"{curvature_type}_exact", ) @@ -497,7 +496,6 @@ def test_eigenvalues( exact_powers=-1, approx_powers=None, eigenvalues=True, - pmap_axis_name=None, ) block_eigenvalues = estimator.block_eigenvalues(cached_state, True) @@ -591,15 +589,14 @@ def test_matmul( exact_powers=-1, approx_powers=None, eigenvalues=True, - pmap_axis_name=None, ) v = init_func(init_key2, data) - m_v = estimator.multiply(state, v, e, True, True, None) - m_inv_v = estimator.multiply_inverse(cached_state, v, e, True, True, None) + m_v = estimator.multiply(state, v, e, True, True) + m_inv_v = estimator.multiply_inverse(cached_state, v, e, True, True) # Check cached and non-cached are the same - m_inv_v2 = estimator.multiply_inverse(state, v, e, True, False, None) + m_inv_v2 = estimator.multiply_inverse(state, v, e, True, False) self.assertAllClose(m_inv_v, m_inv_v2, atol=1e-5, rtol=1e-4) block_vectors = estimator.params_vector_to_blocks_vectors(v)