Skip to content

Commit

Permalink
Run precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 10, 2025
1 parent fa9d0e0 commit 8b71a53
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 84 deletions.
16 changes: 11 additions & 5 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from abc import abstractmethod
import inspect
from inspect import getcallargs
from typing import Any, Callable, Dict, Type, Union, Optional
from typing import Any, Callable, Dict, Optional, Type, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter
from torch import Tensor
from torch.overrides import get_testing_overrides
import torch.nn.utils.parametrize as parametrize
from torch.overrides import get_testing_overrides

from brevitas.fx import GraphModule
from brevitas.fx import immutable_dict
Expand Down Expand Up @@ -177,6 +177,7 @@ def apply(self, graph_model: GraphModule) -> GraphModule:
graph_model.graph.lint()
return graph_model


class ModuleInstanceRegisterParametrization(Transform):

def __init__(
Expand All @@ -195,6 +196,7 @@ def apply(self, model: GraphModule) -> GraphModule:
break
return model


class RotationWeightParametrization(torch.nn.Module):

def __init__(
Expand All @@ -221,6 +223,7 @@ def forward(self, weight: torch.Tensor) -> torch.Tensor:

return weight


class ModuleInstanceFuseRotationWeights(Transform):

def __init__(
Expand All @@ -245,15 +248,17 @@ def apply(self, model: GraphModule) -> GraphModule:
if hasattr(old_module, 'allocate_params'):
old_module.allocate_params(old_module)
weight = getattr(old_module, self.tensor_name).data
weight = RotationWeightParametrization(self.rot_mat, self.rot_func, self.axis, self.K)(weight)
weight = RotationWeightParametrization(
self.rot_mat, self.rot_func, self.axis, self.K)(weight)
# Modify the weights in-place
getattr(old_module, self.tensor_name).data = weight

if hasattr(old_module, 'offload_params'):
old_module.offload_params(old_module)
break
return model



class ModuleInstanceWrapModule(Transform):

def __init__(
Expand All @@ -278,6 +283,7 @@ def apply(self, model: GraphModule) -> GraphModule:
break
return model


class ModuleInstanceToModuleInstance(Transform):

def __init__(self, old_module_instance, new_module_instance):
Expand Down
13 changes: 10 additions & 3 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from brevitas.graph import ModuleToModuleByInstance
from brevitas.graph.base import GraphTransform
from brevitas.graph.base import InsertModuleCallAfter
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleInstanceFuseRotationWeights
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleInstanceWrapModule
from brevitas.graph.base import RotationWeightParametrization
from brevitas.graph.base import Transform
Expand Down Expand Up @@ -1304,13 +1304,19 @@ def random_orthogonal_matrix(size):
return q


def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had', fuse_rotations: bool = True, apply_inplace_rotations: bool = True):
def _apply_rotate(
model: nn.Module,
regions: List[Region],
full_rotation_method='had',
fuse_rotations: bool = True,
apply_inplace_rotations: bool = True):
rewriters = []
# First, rotations on orphan sinks are applied so the order in which rotations are
# applied is consistent, irrespective of the value of fuse_rotations. This is due to
# the fact that parametrizations need to be registered, once all the in-place
# operations have taken place
regions = [region for region in regions if len(region.srcs) == 0] + [region for region in regions if len(region.srcs) > 0]
regions = [region for region in regions if len(region.srcs) == 0] + [
region for region in regions if len(region.srcs) > 0]
for region in regions:
insert_rotation_module = len(region.srcs) == 0

Expand Down Expand Up @@ -1436,6 +1442,7 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
model = r.apply(model)
return rewriters


def _fuse_rotations(model: nn.Module) -> nn.Module:
for module in model.modules():
# Names of the tensors that can potentially be parametrized
Expand Down
17 changes: 7 additions & 10 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,19 +533,16 @@ def forward(self, x):
RESIDUAL_MODEL_REGION_DICTS = [
{
"srcs": ["embedding", "block1_linear2", "block2_linear2"],
"sinks": ["block1_linear1", "block2_linear1", "head"],
},
"sinks": ["block1_linear1", "block2_linear1", "head"],},
{
"srcs": ["block1_linear1"],
"sinks": ["block1_linear2"]
},
"srcs": ["block1_linear1"], "sinks": ["block1_linear2"]},
{
"srcs": [],
"sinks": ["block2_linear2"]
},
]
"srcs": [], "sinks": ["block2_linear2"]},]


