From da907575ab8b748073f6e9dbf7e91facaad6ace6 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Fri, 24 Jan 2025 15:32:05 -0800 Subject: [PATCH] config update --- .../fuji-1B-v3-flash-single-host.txt | 95 +++++++++++++++++++ .../fuji-1B-v3-flash.txt | 95 +++++++++++++++++++ .../fuji-1B-v3-single-host.txt | 95 +++++++++++++++++++ .../fuji-1B-v3-tiktoken-flash-single-host.txt | 95 +++++++++++++++++++ .../fuji-1B-v3-tiktoken-flash.txt | 95 +++++++++++++++++++ .../fuji-1B-v3-tiktoken-single-host.txt | 95 +++++++++++++++++++ .../fuji-1B-v3-tiktoken.txt | 95 +++++++++++++++++++ .../fuji-1B-v3.txt | 95 +++++++++++++++++++ .../fuji-3B-v3-flash-single-host.txt | 95 +++++++++++++++++++ .../fuji-3B-v3-flash.txt | 95 +++++++++++++++++++ .../fuji-3B-v3-single-host.txt | 95 +++++++++++++++++++ .../fuji-3B-v3-tiktoken-flash-single-host.txt | 95 +++++++++++++++++++ .../fuji-3B-v3-tiktoken-flash.txt | 95 +++++++++++++++++++ .../fuji-3B-v3-tiktoken-single-host.txt | 95 +++++++++++++++++++ .../fuji-3B-v3-tiktoken.txt | 95 +++++++++++++++++++ .../fuji-3B-v3.txt | 95 +++++++++++++++++++ .../fuji-70B-v1-flash.txt | 84 +++++++++++++++- .../fuji-70B-v1.txt | 84 +++++++++++++++- .../fuji-70B-v2-flash.txt | 92 +++++++++++++++++- .../fuji-70B-v2.txt | 92 +++++++++++++++++- .../fuji-70B-v3-flash.txt | 92 +++++++++++++++++- .../fuji-70B-v3-tiktoken-flash.txt | 92 +++++++++++++++++- .../fuji-70B-v3-tiktoken.txt | 92 +++++++++++++++++- .../fuji-70B-v3.txt | 92 +++++++++++++++++- .../fuji-7B-v1-flash-single-host.txt | 87 +++++++++++++++++ .../fuji-7B-v1-flash.txt | 87 +++++++++++++++++ .../fuji-7B-v1-single-host.txt | 87 +++++++++++++++++ .../fuji-7B-v1.txt | 87 +++++++++++++++++ .../fuji-7B-v2-flash-single-host.txt | 95 +++++++++++++++++++ .../fuji-7B-v2-flash.txt | 95 +++++++++++++++++++ .../fuji-7B-v2-single-host.txt | 95 +++++++++++++++++++ .../fuji-7B-v2.txt | 95 +++++++++++++++++++ .../fuji-7B-v3-flash-single-host.txt | 95 +++++++++++++++++++ .../fuji-7B-v3-flash.txt | 95 +++++++++++++++++++ .../fuji-7B-v3-single-host.txt | 95 +++++++++++++++++++ .../fuji-7B-v3.txt | 95 +++++++++++++++++++ .../fuji-8B-v3-tiktoken-flash-single-host.txt | 95 +++++++++++++++++++ .../fuji-8B-v3-tiktoken-flash.txt | 95 +++++++++++++++++++ .../fuji-8B-v3-tiktoken-single-host.txt | 95 +++++++++++++++++++ .../fuji-8B-v3-tiktoken.txt | 95 +++++++++++++++++++ 40 files changed, 3704 insertions(+), 24 deletions(-) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt index cdff10a4c..986575d71 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt index 4393f09bf..f2abcecd7 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt index 02926944a..66842ab38 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt index e92520fa8..ec065a6d0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt index 67b87a020..0df5feb88 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt index f01ea2bf9..332ac1994 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt index ed0018f69..00e9b49d6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt index ce42ebc30..14c8bda68 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt index 3f0c10291..b9b87ee15 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt index 3e6436f4c..ec12261cf 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt index 3e6a68d6e..a0e0898b7 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt index 00fdc9ff7..01411bd60 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt index e2670708e..5e7d9188a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt index 10ddc09e1..7d074f037 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt index 33082cd80..d9278aeee 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt index b069f70a1..857df2673 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt @@ -122,6 +122,101 @@ mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[0][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[0][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[0][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[0][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[0][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[0][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[0][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[0][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[0][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt index 3f3a02811..f3b0c55bc 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 367001 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 367001 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,84 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt index e5d388a31..f287eb939 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 367001 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 367001 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,84 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt index 1c8dc6844..ab8c0e563 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 524288 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 1024 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 524288 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 1024 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 1024 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt index 02b00035f..916a6ac29 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 524288 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 1024 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 524288 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 1024 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 1024 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index 14530bb04..adc730c58 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt index 3ffbaf7be..253b94df0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt index 37676eb4b..eadeb917e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index 8efdf10e2..21ba046cf 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -18,7 +18,7 @@ evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.global_batch_size: 8 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -45,7 +45,7 @@ evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.global_batch_size: 8 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 2048 +input.batcher.global_batch_size: 8 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.input_partitioner.fn: 'axlearn.common.input_base.partition_by_path_rank' @@ -193,6 +193,92 @@ mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.l mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[4][1].config_modifiers[2].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[4][1].config_modifiers[2].target_config: 'model.decoder.transformer' +mesh_rules[4][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[4][1].config_modifiers[3].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.bias: True +mesh_rules[4][1].config_modifiers[3].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[0]: None +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[4][1].config_modifiers[3].modification.layer.param_partition_spec[2]: None +mesh_rules[4][1].config_modifiers[3].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[4][1].config_modifiers[4].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[4][1].config_modifiers[4].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt index 94c96a380..923359ff6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash-single-host.txt @@ -202,6 +202,93 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt index f72349af7..a12719f2f 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-flash.txt @@ -202,6 +202,93 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt index a3f8ac77e..f747fc568 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1-single-host.txt @@ -202,6 +202,93 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt index c4426cac3..0ebd87ae0 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v1.txt @@ -202,6 +202,93 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[2].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt index 40d58e819..ad59e47a6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash-single-host.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt index 98be3b833..b78f52af1 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-flash.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt index 0e057f4b9..1ced561b5 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2-single-host.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt index 8d69c9254..04bbf53df 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v2.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt index 467258bf0..5227067ef 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt index 47fb69af4..47d614372 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt index 27dc49fb1..a49c47dae 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt index f391b0abc..77bbabdde 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt @@ -202,6 +202,101 @@ mesh_rules[6][1][2]: 1 mesh_rules[6][1][3]: 8 mesh_rules[6][1][4]: 1 mesh_rules[6][1][5]: 1 +mesh_rules[7][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[7][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[7][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[7][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[7][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[7][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[7][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[7][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[7][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[7][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[7][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[7][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[7][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[7][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[7][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt index 878b7889a..1833c683c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt @@ -186,6 +186,101 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[6][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[6][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[6][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt index bd7c71f4a..347f7007a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt @@ -186,6 +186,101 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[6][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[6][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[6][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt index 5e0762544..01a1e850c 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt @@ -186,6 +186,101 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[6][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[6][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[6][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt index 17ba6f233..e019219a8 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt @@ -186,6 +186,101 @@ mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[6][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[6][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[6][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[6][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[1].modification.klass: 'axlearn.common.attention.StackedTransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.activation: 'nn.relu' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.linear2.param_partition_spec[1]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.residual_weight: 1.0 +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.feed_forward.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.klass: 'axlearn.common.attention.TransformerLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.QKVLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.input_linear.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.bias: True +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.eps: 1e-08 +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.forward_dtype: 'jax.numpy.float32' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.norm.klass: 'axlearn.common.layers.LayerNorm' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.stochastic_depth.mode: 'row' +mesh_rules[6][1].config_modifiers[1].modification.layer.self_attention.structure: 'prenorm' +mesh_rules[6][1].config_modifiers[1].target_config: 'model.decoder.transformer' +mesh_rules[6][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.ModelConfigModifier' +mesh_rules[6][1].config_modifiers[2].modification.klass: 'axlearn.common.attention.GroupedQKVLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.bias: True +mesh_rules[6][1].config_modifiers[2].modification.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[0]: None +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[1]: 'model' +mesh_rules[6][1].config_modifiers[2].modification.layer.param_partition_spec[2]: None +mesh_rules[6][1].config_modifiers[2].target_config: 'model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear' +mesh_rules[6][1].config_modifiers[3].klass: 'axlearn.common.trainer_config_modifier.PartitionSpecModifier' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['input_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['output_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.emb.token_emb']['embedding_partition_spec'][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][0]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][0]: 'expert' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][1]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.lm_head']['param_partition_spec'][1][2]: 'seq' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.self_attention.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.transformer.layer.feed_forward.norm']['output_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][1]: 'model' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['input_partition_spec'][2]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][0]: 'fsdp' +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][1]: None +mesh_rules[6][1].config_modifiers[3].partition_specs['model.decoder.output_norm']['output_partition_spec'][2]: None +mesh_rules[6][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1