Skip to content

Commit

Permalink
TRN2 Meshes and Configurations
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Jan 10, 2025
1 parent 2d1fb29 commit 6b404f6
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 1 deletion.
62 changes: 62 additions & 0 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import (
REQUIRED,
ConfigBase,
ConfigModifier,
ConfigOr,
Required,
Expand Down Expand Up @@ -146,6 +147,67 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModelConfigModifier(ConfigModifier):
"""Update the model config for the trainer config."""

@config_class
class Config(ConfigModifier.Config):
"""Configure ModelConfigModifier.
Attributes:
model_cfg_modifications: A mapping from module path
(e.g. `model.decoder.transformer.layer`) to a Config.
"""

model_cfg_modifications: Required[Dict[str, ConfigBase]] = REQUIRED

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._model_cfg_modifications = cfg.model_cfg_modifications

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the model config of the specified modules.
Args:
cfg: The trainer config to be modified.
Raises:
ValueError: The target module is not found.
Returns:
The modified trainer config.
"""

for module_name, model_cfg in self._model_cfg_modifications.items():
# No modification if None
if not model_cfg:
continue
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
target_module_in_parent = None
parent_module = None

for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
parent_module = curr_module
target_module_in_parent = target_module
curr_module = getattr(curr_module, target_module)

# Copy configurations from the config being replaced on a best effort basis
for key in model_cfg.keys():
if key == "klass":
continue
elif hasattr(curr_module, key) and hasattr(curr_module, key):
setattr(model_cfg, key, getattr(curr_module, key))
# Replace in the parent config
setattr(parent_module, target_module_in_parent, model_cfg)
return cfg


class ChainConfigModifier(ConfigModifier):
"""Chain multiple config modifiers together."""

Expand Down
25 changes: 24 additions & 1 deletion axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import jax
from absl.testing import absltest

from axlearn.common import test_utils
from axlearn.common import causal_lm, test_utils
from axlearn.common.attention import RepeatedTransformerLayer
from axlearn.common.base_layer import RematSpec
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.trainer_config_modifier import (
ChainConfigModifier,
GradientAccumulationModifier,
MeshShapeModifier,
ModelConfigModifier,
RematSpecModifier,
)
from axlearn.common.trainer_test import DummyModel
Expand Down Expand Up @@ -65,6 +67,27 @@ def test_remat_policy_override(self):
_ = cfg_modifier(cfg)


class ModelConfigModifierTest(test_utils.TestCase):
def test_model_config_override(self):
cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config())
print(cfg)
self.assertRegex(str(cfg.model.decoder), ".*StackedTransformerLayer")

cfg_modifier = (
ModelConfigModifier.default_config()
.set(
model_cfg_modifications={
"model.decoder.transformer": RepeatedTransformerLayer.default_config(),
}
)
.instantiate()
)

cfg = cfg_modifier(cfg)
# The default StackedTransformerLayer should have changed to RepeatedTransformerLayer
self.assertRegex(str(cfg.model.decoder), ".*RepeatedTransformerLayer")


class MeshShapeModifierTest(test_utils.TestCase):
def test_mesh_shape_update(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
Expand Down
92 changes: 92 additions & 0 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
BaseStackedTransformerLayer,
FusedGroupedQKVLinear,
FusedQKVLinear,
GroupedQKVLinear,
GroupedQueryAttention,
MultiheadAttention,
RepeatedTransformerLayer,
RoFormerQKVLinear,
StackedTransformerLayer,
)
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import config_for_function
Expand All @@ -37,6 +39,7 @@
ChainConfigModifier,
GradientAccumulationModifier,
MeshShapeModifier,
ModelConfigModifier,
RematSpecModifier,
)
from axlearn.common.utils import extended_checkpoint_policies
Expand Down Expand Up @@ -130,6 +133,16 @@ def get_trainer_kwargs(

rope_theta = ROPE_THETA[version]

# TRN2 specific model config modifications
trn2_model_modifications = {
# Neuron compiler has a module to detect repeating blocks and reuse them during compilation.
# So compile time does not grow with the number of layers.
"model.decoder.transformer": StackedTransformerLayer.default_config(),
"model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear": (
None if version == Version.V1 else GroupedQKVLinear.default_config()
),
}

offload_dots_saveable_policy = config_for_function(
extended_checkpoint_policies.offload_dots_saveable
).set(offload_src="device", offload_dst="pinned_host")
Expand Down Expand Up @@ -174,6 +187,23 @@ def get_trainer_kwargs(
train_batch_size=train_batch_size,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
# TP within the chip, FSDP across chips.
# Each TRN2 chip has 4 XLA cores.
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications=trn2_model_modifications
),
],
),
),
),
)
elif model_size == "3B":
trainer_kwargs = dict(
Expand All @@ -192,6 +222,23 @@ def get_trainer_kwargs(
train_batch_size=train_batch_size,
max_step=max_step,
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
mesh_rules=(
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
# TP within the chip, FSDP across chips.
# Each TRN2 chip has 4 XLA cores.
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications=trn2_model_modifications
),
],
),
),
),
)
elif model_size == "7B":
trainer_kwargs = dict(
Expand Down Expand Up @@ -287,6 +334,21 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
# TP within the chip, FSDP across chips.
# Each TRN2 chip has 4 XLA cores.
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications=trn2_model_modifications
),
],
),
),
),
)
elif model_size == "8B":
Expand Down Expand Up @@ -367,6 +429,21 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
# TP within the chip, FSDP across chips.
# Each TRN2 chip has 4 XLA cores.
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications=trn2_model_modifications
),
],
),
),
),
)
elif model_size == "70B":
Expand Down Expand Up @@ -417,6 +494,21 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
# TP within the chip, FSDP across chips.
# Each TRN2 chip has 4 XLA cores.
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
ModelConfigModifier.default_config().set(
model_cfg_modifications=trn2_model_modifications
),
],
),
),
),
)
else:
Expand Down

0 comments on commit 6b404f6

Please sign in to comment.