Skip to content

Commit

Permalink
TRN2 Meshes and Configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Jan 10, 2025
1 parent 2d1fb29 commit 5be50d7
Show file tree
Hide file tree
Showing 29 changed files with 1,859 additions and 20 deletions.
132 changes: 120 additions & 12 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import (
REQUIRED,
ConfigBase,
ConfigModifier,
ConfigOr,
Required,
Expand All @@ -17,7 +18,27 @@
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import HybridMeshShape, MeshShape
from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec


def find_target_module(
module_name: str, cfg: SpmdTrainer.Config
) -> tuple[ConfigModifier.Config, ConfigModifier.Config, ConfigModifier.Config]:
"""Recursively search for the target module matching module name in provided config"""
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
key_in_parent = None
parent_module = None

for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
parent_module = curr_module
key_in_parent = target_module
curr_module = getattr(curr_module, target_module)
return curr_module, key_in_parent, parent_module


class GradientAccumulationModifier(ConfigModifier):
Expand Down Expand Up @@ -100,18 +121,11 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""

for module_name, remat_spec in self._remat_policies.items():
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
curr_module = getattr(curr_module, target_module)
found_module, _, _ = find_target_module(module_name, cfg)
# Here we assume all modules have remat_spec attribute.
if not hasattr(curr_module, "remat_spec"):
raise ValueError(f"{curr_module} does not have remat_spec attribute")
curr_module.remat_spec = remat_spec
if not hasattr(found_module, "remat_spec"):
raise ValueError(f"{found_module} does not have remat_spec attribute")
found_module.remat_spec = remat_spec
return cfg


Expand Down Expand Up @@ -146,6 +160,100 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModelConfigModifier(ConfigModifier):
"""Update the model config for the trainer config."""

@config_class
class Config(ConfigModifier.Config):
"""Configure ModelConfigModifier.
Attributes:
model_cfg_modifications: A mapping from module path
(e.g. `model.decoder.transformer.layer`) to a Config.
"""

model_cfg_modifications: Required[Dict[str, ConfigBase]] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._model_cfg_modifications = cfg.model_cfg_modifications

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the model config of the specified modules.
Args:
cfg: The trainer config to be modified.
Raises:
ValueError: The target module is not found.
Returns:
The modified trainer config.
"""

for module_name, model_cfg in self._model_cfg_modifications.items():
# No modification if None
if not model_cfg:
continue

found_module, key_in_parent, parent_module = find_target_module(module_name, cfg)

# Copy configurations from the config being replaced on a best effort basis
for key in model_cfg.keys():
if key == "klass":
continue
elif hasattr(found_module, key) and hasattr(model_cfg, key):
setattr(model_cfg, key, getattr(found_module, key))
# Replace in the parent config
setattr(parent_module, key_in_parent, model_cfg)
return cfg


class ParameterPartitionSpecModifier(ConfigModifier):
"""Update the parameter partition spec for specified modules."""

@config_class
class Config(ConfigModifier.Config):
"""Configure ParameterPartitionSpecModifier.
Attributes:
remat_policies: A mapping from module path
(e.g. `model.decoder.transformer.layer`) to PartitionSpec.
"""

partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._partition_specs = cfg.partition_specs

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Update the param_partition_spec for the specified modules.
Args:
cfg: The trainer config to be modified.
Raises:
ValueError: The target module is not found.
ValueError: The partition_spec attribute is not found.
Returns:
The modified trainer config.
"""

for module_name, param_partition_spec in self._partition_specs.items():
found_module, _, _ = find_target_module(module_name, cfg)

# Here we assume all modules have param_partition_spec attribute.
if not hasattr(found_module, "param_partition_spec"):
raise ValueError(f"{found_module} does not have param_partition_spec attribute")

found_module.param_partition_spec = param_partition_spec
return cfg


class ChainConfigModifier(ConfigModifier):
"""Chain multiple config modifiers together."""

Expand Down
72 changes: 71 additions & 1 deletion axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import jax
from absl.testing import absltest

from axlearn.common import test_utils
from axlearn.common import causal_lm, test_utils
from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer
from axlearn.common.base_layer import RematSpec
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.trainer_config_modifier import (
ChainConfigModifier,
GradientAccumulationModifier,
MeshShapeModifier,
ModelConfigModifier,
ParameterPartitionSpecModifier,
RematSpecModifier,
)
from axlearn.common.trainer_test import DummyModel
Expand Down Expand Up @@ -65,6 +68,73 @@ def test_remat_policy_override(self):
_ = cfg_modifier(cfg)


class ModelConfigModifierTest(test_utils.TestCase):
def test_model_config_override(self):
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
self.assertTrue(
str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config())
)

cfg_modifier = (
ModelConfigModifier.default_config()
.set(
model_cfg_modifications={
"model.decoder.transformer": RepeatedTransformerLayer.default_config(),
}
)
.instantiate()
)

cfg = cfg_modifier(cfg)
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
self.assertTrue(
str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config())
)
cfg_modifier = (
ModelConfigModifier.default_config()
.set(
model_cfg_modifications={
"model.decoder.unknown": RepeatedTransformerLayer.default_config(),
}
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)


class ParameterPartitionSpecModifierTest(test_utils.TestCase):
def test_parameter_partition_spec_override(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
cfg_modifier = (
ParameterPartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": ("model", ("expert", "fsdp", "seq")),
},
)
.instantiate()
)
cfg = cfg_modifier(cfg)
self.assertTrue(
str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")"""
)
cfg_modifier = (
ParameterPartitionSpecModifier.default_config()
.set(
partition_specs={
"model.linear": ("model", ("expert", "fsdp", "seq")),
"model.unknown": ("model", ("expert", "fsdp", "seq")),
},
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)


class MeshShapeModifierTest(test_utils.TestCase):
def test_mesh_shape_update(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,65 @@ mesh_axis_names[2]: 'expert'
mesh_axis_names[3]: 'fsdp'
mesh_axis_names[4]: 'seq'
mesh_axis_names[5]: 'model'
mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64'
mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1
mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1
mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4
mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].klass: 'axlearn.common.attention.StackedTransformerLayer'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.activation: 'nn.relu'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.bias: True
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[0]: None
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear1.param_partition_spec[1]: 'model'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.bias: True
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[0]: 'model'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.linear2.param_partition_spec[1]: None
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.eps: 1e-08
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.residual_weight: 1.0
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.stochastic_depth.mode: 'row'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.feed_forward.structure: 'prenorm'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.klass: 'axlearn.common.attention.TransformerLayer'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.bias: True
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.bias: True
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[0]: None
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.output_linear.param_partition_spec[2]: None
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.eps: 1e-08
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.forward_dtype: 'jax.numpy.float32'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.stochastic_depth.mode: 'row'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer'].layer.self_attention.structure: 'prenorm'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].klass: 'axlearn.common.attention.GroupedQKVLinear'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.bias: True
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[0]: None
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[1]: 'model'
mesh_rules[0][1].config_modifiers[1].model_cfg_modifications['model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear'].layer.param_partition_spec[2]: None
mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: -1
mesh_shape[2]: 1
Expand Down
Loading

0 comments on commit 5be50d7

Please sign in to comment.