Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRN2 Meshes and Configurations #916

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

apoorvtintin
Copy link
Contributor

This PR adds meshes for TRN2/1 for Fuji models and transformer layer configuration favorable to Neuron.

Neuron supports stacked transformer and GroupedQKVLinear instead of FusedGroupedQKVLinear for Grouped Query Attention (GQA)

This is a newer version of the PR #885. This PR resolved all comments and requested changes mentioned in the linked PR.

@apoorvtintin apoorvtintin requested review from ruomingp, markblee and a team as code owners January 10, 2025 00:48
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 6b404f6 to 3f7c840 Compare January 10, 2025 00:53
@apoorvtintin
Copy link
Contributor Author

Added a ModelConfigModifier that overrides the class for a module. Allowing different model configurations based on Model size and platform.

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making such change, overall looks good. A few nit comments.

continue
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to extract a common util function named something like
def replace_module_recursive(target_modules:str, config_key: str, target_config) and make it applied to both here and RematSpecModifier

Copy link
Contributor Author

@apoorvtintin apoorvtintin Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted a helper function, let me know if this looks good

axlearn/common/trainer_config_modifier_test.py Outdated Show resolved Hide resolved
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 708fc5e to d481132 Compare January 10, 2025 07:38
@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Jan 10, 2025

Added ParameterPartitionSpecModifier for parameters to shard Embeddings in a vocab parallel manner as described in Megatron LM.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 5be50d7 to 9b10041 Compare January 10, 2025 08:10
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved

found_module, parent_module, key_in_parent = find_target_module(module_name, cfg)

# Copy configurations from the config being replaced on a best effort basis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, this behavior is not explained in the class comments. So we are not replacing but merging the configs? Maybe we should support a merge function instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the goal is to change the config to a similar module. This means most of the configuration can be reused from before. Essentially replacing the module but merging the config. Let me extract out a merge function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Abstracted out a merge function let me know if more changes are needed for this.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 9b10041 to 0f0a530 Compare January 12, 2025 07:06
@apoorvtintin
Copy link
Contributor Author

@ruomingp Thank you for the review, I have addressed all your comments, please let me know if more changes are needed.

