Skip to content

Commit

Permalink
Partial fix to integration scales/parametrizations
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 7, 2025
1 parent e3d853f commit 11a6bb3
Show file tree
Hide file tree
Showing 6 changed files with 603 additions and 7 deletions.
46 changes: 45 additions & 1 deletion src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,50 @@ def _init_new_module(self, old_module: Module, name=None):
return new_module

def _replace_old_module(self, model, old_module, new_module, load_state_dict=True):
replace_module(model, old_module, new_module)
if load_state_dict:
# The dictionary entries relative to parametrizations need to be ignored, as these are passed
# when invoking transfer_parametrizations_and_params.
old_module_state_dict = old_module.state_dict()

# If the model is parametrized filter the state_dict appropiately
if parametrize.is_parametrized(old_module):
# Map the keys "parametrizations.tensor_name.original" to "tensor_name"
keys_to_remove = []
keys_value_to_add = []
for key, value in old_module_state_dict.items():
split_key = key.split(".")
if len(split_key) >= 3 and split_key[-3] == "parametrizations" and split_key[
-1] == "original":
tensor_name = split_key[-2]
keys_value_to_add.append((".".join(split_key[:-3] + [tensor_name]), value))
# We need to remove all the keys corresponding to the parametrizations added to the model
# to make sure the dictionary can be loaded with no missing/unused keys
# NOTE: For safety, an additional check could be added as this would not work if a model
# without parametrizations has any key containing "parametrizations"
if "parametrizations" in split_key:
keys_to_remove.append(key)
# The modifications need to be reflected in old_module_state_dict
for key in keys_to_remove:
del old_module_state_dict[key]
for key, value in keys_value_to_add:
old_module_state_dict[key] = value

# Note that strict is set to True, as all the adaptations to the state dict were performed
new_module.load_state_dict(old_module_state_dict)
# If the old module is parametrized, these need to be transferred to the new module
# We do not rely on the method transfer_parametrizations_and_params as using it can result
# in parameter ties being broken
# Note that unsafe is set to True for efficiency, as the checks should have been done
# when first registering the parametrization to old_module
if parametrize.is_parametrized(old_module):
for tensor_name in old_module.parametrizations:
for param_func in old_module.parametrizations[tensor_name]:
parametrize.register_parametrization(
new_module, tensor_name, param_func, unsafe=True)

# TODO: Remove after debugging
def _replace_old_module_legacy(self, model, old_module, new_module, load_state_dict=True):
replace_module(model, old_module, new_module)
if load_state_dict:
# The dictionary entries relative to parametrizations need to be ignored, as these are passed
Expand Down Expand Up @@ -204,7 +248,7 @@ def apply(self, model: GraphModule) -> GraphModule:
if old_module is self.old_module_instance:
# register the parametrization in the old_module
parametrize.register_parametrization(
old_module, self.tensor_name, self.parametrization_module)
old_module, self.tensor_name, self.parametrization_module, unsafe=True)
break
return model

Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,7 @@ def _apply_rotate(
if insert_rotation_module and len(region.srcs) == 0:
rewriter = ModuleInstanceWrapModule(
module, RotatedModule, "layer", {
"had_mat": rot_mat, "k": K})
"had_mat": None, "k": K})
rewriters.append(rewriter)
for r in rewriters:
# The parametrizations need to be registered after the potential HF hooks have been
Expand Down
Loading

0 comments on commit 11a6bb3

Please sign in to comment.