diff --git a/axlearn/common/trainer_config_modifier.py b/axlearn/common/trainer_config_modifier.py index d647e1a06..42986f7d8 100644 --- a/axlearn/common/trainer_config_modifier.py +++ b/axlearn/common/trainer_config_modifier.py @@ -2,7 +2,7 @@ """Defines trainer config modifiers, which will be used in model definitions.""" -from typing import Dict, Sequence, Union +from typing import Dict, NamedTuple, Sequence, Union from axlearn.common import config from axlearn.common.base_layer import RematSpec @@ -10,6 +10,7 @@ REQUIRED, ConfigModifier, ConfigOr, + Configurable, Required, config_class, maybe_instantiate, @@ -17,7 +18,53 @@ from axlearn.common.gradient_accumulation import with_minibatch_steps from axlearn.common.metrics import MetricAccumulator from axlearn.common.trainer import SpmdTrainer -from axlearn.common.utils import HybridMeshShape, MeshShape +from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec + + +class _FoundModule(NamedTuple): + """Module found in recursive search of a module name in a nested configudable.""" + + # The module found + module: Configurable.Config + # The parent of the module found + parent_module: Configurable.Config + # Key of the found module in parent + key_in_parent: str + + +def _find_target_module(module_name: str, cfg: SpmdTrainer.Config) -> _FoundModule: + """Recursively search for the target module matching module_name in provided cfg. + + Args: + module_name: Name of the target module + cfg: The trainer config to be searched for module_name + + Raises: + ValueError: The module_name is not found. + + Returns: + A Tuple(curr_module, key_in_parent, parent_module) + curr_module: Module found + parent_module: The parent module + key_in_parent: Key in parent for the found module + """ + + # Here we assume x.y.z format. + # One example would be model.decoder.transformer.layer. + target_modules = module_name.split(".") + curr_module = cfg + key_in_parent = None + parent_module = None + + for target_module_key in target_modules: + if not hasattr(curr_module, target_module_key): + raise ValueError(f"{target_module_key} is not found in {curr_module}.") + parent_module = curr_module + key_in_parent = target_module_key + curr_module = getattr(curr_module, target_module_key) + return _FoundModule( + module=curr_module, parent_module=parent_module, key_in_parent=key_in_parent + ) class GradientAccumulationModifier(ConfigModifier): @@ -100,18 +147,11 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: """ for module_name, remat_spec in self._remat_policies.items(): - # Here we assume x.y.z format. - # One example would be model.decoder.transformer.layer. - target_modules = module_name.split(".") - curr_module = cfg - for target_module in target_modules: - if not hasattr(curr_module, target_module): - raise ValueError(f"{target_module} is not found in {curr_module}.") - curr_module = getattr(curr_module, target_module) + found_module = _find_target_module(module_name, cfg) # Here we assume all modules have remat_spec attribute. - if not hasattr(curr_module, "remat_spec"): - raise ValueError(f"{curr_module} does not have remat_spec attribute") - curr_module.remat_spec = remat_spec + if not hasattr(found_module.module, "remat_spec"): + raise ValueError(f"{found_module.module} does not have remat_spec attribute") + found_module.module.remat_spec = remat_spec return cfg @@ -146,6 +186,116 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: return cfg +class ModelConfigModifier(ConfigModifier): + """Update the model config for the trainer config.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure ModelConfigModifier. + + Attributes: + model_cfg_modifications: A mapping from module path + (e.g. `model.decoder.transformer.layer`) to a Config. + """ + + target_config: Required[str] = REQUIRED + modification: Required[Configurable.Config] = REQUIRED + + def __init__(self, cfg: Config): + super().__init__(cfg) + self._target_config = self.config.target_config + self._modification = self.config.modification + + def _merge_configs( + self, target_cfg: Configurable.Config, found_module: Configurable.Config + ) -> Configurable.Config: + """Merge configurations from the config being replaced on a best effort basis. + + Merge Rules: + - Klass is not changed, use target cfg + - If field exists in both then use from class being replaced + - Otherwise keep the value from target_cfg + + Args: + target_cfg: configuration that will replace found_module. + found_module: existing configuration whose class will be replaced + but it's confguration will be merged with target_cfg. + + Returns: + The modified config. + + """ + for key in target_cfg.keys(): + if key == "klass": + continue + elif hasattr(found_module.module, key) and hasattr(target_cfg, key): + setattr(target_cfg, key, getattr(found_module.module, key)) + return target_cfg + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Overwrite the model config of the specified modules. + + Args: + cfg: The trainer config to be modified. + + Raises: + ValueError: The target module is not found. + + Returns: + The modified trainer config. + """ + + found_module = _find_target_module(self._target_config, cfg) + self._modification = self._merge_configs(self._modification, found_module) + # Replace in the parent config + setattr(found_module.parent_module, found_module.key_in_parent, self._modification) + return cfg + + +class PartitionSpecModifier(ConfigModifier): + """Update the partition spec attribute for the specified modules.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure PartitionSpecModifier. + + Attributes: + partition_specs: A nested mapping from module path + (e.g. `model.decoder.transformer.layer`) to another + mapping of model attribute to PartitionSpec. + """ + + partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED + + def __init__(self, cfg: Config): + super().__init__(cfg) + self._attribute_dicts = self.config.partition_specs + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Update the partition_spec attributes for the specified modules. + + Args: + cfg: The trainer config to be modified. + + Raises: + ValueError: The target module is not found. + ValueError: The partition_spec attribute is not found. + + Returns: + The modified trainer config. + """ + for module_name, partition_spec_dict in self._attribute_dicts.items(): + found_module = _find_target_module(module_name, cfg) + for partition_spec_name, partition_spec in partition_spec_dict.items(): + if not hasattr(found_module.module, partition_spec_name): + raise ValueError( + f"{found_module.module} does not have {partition_spec_name} attribute" + ) + setattr(found_module.module, partition_spec_name, partition_spec) + + return cfg + + class ChainConfigModifier(ConfigModifier): """Chain multiple config modifiers together.""" diff --git a/axlearn/common/trainer_config_modifier_test.py b/axlearn/common/trainer_config_modifier_test.py index ccfe00823..336ba9d79 100644 --- a/axlearn/common/trainer_config_modifier_test.py +++ b/axlearn/common/trainer_config_modifier_test.py @@ -5,13 +5,16 @@ import jax from absl.testing import absltest -from axlearn.common import test_utils +from axlearn.common import causal_lm, test_utils +from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer from axlearn.common.base_layer import RematSpec from axlearn.common.trainer import SpmdTrainer from axlearn.common.trainer_config_modifier import ( ChainConfigModifier, GradientAccumulationModifier, MeshShapeModifier, + ModelConfigModifier, + PartitionSpecModifier, RematSpecModifier, ) from axlearn.common.trainer_test import DummyModel @@ -65,6 +68,86 @@ def test_remat_policy_override(self): _ = cfg_modifier(cfg) +class ModelConfigModifierTest(test_utils.TestCase): + def test_model_config_override(self): + cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config()) + self.assertTrue( + str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config()) + ) + + cfg_modifier = ( + ModelConfigModifier.default_config() + .set( + target_config="model.decoder.transformer", + modification=RepeatedTransformerLayer.default_config(), + ) + .instantiate() + ) + + cfg = cfg_modifier(cfg) + # The default StackedTransformerLayer should have changed to RepeatedTransformerLayer + self.assertTrue( + str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config()) + ) + cfg_modifier = ( + ModelConfigModifier.default_config() + .set( + target_config="model.decoder.unknown", + modification=RepeatedTransformerLayer.default_config(), + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(ValueError, "unknown is not found in.*"): + _ = cfg_modifier(cfg) + + +class PartitionSpecModifierTest(test_utils.TestCase): + def test_partition_spec_override(self): + cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + }, + ) + .instantiate() + ) + cfg = cfg_modifier(cfg) + self.assertTrue( + str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")""" + ) + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + "model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + }, + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(ValueError, "unknown is not found in.*"): + _ = cfg_modifier(cfg) + + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": { + "param_partition_spec": ("model", ("expert", "fsdp", "seq")), + "unknown_partition_spec": ("model", ("expert", "fsdp", "seq")), + }, + }, + ) + .instantiate() + ) + with self.assertRaisesRegex(ValueError, ".*does not have unknown_partition_spec attribute"): + _ = cfg_modifier(cfg) + + class MeshShapeModifierTest(test_utils.TestCase): def test_mesh_shape_update(self): cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index bbd769dad..6f9635ead 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -22,11 +22,13 @@ BaseStackedTransformerLayer, FusedGroupedQKVLinear, FusedQKVLinear, + GroupedQKVLinear, GroupedQueryAttention, MultiheadAttention, RematRegexSavePatterns, RepeatedTransformerLayer, RoFormerQKVLinear, + StackedTransformerLayer, ) from axlearn.common.base_layer import RematSpec from axlearn.common.config import config_for_function @@ -38,6 +40,8 @@ ChainConfigModifier, GradientAccumulationModifier, MeshShapeModifier, + ModelConfigModifier, + PartitionSpecModifier, RematSpecModifier, ) from axlearn.common.utils import ( @@ -151,6 +155,60 @@ def get_trainer_kwargs( rope_theta = ROPE_THETA[version] + # TRN2 specific model config modifications + trn2_model_modifications = [ + # Neuron compiler has a module to detect repeating blocks and reuse them during compilation. + # So compile time does not grow with the number of layers. + ModelConfigModifier.default_config().set( + target_config="model.decoder.transformer", + modification=StackedTransformerLayer.default_config(), + ) + ] + if version != Version.V1: + trn2_model_modifications.append( + ModelConfigModifier.default_config().set( + target_config="model.decoder.transformer.layer.self_attention.attention." + "input_linear.input_linear", + modification=GroupedQKVLinear.default_config(), + ) + ) + + trn2_partition_spec_modifications = [ + PartitionSpecModifier.default_config().set( + partition_specs={ + # Vocab parallel embeddings sharding from Megatron LM. + "model.decoder.emb.token_emb": { + "param_partition_spec": ( + "model", + ("expert", "fsdp", "seq"), + ), + "input_partition_spec": ("fsdp", None), + "output_partition_spec": ("fsdp", "model"), + "embedding_partition_spec": ("model", "fsdp"), + }, + "model.decoder.lm_head": { + "param_partition_spec": ( + "model", + ("expert", "fsdp", "seq"), + ), + }, + # Sequence parallel shardings for norms. + "model.decoder.transformer.layer.self_attention.norm": { + "input_partition_spec": ("fsdp", "model", None), + "output_partition_spec": ("fsdp", None, None), + }, + "model.decoder.transformer.layer.feed_forward.norm": { + "input_partition_spec": ("fsdp", "model", None), + "output_partition_spec": ("fsdp", None, None), + }, + "model.decoder.output_norm": { + "input_partition_spec": ("fsdp", "model", None), + "output_partition_spec": ("fsdp", None, None), + }, + }, + ), + ] + offload_dots_saveable_policy = config_for_function( extended_checkpoint_policies.offload_dots_saveable ).set(offload_src="device", offload_dst="pinned_host") @@ -204,6 +262,22 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + *trn2_model_modifications, + *trn2_partition_spec_modifications, + ], + ), + ), + ), ) elif model_size == "3B": trainer_kwargs = dict( @@ -222,6 +296,22 @@ def get_trainer_kwargs( train_batch_size=train_batch_size, max_step=max_step, mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_rules=( + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + *trn2_model_modifications, + *trn2_partition_spec_modifications, + ], + ), + ), + ), ) elif model_size == "7B": trainer_kwargs = dict( @@ -335,6 +425,20 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + *trn2_model_modifications, + *trn2_partition_spec_modifications, + ], + ), + ), ), ) elif model_size == "8B": @@ -415,6 +519,20 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn2|trn2n).48xlarge-64", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. + mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) + ), + *trn2_model_modifications, + *trn2_partition_spec_modifications, + ], + ), + ), ), ) elif model_size == "70B": @@ -433,7 +551,7 @@ def get_trainer_kwargs( ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, - train_batch_size=train_batch_size, + train_batch_size=8, max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( @@ -509,6 +627,8 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( + # TP within the chip, FSDP across chips. + # Each TRN2 chip has 4 XLA cores. mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4) ), RematSpecModifier.default_config().set( @@ -531,6 +651,8 @@ def get_trainer_kwargs( ), } ), + *trn2_model_modifications, + *trn2_partition_spec_modifications, ], ), ),