@apoorvtintin apoorvtintin requested a review from ruomingp January 12, 2025 07:08
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
Comment on lines 239 to 244
for module_name, model_cfg in self._model_cfg_modifications.items():
found_module = _find_target_module(module_name, cfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In utils.py we have get_recursively and set_recursively for Nested[...]. I wonder if it will be useful to add corresponding methods to ConfigBase. Then we can do something like:

Suggested change
for module_name, model_cfg in self._model_cfg_modifications.items():
found_module = _find_target_module(module_name, cfg)
for cfg_path, cfg_modification in self._model_cfg_modifications.items():
child_cfg = cfg.get_recursively(cfg_path)
child_cfg = cfg_modification(child_cfg, path=cfg_path)
cfg.set_recursively(cfg_path, value=child_cfg)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added get_recursively and set_recursively functions to ConfigBase. Let me know if it looks good

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if an alternative (which aims to simplify the ConfigBase api) is to do something similar to Python's sorted; we allow utils.get_resursively to take a value fn:

# Default behavior is to use key lookup:
utils.get_recursively(..., value_fn=lambda k,v: v[k])

# Custom behavior can be attribute lookup:
utils.get_recursively(..., value_fn=lambda k,v: getattr(v,k))

A benefit is that other non-config instances can also leverage get_recursively.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @markblee , maybe we can do this in a follow-up PR?

@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Jan 15, 2025

Added a more flexible PartitionSpecModifier that can modify multiple partition_spec attributes in a single module config.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 45c7df1 to 8807856 Compare January 17, 2025 01:17
Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly lgtm, some minor comments.

axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 8807856 to 25510d6 Compare January 22, 2025 01:39
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from eec33eb to 86bafa8 Compare January 23, 2025 05:40
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 3 times, most recently from fe96240 to da90757 Compare January 24, 2025 23:32
@apoorvtintin apoorvtintin requested a review from ruomingp January 24, 2025 23:40
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from da90757 to 7e2e5f2 Compare January 27, 2025 10:07
@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Jan 27, 2025

@ruomingp and @kelvin-zou thank you both for the review. I addressed all comments, please let me know if anymore changes are needed. PR looks clean now.

key: str

def recursive_traverse(self, key_path: Sequence[str]) -> tuple[Any, str]:
"""Recursively traverse the config to find the target key.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see other comment re get_recursively; also, I wonder whether we actually need recursion here (seems like a loop would be simpler).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both loops and recursion are fair choices I was following the pattern of utils.get_recusively() based on previous guidance.

@@ -146,6 +137,110 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModelConfigModifier(ConfigModifier):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which part of this class is specific to model? It seems to take generic modifications?

Copy link
Contributor Author

@apoorvtintin apoorvtintin Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class makes some assumptions around the klass attribute of configs when merging/reusing the previous configs which makes it non generic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markblee I would love to get some more guidance on a name that fits this class better, do you have any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, but all configs deriving from Configurable have klass. Did you mean ModuleConfigModifier perhaps, as it replaces arbitrary modules by path? (model -> module?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree ModuleConfigModifier is better, thanks!

axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
if version != Version.V1:
trn2_model_modifications.append(
ModelConfigModifier.default_config().set(
target_config="model.decoder.transformer.layer.self_attention.attention."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A downside of representing these deeply nested configs as string paths is that they are brittle, and can quickly become outdated.

Have we considered using cfg.visit to achieve some of these modifications (e.g.,

def set_layer_norm_eps_recursively(cfg: ConfigBase, eps: float, set_only_if_none: bool = False):
)?

(A bit late to review, so apologies if this discussion has already taken place.)

axlearn/common/config.py Outdated Show resolved Hide resolved
axlearn/common/config.py Outdated Show resolved Hide resolved
axlearn/common/config.py Outdated Show resolved Hide resolved
"""Recursively find the target key in the config and return its value.

Args:
key_path: A sequence of keys for indexing to get the target value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can path be empty? Maybe it can return self if path is empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a check to see if path is empty, throws value error now if path is empty

axlearn/common/config.py Outdated Show resolved Hide resolved
Raises:
ValueError: A key in key_path is not found.
"""
traverse_result = self.recursive_traverse(key_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do something like:

Suggested change
traverse_result = self.recursive_traverse(key_path)
if not path:
raise ValueError(...)
parent = self.get_recursively(path[:-1])

Copy link
Contributor Author

@apoorvtintin apoorvtintin Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the recursive_traverse method and added a value error if path is empty. I am still using path[1:] instead of path[:-1] since we are traversing parents first.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 7e2e5f2 to 73e14a6 Compare January 31, 2025 19:05
@apoorvtintin
Copy link
Contributor Author

Thank you for the reviews @markblee @ruomingp @kelvin-zou!
@ruomingp I addressed most of your comments, please let me know if more changes are needed. @markblee We can discuss the usage of cfg.visit instead of using full module paths further, but I would like to address this in a follow up PR if possible.

axlearn/common/config.py Outdated Show resolved Hide resolved
axlearn/common/config.py Outdated Show resolved Hide resolved
axlearn/common/config.py Outdated Show resolved Hide resolved
axlearn/common/config.py Outdated Show resolved Hide resolved
axlearn/common/config.py Outdated Show resolved Hide resolved
Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @apoorvtintin, please feel free leave TODO(markblee): for my comments and go ahead once @ruomingp approves. I'll be happy to send a follow-up PR to avoid blocking this one.

@@ -146,6 +137,110 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModelConfigModifier(ConfigModifier):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, but all configs deriving from Configurable have klass. Did you mean ModuleConfigModifier perhaps, as it replaces arbitrary modules by path? (model -> module?)

axlearn/experiments/text/gpt/fuji.py Outdated Show resolved Hide resolved
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 73e14a6 to 61efbb5 Compare February 3, 2025 09:28
@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Feb 3, 2025

Addressed all comments. @markblee I added TODOs for us to revisit cfg.visit later, please let me know if I should add more.

Comment on lines +419 to +421

if index == len(path):
return current
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if index == len(path):
return current

Comment on lines +440 to +447
current = self
for i, key in enumerate(path):
if i == len(path) - 1:
setattr(current, key, value)
return
else:
# TODO(markblee): maybe use cfg.visit instead of getattr
current = getattr(current, key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
current = self
for i, key in enumerate(path):
if i == len(path) - 1:
setattr(current, key, value)
return
else:
# TODO(markblee): maybe use cfg.visit instead of getattr
current = getattr(current, key)
parent = self.get_recursively(path[:-1])
setattr(current, path[-1], value)

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 61efbb5 to dfbb412 Compare February 3, 2025 18:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants