Skip to content

Commit

Permalink
Fix for rotations with tied weights
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 13, 2025
1 parent c42a1ed commit 20cffcc
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 29 deletions.
32 changes: 32 additions & 0 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,39 @@ def _apply_rotate(
return rewriters


from brevitas.utils.python_utils import recurse_getattr


def _untie_parameters_with_parametrizations(model: torch.nn.Module):
# get ALL model parameters and their names
all_named_parameters = {
name: param for name, param in model.named_parameters(remove_duplicate=False)}

# get ONLY unique named parameters,
# if parameter is tied and have multiple names, it will be included only once
no_duplicate_named_parameters = {
name: param for name, param in model.named_parameters(remove_duplicate=True)}

# the difference of the two sets will give us the tied parameters
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())

for tied_param_name in tied_param_names:
tied_param_name_split = tied_param_name.split(".")
# Check if the tied parameter is the original parameter in the module
if len(tied_param_name_split) >= 3 and tied_param_name_split[
-3] == "parametrizations" and tied_param_name_split[-1] == "original":
# If that is the case, retrieve the parent module
parent_module = recurse_getattr(model, ".".join(tied_param_name_split[:-1]))
# And set to a new parameter, thus breaking the tie
setattr(parent_module, "original", nn.Parameter(all_named_parameters[tied_param_name]))

return model


def _fuse_rotations(model: nn.Module) -> nn.Module:
# First of all, parameters that have parametrizations need to be untied
model = _untie_parameters_with_parametrizations(model)
# Then, parametrizations can be safely removed
for module in model.modules():
# Names of the tensors that can potentially be parametrized
tensor_names = ["weight", "bias"]
Expand Down
65 changes: 41 additions & 24 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import functools

from packaging import version
import pytest
import pytest_cases
Expand Down Expand Up @@ -540,34 +542,49 @@ def forward(self, x):
"srcs": [], "sinks": ["block2_linear2"]},]


@pytest_cases.fixture
def block_residual_model():
class BlockResidualModel(nn.Module):

class BlockResidualModel(nn.Module):
def __init__(self, is_tied: bool = False) -> None:
super().__init__()
self.embedding = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)

def __init__(self) -> None:
super().__init__()
self.embedding = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)
self.block1_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True)
self.block1_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)

self.block1_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True)
self.block1_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)
self.block2_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)
self.act = nn.SiLU()
self.block2_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True)

self.block2_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)
self.act = nn.SiLU()
self.block2_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True)
self.head = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)
if is_tied:
self.head.weight = self.embedding.weight

self.head = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)
def forward(self, x):
x = self.embedding(x)
r = x
x = self.block1_linear1(x)
x = self.block1_linear2(x) + r
r = x
x = self.block2_linear1(x)
x = self.act(x)
x = self.block2_linear2(x) + r
x = self.head(x)
return x

def forward(self, x):
x = self.embedding(x)
r = x
x = self.block1_linear1(x)
x = self.block1_linear2(x) + r
r = x
x = self.block2_linear1(x)
x = self.act(x)
x = self.block2_linear2(x) + r
x = self.head(x)
return x

return BlockResidualModel
@pytest_cases.fixture
def block_residual_model():
return functools.partial(BlockResidualModel, is_tied=False)


@pytest_cases.fixture
def block_residual_model_tied():
return functools.partial(BlockResidualModel, is_tied=True)


list_of_rotation_fixtures = [
"block_residual_model",
"block_residual_model_tied",]

rotation_model = fixture_union(
'rotation_model', list_of_rotation_fixtures, ids=list_of_rotation_fixtures)
14 changes: 9 additions & 5 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,11 @@ def compare_model_weights(model_fused, model_unfused, classes_to_compare=(nn.Lin
ids=lambda mask: "-".join([rot for mask_el, rot in zip(mask, ["R1", "R2", "R3"]) if mask_el]))
@pytest_cases.parametrize('full_rotation_method', ['ort', 'had'])
@pytest_cases.parametrize('device', ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu'])
@pytest_cases.parametrize('fuse_rotations', [False, True], ids=["fused", "unfused"])
@pytest_cases.parametrize('fuse_rotations', [False, True], ids=["unfused", "fused"])
@pytest_cases.parametrize('use_fx', [True, False], ids=["fx", "no-fx"])
def test_apply_rotate(
block_residual_model, mask, full_rotation_method, device, fuse_rotations, use_fx):
def test_apply_rotate(rotation_model, mask, full_rotation_method, device, fuse_rotations, use_fx):
# Instantiate a residual model for which a collection of regions is available
model = block_residual_model()
model = rotation_model()
device = torch.device("cuda") if device == 'cuda' else torch.device("cpu")
model.to(device)
# Sample input to pass through the models
Expand All @@ -433,11 +432,16 @@ def test_apply_rotate(
# The module names in the original model need to be mapped to the ones
# in graph_model
map_model_graph = {}
assigned_graph_modules = set()
for graph_module_name, graph_module in graph_model.named_modules():
if hasattr(graph_module, "weight"):
for name, module in model.named_modules():
if hasattr(module, "weight") and graph_module.weight is module.weight:
# The check name not in map_model_graph prevents the assignment to the same module
# when tied parameters are present
if name not in map_model_graph and graph_module_name not in assigned_graph_modules and hasattr(
module, "weight") and graph_module.weight is module.weight:
map_model_graph[name] = graph_module_name
assigned_graph_modules.add(graph_module_name)
# Replace the names of the modules in sources/sinks by the names of the modules in the FX model
regions_dicts = [{
k: list(map(lambda x: map_model_graph[x], v))
Expand Down

0 comments on commit 20cffcc

Please sign in to comment.