@pytest_cases.fixture
def block_residual_model():

class BlockResidualModel(nn.Module):

def __init__(self) -> None:
Expand All @@ -572,5 +569,5 @@ def forward(self, x):
x = self.block2_linear2(x) + r
x = self.head(x)
return x

return BlockResidualModel
120 changes: 74 additions & 46 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
from functools import partial
from functools import reduce
import itertools
from functools import partial, reduce
from unittest.mock import patch

import pytest
Expand All @@ -14,27 +15,27 @@
from brevitas.fx import symbolic_trace
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import RotationWeightParametrization
from brevitas.graph.equalize import Region
from brevitas.graph.equalize import EqualizationIndexes
from brevitas.graph.equalize import _apply_had_device
from brevitas.graph.equalize import _apply_ort_device
from brevitas.graph.equalize import _apply_rotate
from brevitas.graph.equalize import _batch_norm
from brevitas.graph.equalize import _extract_regions
from brevitas.graph.equalize import _fuse_rotations
from brevitas.graph.equalize import _get_input_axis
from brevitas.graph.equalize import _get_output_axis
from brevitas.graph.equalize import _is_supported_module
from brevitas.graph.equalize import _supported_layers
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import EqualizationIndexes
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import MergeLnAffine
from brevitas.graph.equalize import random_orthogonal_matrix
from brevitas.graph.equalize import _apply_rotate
from brevitas.graph.equalize import _apply_had_device
from brevitas.graph.equalize import _fuse_rotations
from brevitas.graph.equalize import _apply_ort_device
from brevitas.graph.equalize import _get_input_axis
from brevitas.graph.equalize import _get_output_axis
from brevitas.graph.equalize import Region
from brevitas.graph.hadamard import get_hadK
from brevitas.nn.equalized_layer import RotatedModule
from brevitas.graph.standardize import DuplicateSharedStatelessModule
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.graph.utils import get_module
from brevitas.nn.equalized_layer import RotatedModule
from tests.marker import requires_pt_ge

from .equalization_fixtures import *
Expand Down Expand Up @@ -294,7 +295,7 @@ def test_models(rotation_fixtures, partial_had):
if partial_had:
last_weight_new = model.linear_2.layer.weight.data
assert not torch.allclose(last_weight, last_weight_new)


@pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}")
def test_composition_unfused_rotations(N):
Expand Down Expand Up @@ -372,21 +373,25 @@ def _random_orthogonal_matrix(size, generator):
q *= torch.sign(torch.diag(r)).unsqueeze(0).float()
return q


# Auxiliar method to convert a dictionary of sources/sinks into a valid region
def _instantiate_region(region_dict, model) -> Region:
if len(region_dict["srcs"]) > 0:
sorted_srcs = dict(sorted({ src: EqualizationIndexes(0, IN_FEATURES, 0) for src in region_dict["srcs"]}.items()))
sorted_sinks = dict(sorted({ sink: EqualizationIndexes(0, IN_FEATURES, 0) for sink in region_dict["sinks"]}.items()))
sorted_srcs = dict(
sorted({src: EqualizationIndexes(0, IN_FEATURES, 0) for src in region_dict["srcs"]
}.items()))
sorted_sinks = dict(
sorted({sink: EqualizationIndexes(0, IN_FEATURES, 0) for sink in region_dict["sinks"]
}.items()))
else:
sorted_srcs = dict()
sorted_sinks = dict(sorted({ sink: EqualizationIndexes(0, IN_FEATURES, 0) for sink in region_dict["sinks"]}.items()))
sorted_sinks = dict(
sorted({sink: EqualizationIndexes(0, IN_FEATURES, 0) for sink in region_dict["sinks"]
}.items()))
sorted_acts = tuple()
return Region(
srcs=sorted_srcs,
sinks=sorted_sinks,
acts=sorted_acts,
name_to_module=model._modules
)
srcs=sorted_srcs, sinks=sorted_sinks, acts=sorted_acts, name_to_module=model._modules)


