-
Notifications
You must be signed in to change notification settings - Fork 285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TRN2 Meshes and Configurations #916
base: main
Are you sure you want to change the base?
TRN2 Meshes and Configurations #916
Conversation
6b404f6
to
3f7c840
Compare
Added a ModelConfigModifier that overrides the class for a module. Allowing different model configurations based on Model size and platform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making such change, overall looks good. A few nit comments.
continue | ||
# Here we assume x.y.z format. | ||
# One example would be model.decoder.transformer.layer. | ||
target_modules = module_name.split(".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try to extract a common util function named something like
def replace_module_recursive(target_modules:str, config_key: str, target_config)
and make it applied to both here and RematSpecModifier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I extracted a helper function, let me know if this looks good
708fc5e
to
d481132
Compare
Added |
5be50d7
to
9b10041
Compare
|
||
found_module, parent_module, key_in_parent = find_target_module(module_name, cfg) | ||
|
||
# Copy configurations from the config being replaced on a best effort basis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, this behavior is not explained in the class comments. So we are not replacing but merging the configs? Maybe we should support a merge function instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the goal is to change the config to a similar module. This means most of the configuration can be reused from before. Essentially replacing the module but merging the config. Let me extract out a merge function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Abstracted out a merge function let me know if more changes are needed for this.
9b10041
to
0f0a530
Compare
@ruomingp Thank you for the review, I have addressed all your comments, please let me know if more changes are needed. |
for module_name, model_cfg in self._model_cfg_modifications.items(): | ||
found_module = _find_target_module(module_name, cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In utils.py we have get_recursively
and set_recursively
for Nested[...]
. I wonder if it will be useful to add corresponding methods to ConfigBase. Then we can do something like:
for module_name, model_cfg in self._model_cfg_modifications.items(): | |
found_module = _find_target_module(module_name, cfg) | |
for cfg_path, cfg_modification in self._model_cfg_modifications.items(): | |
child_cfg = cfg.get_recursively(cfg_path) | |
child_cfg = cfg_modification(child_cfg, path=cfg_path) | |
cfg.set_recursively(cfg_path, value=child_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added get_recursively and set_recursively functions to ConfigBase. Let me know if it looks good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if an alternative (which aims to simplify the ConfigBase
api) is to do something similar to Python's sorted
; we allow utils.get_resursively
to take a value fn:
# Default behavior is to use key lookup:
utils.get_recursively(..., value_fn=lambda k,v: v[k])
# Custom behavior can be attribute lookup:
utils.get_recursively(..., value_fn=lambda k,v: getattr(v,k))
A benefit is that other non-config instances can also leverage get_recursively
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @markblee , maybe we can do this in a follow-up PR?
c23e3b2
to
94bfff6
Compare
Added a more flexible |
45c7df1
to
8807856
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly lgtm, some minor comments.
8807856
to
25510d6
Compare
eec33eb
to
86bafa8
Compare
fe96240
to
da90757
Compare
da90757
to
7e2e5f2
Compare
@ruomingp and @kelvin-zou thank you both for the review. I addressed all comments, please let me know if anymore changes are needed. PR looks clean now. |
axlearn/common/config.py
Outdated
key: str | ||
|
||
def recursive_traverse(self, key_path: Sequence[str]) -> tuple[Any, str]: | ||
"""Recursively traverse the config to find the target key. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see other comment re get_recursively
; also, I wonder whether we actually need recursion here (seems like a loop would be simpler).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both loops and recursion are fair choices I was following the pattern of utils.get_recusively()
based on previous guidance.
@@ -146,6 +137,110 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | |||
return cfg | |||
|
|||
|
|||
class ModelConfigModifier(ConfigModifier): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which part of this class is specific to model? It seems to take generic modifications?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class makes some assumptions around the klass
attribute of configs when merging/reusing the previous configs which makes it non generic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@markblee I would love to get some more guidance on a name that fits this class better, do you have any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, but all configs deriving from Configurable
have klass
. Did you mean ModuleConfigModifier
perhaps, as it replaces arbitrary modules by path? (model -> module?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree ModuleConfigModifier
is better, thanks!
if version != Version.V1: | ||
trn2_model_modifications.append( | ||
ModelConfigModifier.default_config().set( | ||
target_config="model.decoder.transformer.layer.self_attention.attention." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A downside of representing these deeply nested configs as string paths is that they are brittle, and can quickly become outdated.
Have we considered using cfg.visit
to achieve some of these modifications (e.g.,
axlearn/axlearn/common/layers.py
Line 266 in 1c22688
def set_layer_norm_eps_recursively(cfg: ConfigBase, eps: float, set_only_if_none: bool = False): |
(A bit late to review, so apologies if this discussion has already taken place.)
axlearn/common/config.py
Outdated
"""Recursively find the target key in the config and return its value. | ||
|
||
Args: | ||
key_path: A sequence of keys for indexing to get the target value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can path be empty? Maybe it can return self
if path is empty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a check to see if path is empty, throws value error now if path is empty
axlearn/common/config.py
Outdated
Raises: | ||
ValueError: A key in key_path is not found. | ||
""" | ||
traverse_result = self.recursive_traverse(key_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do something like:
traverse_result = self.recursive_traverse(key_path) | |
if not path: | |
raise ValueError(...) | |
parent = self.get_recursively(path[:-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the recursive_traverse
method and added a value error if path is empty. I am still using path[1:] instead of path[:-1] since we are traversing parents first.
7e2e5f2
to
73e14a6
Compare
Thank you for the reviews @markblee @ruomingp @kelvin-zou! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @apoorvtintin, please feel free leave TODO(markblee):
for my comments and go ahead once @ruomingp approves. I'll be happy to send a follow-up PR to avoid blocking this one.
@@ -146,6 +137,110 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | |||
return cfg | |||
|
|||
|
|||
class ModelConfigModifier(ConfigModifier): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, but all configs deriving from Configurable
have klass
. Did you mean ModuleConfigModifier
perhaps, as it replaces arbitrary modules by path? (model -> module?)
73e14a6
to
61efbb5
Compare
Addressed all comments. @markblee I added TODOs for us to revisit cfg.visit later, please let me know if I should add more. |
|
||
if index == len(path): | ||
return current |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if index == len(path): | |
return current |
current = self | ||
for i, key in enumerate(path): | ||
if i == len(path) - 1: | ||
setattr(current, key, value) | ||
return | ||
else: | ||
# TODO(markblee): maybe use cfg.visit instead of getattr | ||
current = getattr(current, key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current = self | |
for i, key in enumerate(path): | |
if i == len(path) - 1: | |
setattr(current, key, value) | |
return | |
else: | |
# TODO(markblee): maybe use cfg.visit instead of getattr | |
current = getattr(current, key) | |
parent = self.get_recursively(path[:-1]) | |
setattr(current, path[-1], value) |
+ address comments
61efbb5
to
dfbb412
Compare
This PR adds meshes for TRN2/1 for Fuji models and transformer layer configuration favorable to Neuron.
Neuron supports stacked transformer and GroupedQKVLinear instead of FusedGroupedQKVLinear for Grouped Query Attention (GQA)
This is a newer version of the PR #885. This PR resolved all comments and requested changes mentioned in the linked PR.