Skip to content

Commit

Permalink
Add function source
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 14, 2025
1 parent b1b59a2 commit 2d5ef36
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,7 @@ def _apply_rotate(
return rewriters


# This function is adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py
def _untie_parameters_with_parametrizations(model: torch.nn.Module):
# get ALL model parameters and their names
all_named_parameters = {
Expand All @@ -1428,7 +1429,10 @@ def _untie_parameters_with_parametrizations(model: torch.nn.Module):

for tied_param_name in tied_param_names:
tied_param_name_split = tied_param_name.split(".")
# Check if the tied parameter is the original parameter in the module
# The names of the original parameters after registering the parametrization
# have the format "prefix.parametrizations.tensor_name.original", e.g.
# "model.layer.parametrizations.weight.original". This allows to identify
# which subset of tied parameters are original tied parameters of the module
if len(tied_param_name_split) >= 3 and tied_param_name_split[
-3] == "parametrizations" and tied_param_name_split[-1] == "original":
# If that is the case, retrieve the parent module
Expand Down

0 comments on commit 2d5ef36

Please sign in to comment.