# Auxiliar function to compare the weights of module instances belonging to classes_to_compare
def compare_model_weights(model_fused, model_unfused, classes_to_compare=(nn.Linear,)):
Expand All @@ -395,33 +400,31 @@ def compare_model_weights(model_fused, model_unfused, classes_to_compare=(nn.Lin
if isinstance(module_fused, classes_to_compare):
module_unfused = reduce(getattr, [model_unfused] + name_module_fused.split("."))
for tensor_name in tensor_names:
if hasattr(module_fused, tensor_name) and getattr(module_fused, tensor_name) is not None:
if hasattr(module_fused, tensor_name) and getattr(module_fused,
tensor_name) is not None:
assert torch.allclose(getattr(module_fused, tensor_name), getattr(module_unfused, tensor_name), atol=0.0, rtol=0.0), f"Tensor {tensor_name} does not match for module {name_module_fused}"


@pytest_cases.parametrize(
'mask',
itertools.product([False, True], repeat=3),
ids=lambda mask: "-".join([rot for mask_el, rot in zip(mask, ["R1", "R2", "R3"]) if mask_el])
)
ids=lambda mask: "-".join([rot for mask_el, rot in zip(mask, ["R1", "R2", "R3"]) if mask_el]))
@pytest_cases.parametrize('full_rotation_method', ['ort', 'had'])
@pytest_cases.parametrize('device', ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu'])
@pytest_cases.parametrize('fuse_rotations', [False, True], ids=["fused", "unfused"])
@pytest_cases.parametrize('use_fx', [True, False], ids=["fx", "no-fx"])
def test_apply_rotate(block_residual_model, mask, full_rotation_method, device, fuse_rotations, use_fx):
def test_apply_rotate(
block_residual_model, mask, full_rotation_method, device, fuse_rotations, use_fx):
# Instantiate a residual model for which a collection of regions is available
model = block_residual_model()
device = torch.device("cuda") if device == 'cuda' else torch.device("cpu")
device = torch.device("cuda") if device == 'cuda' else torch.device("cpu")
model.to(device)
# Sample input to pass through the models
sample_inputs = torch.rand(size=(5, IN_FEATURES)).to(device)
# Collect only a subset of regions to be applied
regions_dicts = [
region_dict
for mask_element, region_dict
in zip(mask, RESIDUAL_MODEL_REGION_DICTS)
if mask_element
]
region_dict for mask_element,
region_dict in zip(mask, RESIDUAL_MODEL_REGION_DICTS) if mask_element]
# Use FX model if requested
if use_fx:
graph_model, _ = torch._dynamo.export(model)(sample_inputs)
Expand All @@ -434,41 +437,62 @@ def test_apply_rotate(block_residual_model, mask, full_rotation_method, device,
if hasattr(module, "weight") and graph_module.weight is module.weight:
map_model_graph[name] = graph_module_name
# Replace the names of the modules in sources/sinks by the names of the modules in the FX model
regions_dicts = [
{k: list(map(lambda x: map_model_graph[x], v)) for k, v in region_dict.items()} for region_dict in regions_dicts
]
regions_dicts = [{
k: list(map(lambda x: map_model_graph[x], v))
for k, v in region_dict.items()}
for region_dict in regions_dicts]
# Rotation will be applied on the FX model
model = graph_model

# Deepcopy the models as parameters are going to be modified in-place
rotated_model_unfused = copy.deepcopy(model)
rotated_model_fused = copy.deepcopy(model)

# Generator to control the random orthogonal matrices generated
generator = torch.Generator()
generator.manual_seed(SEED)
# Clone generator to make sure we can use the same rotation matrices
generator_clone = generator.clone_state()

# Apply rotations on the model with unfused rotations
regions_unfused = list(map(lambda x: _instantiate_region(x, rotated_model_unfused), regions_dicts))
regions_unfused = list(
map(lambda x: _instantiate_region(x, rotated_model_unfused), regions_dicts))
if full_rotation_method == 'had':
# _apply_ort_device is patched to ensure that the hadamard matrices in hadamard.pt are used, instead of
# the random ones generated by random_hadamard_matrices
with patch('brevitas.graph.equalize._apply_ort_device', lambda tensor, had_K, K: _apply_had_device(tensor, get_hadK(had_K.shape[0])[0], get_hadK(had_K.shape[0])[1])):
rewriters = _apply_rotate(rotated_model_unfused, regions_unfused, full_rotation_method=full_rotation_method, fuse_rotations=False)
elif full_rotation_method == 'ort':
with patch('brevitas.graph.equalize.random_orthogonal_matrix', partial(_random_orthogonal_matrix, generator=generator)):
rewriters = _apply_rotate(rotated_model_unfused, regions_unfused, full_rotation_method=full_rotation_method, fuse_rotations=False)
with patch('brevitas.graph.equalize._apply_ort_device',
lambda tensor,
had_K,
K: _apply_had_device(
tensor, get_hadK(had_K.shape[0])[0], get_hadK(had_K.shape[0])[1])):
rewriters = _apply_rotate(
rotated_model_unfused,
regions_unfused,
full_rotation_method=full_rotation_method,
fuse_rotations=False)
elif full_rotation_method == 'ort':
with patch('brevitas.graph.equalize.random_orthogonal_matrix',
partial(_random_orthogonal_matrix, generator=generator)):
rewriters = _apply_rotate(
rotated_model_unfused,
regions_unfused,
full_rotation_method=full_rotation_method,
fuse_rotations=False)
# Register parametrizations after calling _apply_rotate, as these are not inmediately registered since they alter the structure of the
# model, thus potentially causing a crash if the model is offloaded
for r in rewriters:
if isinstance(r, ModuleInstanceRegisterParametrization):
rotated_model_unfused = r.apply(rotated_model_unfused)
# Apply rotations on the model with fused rotations
with patch('brevitas.graph.equalize.random_orthogonal_matrix', partial(_random_orthogonal_matrix, generator=generator_clone)):
regions_fused = list(map(lambda x: _instantiate_region(x, rotated_model_fused), regions_dicts))
_apply_rotate(rotated_model_fused, regions_fused, full_rotation_method=full_rotation_method, fuse_rotations=True)
with patch('brevitas.graph.equalize.random_orthogonal_matrix',
partial(_random_orthogonal_matrix, generator=generator_clone)):
regions_fused = list(
map(lambda x: _instantiate_region(x, rotated_model_fused), regions_dicts))
_apply_rotate(
rotated_model_fused,
regions_fused,
full_rotation_method=full_rotation_method,
fuse_rotations=True)

# Compute outputs for each model
model_output = model(sample_inputs)
Expand All @@ -477,10 +501,13 @@ def test_apply_rotate(block_residual_model, mask, full_rotation_method, device,

# Verify that the correct number of unique rotation matrices were included. Orphan sinks (len(region_dict["srcs"]) == 0) do not
# an attached parametrization
assert sum([len(region_dict["srcs"]) > 0 for region_dict in regions_dicts]) == sum(["rot_mat" in name for name, _ in rotated_model_unfused.named_parameters(remove_duplicate=True)])
assert sum([len(region_dict["srcs"]) > 0 for region_dict in regions_dicts]) == sum([
"rot_mat" in name for name,
_ in rotated_model_unfused.named_parameters(remove_duplicate=True)])
# Verify that RotatedModules were added appropiately
for rotated_model in [rotated_model_fused, rotated_model_unfused]:
assert sum([len(region_dict["srcs"]) == 0 for region_dict in regions_dicts]) == sum([isinstance(module, RotatedModule) for module in rotated_model.modules()])
assert sum([len(region_dict["srcs"]) == 0 for region_dict in regions_dicts]) == sum([
isinstance(module, RotatedModule) for module in rotated_model.modules()])
# Optionally fuse the rotations
if fuse_rotations:
rotated_model_unfused = _fuse_rotations(rotated_model_unfused)
Expand All @@ -489,7 +516,8 @@ def test_apply_rotate(block_residual_model, mask, full_rotation_method, device,
assert not parametrize.is_parametrized(module)
# Outputs should match for rotated and unrotated models
assert torch.allclose(model_output, rotated_model_fused_output, atol=ATOL)
assert torch.allclose(rotated_model_unfused_output, rotated_model_fused_output, atol=0.0, rtol=0.0)
assert torch.allclose(
rotated_model_unfused_output, rotated_model_fused_output, atol=0.0, rtol=0.0)
# Verify that the weights have changed with respect to the unrotated module for the modules that have received parametrizations
# Verify that weights match between the fused and unfused model
compare_model_weights(rotated_model_fused, rotated_model_unfused)
compare_model_weights(rotated_model_fused, rotated_model_unfused)
Loading

0 comments on commit 8b71a53

Please sign in to comment.