Skip to content

Commit

Permalink
TRN2 Meshes and Configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Jan 24, 2025
1 parent 185b1b5 commit c9f9ab1
Show file tree
Hide file tree
Showing 3 changed files with 370 additions and 15 deletions.
176 changes: 163 additions & 13 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,69 @@

"""Defines trainer config modifiers, which will be used in model definitions."""

from typing import Dict, Sequence, Union
from typing import Dict, NamedTuple, Sequence, Union

from axlearn.common import config
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import (
REQUIRED,
ConfigModifier,
ConfigOr,
Configurable,
Required,
config_class,
maybe_instantiate,
)
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import HybridMeshShape, MeshShape
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec


class _FoundModule(NamedTuple):
"""Module found in recursive search of a module name in a nested configudable."""

# The module found
module: Configurable.Config
# The parent of the module found
parent_module: Configurable.Config
# Key of the found module in parent
key_in_parent: str


def _find_target_module(module_name: str, cfg: SpmdTrainer.Config) -> _FoundModule:
"""Recursively search for the target module matching module_name in provided cfg.
Args:
module_name: Name of the target module
cfg: The trainer config to be searched for module_name
Raises:
ValueError: The module_name is not found.
Returns:
A Tuple(curr_module, key_in_parent, parent_module)
curr_module: Module found
parent_module: The parent module
key_in_parent: Key in parent for the found module
"""

# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
key_in_parent = None
parent_module = None

for target_module_key in target_modules:
if not hasattr(curr_module, target_module_key):
raise ValueError(f"{target_module_key} is not found in {curr_module}.")
parent_module = curr_module
key_in_parent = target_module_key
curr_module = getattr(curr_module, target_module_key)
return _FoundModule(
module=curr_module, parent_module=parent_module, key_in_parent=key_in_parent
)


class GradientAccumulationModifier(ConfigModifier):
Expand Down Expand Up @@ -100,18 +147,11 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""

for module_name, remat_spec in self._remat_policies.items():
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
curr_module = getattr(curr_module, target_module)
found_module = _find_target_module(module_name, cfg)
# Here we assume all modules have remat_spec attribute.
if not hasattr(curr_module, "remat_spec"):
raise ValueError(f"{curr_module} does not have remat_spec attribute")
curr_module.remat_spec = remat_spec
if not hasattr(found_module.module, "remat_spec"):
raise ValueError(f"{found_module.module} does not have remat_spec attribute")
found_module.module.remat_spec = remat_spec
return cfg


Expand Down Expand Up @@ -146,6 +186,116 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModelConfigModifier(ConfigModifier):
"""Update the model config for the trainer config."""

@config_class
class Config(ConfigModifier.Config):
"""Configure ModelConfigModifier.
Attributes:
model_cfg_modifications: A mapping from module path
(e.g. `model.decoder.transformer.layer`) to a Config.
"""

target_config: Required[str] = REQUIRED
modification: Required[Configurable.Config] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
self._target_config = self.config.target_config
self._modification = self.config.modification

def _merge_configs(
self, target_cfg: Configurable.Config, found_module: Configurable.Config
) -> Configurable.Config:
"""Merge configurations from the config being replaced on a best effort basis.
Merge Rules:
- Klass is not changed, use target cfg
- If field exists in both then use from class being replaced
- Otherwise keep the value from target_cfg
Args:
target_cfg: configuration that will replace found_module.
found_module: existing configuration whose class will be replaced
but it's confguration will be merged with target_cfg.
Returns:
The modified config.
"""
for key in target_cfg.keys():
if key == "klass":
continue
elif hasattr(found_module.module, key) and hasattr(target_cfg, key):
setattr(target_cfg, key, getattr(found_module.module, key))
return target_cfg

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the model config of the specified modules.
Args:
cfg: The trainer config to be modified.
Raises:
ValueError: The target module is not found.
Returns:
The modified trainer config.
"""

found_module = _find_target_module(self._target_config, cfg)
self._modification = self._merge_configs(self._modification, found_module)
# Replace in the parent config
setattr(found_module.parent_module, found_module.key_in_parent, self._modification)
return cfg


class PartitionSpecModifier(ConfigModifier):
"""Update the partition spec attribute for the specified modules."""

@config_class
class Config(ConfigModifier.Config):
"""Configure PartitionSpecModifier.
Attributes:
partition_specs: A nested mapping from module path
(e.g. `model.decoder.transformer.layer`) to another
mapping of model attribute to PartitionSpec.
"""

partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
self._attribute_dicts = self.config.partition_specs

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Update the partition_spec attributes for the specified modules.
Args:
cfg: The trainer config to be modified.
Raises:
ValueError: The target module is not found.
ValueError: The partition_spec attribute is not found.
Returns:
The modified trainer config.
"""
for module_name, partition_spec_dict in self._attribute_dicts.items():
found_module = _find_target_module(module_name, cfg)
for partition_spec_name, partition_spec in partition_spec_dict.items():
if not hasattr(found_module.module, partition_spec_name):
raise ValueError(
f"{found_module.module} does not have {partition_spec_name} attribute"
)
setattr(found_module.module, partition_spec_name, partition_spec)

return cfg


class ChainConfigModifier(ConfigModifier):
"""Chain multiple config modifiers together."""

Expand Down
85 changes: 84 additions & 1 deletion axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import jax
from absl.testing import absltest

from axlearn.common import test_utils
from axlearn.common import causal_lm, test_utils
from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer
from axlearn.common.base_layer import RematSpec
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.trainer_config_modifier import (
ChainConfigModifier,
GradientAccumulationModifier,
MeshShapeModifier,
ModelConfigModifier,
PartitionSpecModifier,
RematSpecModifier,
)
from axlearn.common.trainer_test import DummyModel
Expand Down Expand Up @@ -65,6 +68,86 @@ def test_remat_policy_override(self):
_ = cfg_modifier(cfg)


class ModelConfigModifierTest(test_utils.TestCase):
def test_model_config_override(self):
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
self.assertTrue(
str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config())
)

cfg_modifier = (
ModelConfigModifier.default_config()
.set(
target_config="model.decoder.transformer",
modification=RepeatedTransformerLayer.default_config(),
)
.instantiate()
)

cfg = cfg_modifier(cfg)
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
self.assertTrue(
str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config())
)
cfg_modifier = (
ModelConfigModifier.default_config()
.set(
target_config="model.decoder.unknown",
modification=RepeatedTransformerLayer.default_config(),
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)


class PartitionSpecModifierTest(test_utils.TestCase):
def test_partition_spec_override(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
},
)
.instantiate()
)
cfg = cfg_modifier(cfg)
self.assertTrue(
str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")"""
)
cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
"model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
},
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)

cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {
"param_partition_spec": ("model", ("expert", "fsdp", "seq")),
"unknown_partition_spec": ("model", ("expert", "fsdp", "seq")),
},
},
)
.instantiate()
)
with self.assertRaisesRegex(ValueError, ".*does not have unknown_partition_spec attribute"):
_ = cfg_modifier(cfg)


class MeshShapeModifierTest(test_utils.TestCase):
def test_mesh_shape_update(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
Expand Down
Loading

0 comments on commit c9f9ab1

Please sign in to comment.