Skip to content

Commit

Permalink
TRN2 Meshes and Configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Jan 14, 2025
1 parent 2d1fb29 commit c23e3b2
Show file tree
Hide file tree
Showing 29 changed files with 2,073 additions and 14 deletions.
173 changes: 160 additions & 13 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

"""Defines trainer config modifiers, which will be used in model definitions."""

from typing import Dict, Sequence, Union
from typing import Callable, Dict, NamedTuple, Sequence, Union

from axlearn.common import config
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import (
REQUIRED,
ConfigBase,
ConfigModifier,
ConfigOr,
Required,
Expand All @@ -17,7 +18,53 @@
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import HybridMeshShape, MeshShape
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec


class _FoundModule(NamedTuple):
"""Module found in recursive search of matching module."""

# The module found
module: ConfigModifier.Config
# The parent of the module found
parent_module: ConfigModifier.Config
# Key of the found module in parent
key_in_parent: ConfigModifier.Config


def _find_target_module(module_name: str, cfg: SpmdTrainer.Config) -> _FoundModule:
"""Recursively search for the target module matching module_name in provided cfg.
Args:
module_name: Name of the target module
cfg: The trainer config to be searched for module_name
Raises:
ValueError: The module_name is not found.
Returns:
A Tuple(curr_module, key_in_parent, parent_module)
curr_module: Module found
parent_module: The parent module
key_in_parent: Key in parent for the found module
"""

# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
key_in_parent = None
parent_module = None

for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
parent_module = curr_module
key_in_parent = target_module
curr_module = getattr(curr_module, target_module)
return _FoundModule(
module=curr_module, parent_module=parent_module, key_in_parent=key_in_parent
)


class GradientAccumulationModifier(ConfigModifier):
Expand Down Expand Up @@ -100,18 +147,11 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""

for module_name, remat_spec in self._remat_policies.items():
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
curr_module = getattr(curr_module, target_module)
found_module = _find_target_module(module_name, cfg)
# Here we assume all modules have remat_spec attribute.
if not hasattr(curr_module, "remat_spec"):
raise ValueError(f"{curr_module} does not have remat_spec attribute")
curr_module.remat_spec = remat_spec
if not hasattr(found_module.module, "remat_spec"):
raise ValueError(f"{found_module.module} does not have remat_spec attribute")
found_module.module.remat_spec = remat_spec
return cfg


Expand Down Expand Up @@ -146,6 +186,113 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModelConfigModifier(ConfigModifier):
"""Update the model config for the trainer config."""

@config_class
class Config(ConfigModifier.Config):
"""Configure ModelConfigModifier.
Attributes:
model_cfg_modifications: A mapping from module path
(e.g. `model.decoder.transformer.layer`) to a Config.
"""

model_cfg_modifications: Dict[str, Callable[[ConfigBase], ConfigBase]] = {}

def __init__(self, cfg: Config):
super().__init__(cfg)
self._model_cfg_modifications = self.config.model_cfg_modifications

def _merge_configs(self, target_cfg: ConfigBase, found_module: ConfigBase) -> ConfigBase:
"""Merge configurations from the config being replaced on a best effort basis.
Args:
target_cfg: configuration that will replace found_module.
found_module: existing configuration whose class will be replaced
but it's confguration will be merged with target_cfg.
Returns:
The modified config.
"""
for key in target_cfg.keys():
if key == "klass":
continue
elif hasattr(found_module.module, key) and hasattr(target_cfg, key):
setattr(target_cfg, key, getattr(found_module.module, key))
return target_cfg

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the model config of the specified modules.
Args:
cfg: The trainer config to be modified.
Raises:
ValueError: The target module is not found.
Returns:
The modified trainer config.
"""

# Iterate over modules in the mapping, modules are sorted based on module name length.
# This ensures parent is modified before children to avoid missing modifications.
for module_name, model_cfg in sorted(
self._model_cfg_modifications.items(), key=lambda item: len(item[0])
):
found_module = _find_target_module(module_name, cfg)

model_cfg = self._merge_configs(model_cfg, found_module)
# Replace in the parent config
setattr(found_module.parent_module, found_module.key_in_parent, model_cfg)
return cfg


class PartitionSpecModifier(ConfigModifier):
"""Update the partition spec attribute for the specified modules."""

@config_class
class Config(ConfigModifier.Config):
"""Configure PartitionSpecModifier.
Attributes:
partition_specs: A nested mapping from module path
(e.g. `model.decoder.transformer.layer`) to another
mapping of model attribute to PartitionSpec.
"""

partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
self._attribute_dicts = self.config.partition_specs

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Update the partition_spec attributes for the specified modules.
Args:
cfg: The trainer config to be modified.
Raises:
ValueError: The target module is not found.
ValueError: The partition_spec attribute is not found.
Returns:
The modified trainer config.
"""
for module_name, partition_spec_dict in self._attribute_dicts.items():
found_module = _find_target_module(module_name, cfg)
for partition_spec_name, partition_spec in partition_spec_dict.items():
if not hasattr(found_module.module, partition_spec_name):
raise ValueError(
f"{found_module.module} does not have {partition_spec_name} attribute"
)
setattr(found_module.module, partition_spec_name, partition_spec)

return cfg


class ChainConfigModifier(ConfigModifier):
"""Chain multiple config modifiers together."""

Expand Down
87 changes: 86 additions & 1 deletion axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import jax
from absl.testing import absltest

from axlearn.common import test_utils
from axlearn.common import causal_lm, test_utils
from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer
from axlearn.common.base_layer import RematSpec
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.trainer_config_modifier import (
ChainConfigModifier,
GradientAccumulationModifier,
MeshShapeModifier,
ModelConfigModifier,
PartitionSpecModifier,
RematSpecModifier,
)
from axlearn.common.trainer_test import DummyModel
Expand Down Expand Up @@ -65,6 +68,88 @@ def test_remat_policy_override(self):
_ = cfg_modifier(cfg)


class ModelConfigModifierTest(test_utils.TestCase):
def test_model_config_override(self):
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
self.assertTrue(
str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config())
)

cfg_modifier = (
ModelConfigModifier.default_config()
.set(
model_cfg_modifications={
"model.decoder.transformer": RepeatedTransformerLayer.default_config(),
}
)
.instantiate()
)

cfg = cfg_modifier(cfg)
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
self.assertTrue(
str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config())
)
cfg_modifier = (
ModelConfigModifier.default_config()
.set(
model_cfg_modifications={
"model.decoder.unknown": RepeatedTransformerLayer.default_config(),
}
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)


class PartitionSpecModifierTest(test_utils.TestCase):
def test_partition_spec_override(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
},
)
.instantiate()
)
cfg = cfg_modifier(cfg)
self.assertTrue(
str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")"""
)
cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
"model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))},
},
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)

cfg_modifier = (
PartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": {
"param_partition_spec": ("model", ("expert", "fsdp", "seq")),
"unknown_partition_spec": ("model", ("expert", "fsdp", "seq")),
},
},
)
.instantiate()
)
with self.assertRaisesRegex(ValueError, ".*does not have unknown_partition_spec attribute"):
_ = cfg_modifier(cfg)


class MeshShapeModifierTest(test_utils.TestCase):
def test_mesh_shape_update(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
Expand Down
Loading

0 comments on commit c23e3b2

Please sign in to comment.