diff --git a/axlearn/common/trainer_config_modifier.py b/axlearn/common/trainer_config_modifier.py index d647e1a06..63d0e5c4b 100644 --- a/axlearn/common/trainer_config_modifier.py +++ b/axlearn/common/trainer_config_modifier.py @@ -2,12 +2,13 @@ """Defines trainer config modifiers, which will be used in model definitions.""" -from typing import Dict, Sequence, Union +from typing import Callable, Dict, NamedTuple, Sequence, Union from axlearn.common import config from axlearn.common.base_layer import RematSpec from axlearn.common.config import ( REQUIRED, + ConfigBase, ConfigModifier, ConfigOr, Required, @@ -17,7 +18,53 @@ from axlearn.common.gradient_accumulation import with_minibatch_steps from axlearn.common.metrics import MetricAccumulator from axlearn.common.trainer import SpmdTrainer -from axlearn.common.utils import HybridMeshShape, MeshShape +from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec + + +class _FoundModule(NamedTuple): + """Module found in recursive search of matching module.""" + + # The module found + module: ConfigModifier.Config + # The parent of the module found + parent_module: ConfigModifier.Config + # Key of the found module in parent + key_in_parent: ConfigModifier.Config + + +def _find_target_module(module_name: str, cfg: SpmdTrainer.Config) -> _FoundModule: + """Recursively search for the target module matching module_name in provided cfg. + + Args: + module_name: Name of the target module + cfg: The trainer config to be searched for module_name + + Raises: + ValueError: The module_name is not found. + + Returns: + A Tuple(curr_module, key_in_parent, parent_module) + curr_module: Module found + parent_module: The parent module + key_in_parent: Key in parent for the found module + """ + + # Here we assume x.y.z format. + # One example would be model.decoder.transformer.layer. + target_modules = module_name.split(".") + curr_module = cfg + key_in_parent = None + parent_module = None + + for target_module in target_modules: + if not hasattr(curr_module, target_module): + raise ValueError(f"{target_module} is not found in {curr_module}.") + parent_module = curr_module + key_in_parent = target_module + curr_module = getattr(curr_module, target_module) + return _FoundModule( + module=curr_module, parent_module=parent_module, key_in_parent=key_in_parent + ) class GradientAccumulationModifier(ConfigModifier): @@ -100,18 +147,11 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: """ for module_name, remat_spec in self._remat_policies.items(): - # Here we assume x.y.z format. - # One example would be model.decoder.transformer.layer. - target_modules = module_name.split(".") - curr_module = cfg - for target_module in target_modules: - if not hasattr(curr_module, target_module): - raise ValueError(f"{target_module} is not found in {curr_module}.") - curr_module = getattr(curr_module, target_module) + found_module = _find_target_module(module_name, cfg) # Here we assume all modules have remat_spec attribute. - if not hasattr(curr_module, "remat_spec"): - raise ValueError(f"{curr_module} does not have remat_spec attribute") - curr_module.remat_spec = remat_spec + if not hasattr(found_module.module, "remat_spec"): + raise ValueError(f"{found_module.module} does not have remat_spec attribute") + found_module.module.remat_spec = remat_spec return cfg @@ -146,6 +186,113 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: return cfg +class ModelConfigModifier(ConfigModifier): + """Update the model config for the trainer config.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure ModelConfigModifier. + + Attributes: + model_cfg_modifications: A mapping from module path + (e.g. `model.decoder.transformer.layer`) to a Config. + """ + + model_cfg_modifications: Dict[str, Callable[[ConfigBase], ConfigBase]] = {} + + def __init__(self, cfg: Config): + super().__init__(cfg) + self._model_cfg_modifications = self.config.model_cfg_modifications + + def _merge_configs(self, target_cfg: ConfigBase, found_module: ConfigBase) -> ConfigBase: + """Merge configurations from the config being replaced on a best effort basis. + + Args: + target_cfg: configuration that will replace found_module. + found_module: existing configuration whose class will be replaced + but it's confguration will be merged with target_cfg. + + Returns: + The modified config. + + """ + for key in target_cfg.keys(): + if key == "klass": + continue + elif hasattr(found_module.module, key) and hasattr(target_cfg, key): + setattr(target_cfg, key, getattr(found_module.module, key)) + return target_cfg + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Overwrite the model config of the specified modules. + + Args: + cfg: The trainer config to be modified. + + Raises: + ValueError: The target module is not found. + + Returns: + The modified trainer config. + """ + + # Iterate over modules in the mapping, modules are sorted based on module name length. + # This ensures parent is modified before children to avoid missing modifications. + for module_name, model_cfg in sorted( + self._model_cfg_modifications.items(), key=lambda item: len(item[0]) + ): + found_module = _find_target_module(module_name, cfg) + + model_cfg = self._merge_configs(model_cfg, found_module) + # Replace in the parent config + setattr(found_module.parent_module, found_module.key_in_parent, model_cfg) + return cfg + + +class PartitionSpecModifier(ConfigModifier): + """Update the partition spec attribute for the specified modules.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure PartitionSpecModifier. + + Attributes: + partition_specs: A nested mapping from module path + (e.g. `model.decoder.transformer.layer`) to another + mapping of model attribute to PartitionSpec. + """ + + partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED + + def __init__(self, cfg: Config): + super().__init__(cfg) + self._attribute_dicts = self.config.partition_specs + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Update the partition_spec attributes for the specified modules. + + Args: + cfg: The trainer config to be modified. + + Raises: + ValueError: The target module is not found. + ValueError: The partition_spec attribute is not found. + + Returns: + The modified trainer config. + """ + for module_name, partition_spec_dict in self._attribute_dicts.items(): + found_module = _find_target_module(module_name, cfg) + for partition_spec_name, partition_spec in partition_spec_dict.items(): + if not hasattr(found_module.module, partition_spec_name): + raise ValueError( + f"{found_module.module} does not have {partition_spec_name} attribute" + ) + setattr(found_module.module, partition_spec_name, partition_spec) + + return cfg + + class ChainConfigModifier(ConfigModifier): """Chain multiple config modifiers together.""" diff --git a/axlearn/common/trainer_config_modifier_test.py b/axlearn/common/trainer_config_modifier_test.py index ccfe00823..28c88c052 100644 --- a/axlearn/common/trainer_config_modifier_test.py +++ b/axlearn/common/trainer_config_modifier_test.py @@ -5,13 +5,16 @@ import jax from absl.testing import absltest -from axlearn.common import test_utils +from axlearn.common import causal_lm, test_utils +from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer from axlearn.common.base_layer import RematSpec from axlearn.common.trainer import SpmdTrainer from axlearn.common.trainer_config_modifier import ( ChainConfigModifier, GradientAccumulationModifier, MeshShapeModifier, + ModelConfigModifier, + PartitionSpecModifier, RematSpecModifier, ) from axlearn.common.trainer_test import DummyModel @@ -65,6 +68,88 @@ def test_remat_policy_override(self): _ = cfg_modifier(cfg) +class ModelConfigModifierTest(test_utils.TestCase): + def test_model_config_override(self): + cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config()) + self.assertTrue( + str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config()) + ) + + cfg_modifier = ( + ModelConfigModifier.default_config() + .set( + model_cfg_modifications={ + "model.decoder.transformer": RepeatedTransformerLayer.default_config(), + } + ) + .instantiate() + ) + + cfg = cfg_modifier(cfg) + # The default StackedTransformerLayer should have changed to RepeatedTransformerLayer + self.assertTrue( + str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config()) + ) + cfg_modifier = ( + ModelConfigModifier.default_config() + .set( + model_cfg_modifications={ + "model.decoder.unknown": RepeatedTransformerLayer.default_config(), + } + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(ValueError, "unknown is not found in.*"): + _ = cfg_modifier(cfg) + + +class PartitionSpecModifierTest(test_utils.TestCase): + def test_partition_spec_override(self): + cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + }, + ) + .instantiate() + ) + cfg = cfg_modifier(cfg) + self.assertTrue( + str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")""" + ) + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + "model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + }, + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(ValueError, "unknown is not found in.*"): + _ = cfg_modifier(cfg) + + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": { + "param_partition_spec": ("model", ("expert", "fsdp", "seq")), + "unknown_partition_spec": ("model", ("expert", "fsdp", "seq")), + }, + }, + ) + .instantiate() + ) + with self.assertRaisesRegex(ValueError, ".*does not have unknown_partition_spec attribute"): + _ = cfg_modifier(cfg) + + class MeshShapeModifierTest(test_utils.TestCase): def test_mesh_shape_update(self): cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt index 41971d290..063a37fd3 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt index a27337377..894ef8e0c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt index 5cc38c163..0aec515fb 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt index 86c13eb79..d8594cb2f 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt index 32be1295c..3ea9305a5 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt index 3de7d2b95..0be311b31 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt index 7cc3b4afc..1373c599c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt index 612565b6f..757056860 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt index 3052dab86..0c67d7a99 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt @@ -185,6 +185,83 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt index d57c35aea..db4233286 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt @@ -185,6 +185,83 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt index 605d5f326..0a90eb22e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt @@ -185,6 +185,89 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt index ccf575c40..082ef5efc 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt @@ -185,6 +185,89 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index e3f269bfa..9577a3439 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -185,6 +185,89 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index b7457c951..ad92c6680 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -185,6 +185,89 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt index 535234d6e..4aad97098 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt @@ -194,6 +194,59 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt index cf5ed9a88..97b025758 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt @@ -194,6 +194,59 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt index 4ec1ad578..961476a8c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt @@ -194,6 +194,59 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt index a58c11472..4f8bb0091 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt @@ -194,6 +194,59 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt index 087727526..cb2898e94 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt @@ -194,6 +194,65 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt index 32f64479e..046e5fa16 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt @@ -194,6 +194,65 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt index d55e01b42..a56b7a38f 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt @@ -194,6 +194,65 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt index eb4182b28..b459bd559 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt @@ -194,6 +194,65 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt index a15dfdf0b..0273cfddc 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt index 7f520cbde..2a7765723 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt index 225299e7b..462795562 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt index 6339517df..a1fc9b8dd 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt @@ -178,6 +178,65 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 001a873e3..9a0d5179e 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -22,11 +22,13 @@ BaseStackedTransformerLayer, FusedGroupedQKVLinear, FusedQKVLinear, + GroupedQKVLinear, GroupedQueryAttention, MultiheadAttention, RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, + StackedTransformerLayer, ) from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function @@ -38,6 +40,8 @@ ChainConfigModifier, GradientAccumulationModifier, MeshShapeModifier, + ModelConfigModifier, + PartitionSpecModifier, RematSpecModifier, ) from axlearn.common.utils import ( @@ -141,6 +145,21 @@ def get_trainer_kwargs( rope_theta = ROPE_THETA[version] + # TRN2 specific model config modifications + trn2_model_modifications = { + # Neuron compiler has a module to detect repeating blocks and reuse them during compilation. + # So compile time does not grow with the number of layers. + "model.decoder.transformer": StackedTransformerLayer.default_config(), + **( + {} + if version == Version.V1 + else { + "model.decoder.transformer.layer.self_attention.attention.input_linear" + ".input_linear": GroupedQKVLinear.default_config() + } + ), + } + offload_dots_saveable_policy = config_for_function( extended_checkpoint_policies.offload_dots_saveable ).set(offload_src="device", offload_dst="pinned_host") @@ -194,6 +213,23 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications=trn2_model_modifications + ), + ], + ), + ), + ), ) elif model_size == "3B": trainer_kwargs = dict( @@ -212,6 +248,23 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications=trn2_model_modifications + ), + ], + ), + ), + ), ) elif model_size == "7B": trainer_kwargs = dict( @@ -325,6 +378,21 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications=trn2_model_modifications + ), + ], + ), + ), ), ) elif model_size == "8B": @@ -405,6 +473,21 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + ModelConfigModifier.default_config().set( + model_cfg_modifications=trn2_model_modifications + ), + ], + ), + ), ), ) elif model_size == "70B": @@ -499,6 +582,8 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) ), RematSpecModifier.default_config().set( @@ -521,6 +606,46 @@ def get_trainer_kwargs( ), } ), + ModelConfigModifier.default_config().set( + model_cfg_modifications=trn2_model_modifications + ), + PartitionSpecModifier.default_config().set( + partition_specs={ + # Vocab parallel embeddings sharding from Megatron LM. + "model.decoder.emb.token_emb": { + "param_partition_spec": ( + "model", + ("expert", "fsdp", "seq"), + ), + "input_partition_spec": ("fsdp", None), + "embedding_partition_spec": ("model", None), + "output_partition_spec": ("fsdp", None, None), + }, + "model.decoder.lm_head": { + "param_partition_spec": ( + "model", + ("expert", "fsdp", "seq"), + ), + }, + # Sequence parallel shardings for norms. + "model.decoder.transformer.layer.self_attention.norm": { + "input_partition_spec": ("fsdp", "model", None), + "output_partition_spec": ("fsdp", None, None), + }, + "model.decoder.transformer.layer.feed_forward.norm": { + "input_partition_spec": ("fsdp", "model", None), + "output_partition_spec": ("fsdp", None, None), + }, + "model.decoder.output_norm": { + "input_partition_spec": ("fsdp", "model", None), + "output_partition_spec": ("fsdp", None, None), + }, + # Sequence parallel shardings for feed_forward layer. + "model.decoder.transformer.layer.feed_forward.linear2": { + "output_partition_spec": ("fsdp", None, None), + } + }, + ), ], ), ),