Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable parametrized rotations
Browse files Browse the repository at this point in the history
pablomlago committed Jan 9, 2025
1 parent 860b437 commit e536294
Showing 6 changed files with 620 additions and 33 deletions.
104 changes: 104 additions & 0 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
@@ -5,10 +5,14 @@
from abc import abstractmethod
import inspect
from inspect import getcallargs
from typing import Any, Callable, Dict, Type, Union, Optional

import torch
from torch.nn import Module
from torch.nn import Parameter
from torch import Tensor
from torch.overrides import get_testing_overrides
import torch.nn.utils.parametrize as parametrize

from brevitas.fx import GraphModule
from brevitas.fx import immutable_dict
@@ -173,6 +177,106 @@ def apply(self, graph_model: GraphModule) -> GraphModule:
graph_model.graph.lint()
return graph_model

class ModuleInstanceRegisterParametrization(Transform):

def __init__(
self, old_module_instance: Module, tensor_name: str,
parametrization_module: Module) -> None:
self.old_module_instance = old_module_instance
self.tensor_name = tensor_name
self.parametrization_module = parametrization_module

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
# register the parametrization in the old_module
parametrize.register_parametrization(
old_module, self.tensor_name, self.parametrization_module, unsafe=True)
break
return model

class RotationWeightParametrization(torch.nn.Module):

def __init__(
self,
rot_mat: torch.nn.Parameter,
rot_func: Callable,
axis: int,
K: Optional[int],
) -> None:
super().__init__()
self.rot_mat = rot_mat
self.rot_func = rot_func
self.axis = axis
self.K = K

def forward(self, weight: torch.Tensor) -> torch.Tensor:

if self.axis == 0:
weight = self.rot_func(weight.t(), self.rot_mat, self.K).t()
elif self.axis == 1:
weight = self.rot_func(weight, self.rot_mat, self.K)
else:
raise RuntimeError("Not supported yet")

return weight

class ModuleInstanceFuseRotationWeights(Transform):

def __init__(
self,
old_module_instance: Module,
rot_mat: Union[Parameter, Tensor],
rot_func: Callable,
K: Optional[int],
tensor_name: str,
axis: int,
):
self.old_module_instance = old_module_instance
self.rot_mat = rot_mat
self.rot_func = rot_func
self.K = K
self.tensor_name = tensor_name
self.axis = axis

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
if hasattr(old_module, 'allocate_params'):
old_module.allocate_params(old_module)
weight = getattr(old_module, self.tensor_name).data
weight = RotationWeightParametrization(self.rot_mat, self.rot_func, self.axis, self.K)(weight)
# Modify the weights in-place
getattr(old_module, self.tensor_name).data = weight

if hasattr(old_module, 'offload_params'):
old_module.offload_params(old_module)
break
return model

class ModuleInstanceWrapModule(Transform):

def __init__(
self,
old_module_instance: Module,
wrapper_class: Type[Module],
module_attribute: str,
kwargs_wrapper: Dict[str, Any]):
self.old_module_instance = old_module_instance
self.wrapper_class = wrapper_class
self.module_attribute = module_attribute
self.kwargs_wrapper = kwargs_wrapper

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
kwargs = {self.module_attribute: self.old_module_instance}
kwargs.update(self.kwargs_wrapper)
new_module_instance = self.wrapper_class(**kwargs)
# init the new module based on the old one
replace_module(model, old_module, new_module_instance)
break
return model

class ModuleInstanceToModuleInstance(Transform):

134 changes: 101 additions & 33 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
@@ -15,19 +15,24 @@
import torch
from torch.fx import GraphModule as TorchGraphModule
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

from brevitas import torch_version
from brevitas.fx import GraphModule
from brevitas.fx import Node
from brevitas.graph import ModuleToModuleByClass
from brevitas.graph import ModuleToModuleByInstance
from brevitas.graph.base import GraphTransform
from brevitas.graph.base import InsertModuleCallAfter
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleInstanceFuseRotationWeights
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import ModuleInstanceWrapModule
from brevitas.graph.base import RotationWeightParametrization
from brevitas.graph.base import Transform
from brevitas.graph.hadamard import get_hadK
from brevitas.graph.hadamard import matmul_hadU
from brevitas.graph.hadamard import matmul_hadU_cuda
from brevitas.graph.hadamard import random_hadamard_matrix
from brevitas.graph.utils import get_module
from brevitas.graph.utils import get_node
from brevitas.nn.equalized_layer import EqualizedModule
@@ -1299,7 +1304,7 @@ def random_orthogonal_matrix(size):
return q


def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had'):
def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had', fuse_rotations: bool = True, apply_inplace_rotations: bool = True):
rewriters = []
for region in regions:
insert_rotation_module = len(region.srcs) == 0
@@ -1311,6 +1316,14 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
rot_mat = random_orthogonal_matrix(hidden_dim)
K = None
rot_func = _apply_ort_device
elif not insert_rotation_module and not fuse_rotations:
# If the model is distributed across GPUs, the device will be
# not be the same for all of the parameters, so explicit moves
# to the same device as the weights need to be added
device = next(model.parameters()).device
rot_mat = random_hadamard_matrix(hidden_dim, device)
K = None
rot_func = _apply_ort_device
else:
try:
# Build hadamard rotation matrix
@@ -1326,53 +1339,108 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
print("Skipping layers")
continue

# If the rotation is not fused, redefine as a Parameter, to enable its optimization
if not insert_rotation_module and not fuse_rotations:
rot_mat = torch.nn.Parameter(rot_mat)

for name, indexes in region.srcs.items():
module = region.get_module_from_name(name)
if hasattr(module, 'allocate_params'):
module.allocate_params(module)
axis = _get_output_axis(module)
weight = module.weight.data

if axis == 0:
weight = rot_func(weight.t(), rot_mat, K).t()
elif axis == 1:
weight = rot_func(weight, rot_mat, K)
else:
raise RuntimeError("Not supported yet")
module.weight.data = weight
if fuse_rotations:
rewriter = ModuleInstanceFuseRotationWeights(
old_module_instance=module,
rot_mat=rot_mat,
rot_func=rot_func,
K=K,
tensor_name="weight",
axis=axis,
)
rewriters.append(rewriter)

if getattr(module, 'bias', None) is not None:
bias = module.bias.data
bias = rot_func(bias, rot_mat, K)
module.bias.data = bias
if hasattr(module, 'offload_params'):
module.offload_params(module)
if getattr(module, 'bias', None) is not None:
rewriter = ModuleInstanceFuseRotationWeights(
old_module_instance=module,
rot_mat=rot_mat,
rot_func=rot_func,
K=K,
tensor_name="bias",
axis=1,
)
rewriters.append(rewriter)
else:
rewriter = ModuleInstanceRegisterParametrization(
module,
"weight",
RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=rot_func,
axis=axis,
K=K,
))
rewriters.append(rewriter)
if getattr(module, 'bias', None) is not None:
rewriter = ModuleInstanceRegisterParametrization(
module,
"bias",
RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=rot_func,
axis=1,
K=K,
))
rewriters.append(rewriter)

for name, indexes in region.sinks.items():
module = region.get_module_from_name(name)
if hasattr(module, 'allocate_params'):
module.allocate_params(module)
axis = _get_input_axis(module)
weight = module.weight.data

if axis == 1:
_update_weights(module, rot_func(weight, rot_mat, K), 'weight')
elif axis == 0:
_update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight')
if not insert_rotation_module and not fuse_rotations:
rewriter = ModuleInstanceRegisterParametrization(
module,
"weight",
RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=rot_func,
axis=axis,
K=K,
))
rewriters.append(rewriter)
else:
raise RuntimeError("Not supported yet")

if hasattr(module, 'offload_params'):
module.offload_params(module)
rewriter = ModuleInstanceFuseRotationWeights(
old_module_instance=module,
rot_mat=rot_mat,
rot_func=rot_func,
K=K,
tensor_name="weight",
axis=axis,
)
rewriters.append(rewriter)

if insert_rotation_module and len(region.srcs) == 0:
rewriter = ModuleInstanceToModuleInstance(
module, RotatedModule(had_mat=rot_mat, k=K, layer=module))
rewriter = ModuleInstanceWrapModule(
module, RotatedModule, "layer", {
"had_mat": None, "k": K})
rewriters.append(rewriter)
for r in rewriters:
model = r.apply(model)
if apply_inplace_rotations:
for r in rewriters:
# The parametrizations need to be registered after the potential HF hooks have been
# removed, as otherwise the device maps will not match the structure of the
# model's state_dict after the registration of the parametrizations.
if not isinstance(r, ModuleInstanceRegisterParametrization):
model = r.apply(model)
return rewriters

def _fuse_rotations(model: nn.Module) -> nn.Module:
for module in model.modules():
# Names of the tensors that can potentially be parametrized
tensor_names = ["weight", "bias"]
# Remove parametrizations from each tensor
for tensor_name in tensor_names:
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
parametrize.remove_parametrizations(module, tensor_name, leave_parametrized=True)
return model


def _replace_bias(next_module, new_bias):
new_bias = new_bias.view(-1)
46 changes: 46 additions & 0 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
@@ -528,3 +528,49 @@ def forward(self, x):

rotation_fixtures = fixture_union(
'rotation_fixtures', list_of_rotation_mixtures, ids=list_of_rotation_mixtures)

IN_FEATURES = 12
RESIDUAL_MODEL_REGION_DICTS = [
{
"srcs": ["embedding", "block1_linear2", "block2_linear2"],
"sinks": ["block1_linear1", "block2_linear1", "head"],
},
{
"srcs": ["block1_linear1"],
"sinks": ["block1_linear2"]
},
{
"srcs": [],
"sinks": ["block2_linear2"]
},
]
@pytest_cases.fixture
def block_residual_model():
class BlockResidualModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.embedding = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)

self.block1_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True)
self.block1_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)

self.block2_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)
self.act = nn.SiLU()
self.block2_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True)

self.head = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False)

def forward(self, x):
x = self.embedding(x)
r = x
x = self.block1_linear1(x)
x = self.block1_linear2(x) + r
r = x
x = self.block2_linear1(x)
x = self.act(x)
x = self.block2_linear2(x) + r
x = self.head(x)
return x

return BlockResidualModel
222 changes: 222 additions & 0 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
@@ -2,18 +2,35 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
import itertools
from functools import partial
from unittest.mock import patch

import pytest
import torch
import torch.nn.utils.parametrize as parametrize
from torchvision import models

from brevitas.fx import symbolic_trace
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import RotationWeightParametrization
from brevitas.graph.equalize import Region
from brevitas.graph.equalize import EqualizationIndexes
from brevitas.graph.equalize import _batch_norm
from brevitas.graph.equalize import _extract_regions
from brevitas.graph.equalize import _is_supported_module
from brevitas.graph.equalize import _supported_layers
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import MergeLnAffine
from brevitas.graph.equalize import random_orthogonal_matrix
from brevitas.graph.equalize import _apply_rotate
from brevitas.graph.equalize import _apply_had_device
from brevitas.graph.equalize import _fuse_rotations
from brevitas.graph.equalize import _apply_ort_device
from brevitas.graph.equalize import _get_input_axis
from brevitas.graph.equalize import _get_output_axis
from brevitas.nn.equalized_layer import RotatedModule
from brevitas.graph.standardize import DuplicateSharedStatelessModule
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.graph.utils import get_module
@@ -276,3 +293,208 @@ def test_models(rotation_fixtures, partial_had):
if partial_had:
last_weight_new = model.linear_2.layer.weight.data
assert not torch.allclose(last_weight, last_weight_new)


@pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}")
def test_composition_unfused_rotations(N):
torch.manual_seed(SEED)

for rotation_flags in itertools.product([False, True], repeat=N):

in_features = 5
module = nn.Linear(in_features=in_features, out_features=in_features)
rot_module = copy.deepcopy(module)

# Sample input to pass through the block
sample_input = torch.rand((1, in_features),)
# Composite rotation matrices
rot_mat_input = torch.eye(in_features)
rot_mat_output = torch.eye(in_features)

for is_source in rotation_flags:
# Generate a random matrix
rot_mat = random_orthogonal_matrix(in_features).to(dtype=torch.float32)

# Aggregate rotation matrices
if is_source:
rot_mat_output = rot_mat_output @ rot_mat
else:
rot_mat_input = rot_mat_input @ rot_mat

# Compose rotation modules
parametrize.register_parametrization(
rot_module,
"weight",
RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=_apply_ort_device,
axis=_get_output_axis(rot_module) if is_source else _get_input_axis(rot_module),
))
if is_source:
parametrize.register_parametrization(
rot_module,
"bias",
RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=_apply_ort_device,
axis=1,
))

# If the node is a sink, the input is multiplied by the inverse of the rotation matrix x <- xQ^{-1}
# If the node is a source, the output is multiplied by the rotation matrix o <- oQ
gt_output = module(sample_input @ rot_mat_input.t()) @ rot_mat_output
rot_output = rot_module(sample_input)

# Verify that the rotation operations were computed correctly
assert torch.allclose(gt_output, rot_output, atol=ATOL)


# This method is almost the same as brevitas.graph.equalize.random_orthogonal_matrix, except for the
# possibility of passing a generator, that enables controlling the random matrices that are generated
# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26
# This functions needs to be patches to enable passing the generator and ensuring that the orthogonal
# matrices generated are the same.
def _random_orthogonal_matrix(size, generator):
"""
Generate a random orthogonal matrix of the specified size.
First, we generate a random matrix with entries from a standard distribution.
Then, we use QR decomposition to obtain an orthogonal matrix.
Finally, we multiply by a diagonal matrix with diag r to adjust the signs.
Args:
size (int): The size of the matrix (size x size).
Returns:
torch.Tensor: An orthogonal matrix of the specified size.
"""
torch.cuda.empty_cache()
random_matrix = torch.randn(size, size, dtype=torch.float64, generator=generator)
q, r = torch.linalg.qr(random_matrix)
q *= torch.sign(torch.diag(r)).unsqueeze(0).float()
return q

# Auxiliar method to convert a dictionary of sources/sinks into a valid region
def _instantiate_region(region_dict, model) -> Region:
if len(region_dict["srcs"]) > 0:
sorted_srcs = dict(sorted({ src: EqualizationIndexes(0, IN_FEATURES, 0) for src in region_dict["srcs"]}.items()))
sorted_sinks = dict(sorted({ sink: EqualizationIndexes(0, IN_FEATURES, 0) for sink in region_dict["sinks"]}.items()))
else:
sorted_srcs = dict()
sorted_sinks = dict(sorted({ sink: EqualizationIndexes(0, IN_FEATURES, 0) for sink in region_dict["sinks"]}.items()))
sorted_acts = tuple()
return Region(
srcs=sorted_srcs,
sinks=sorted_sinks,
acts=sorted_acts,
name_to_module=model._modules
)

# Auxiliar function to compare the weights of modules instances belonging to classes_to_compare
def _compare_model_weights(model_fused, model_unfused, modules_not_matching=[], classes_to_compare=(nn.Linear,)):
tensor_names = ["weight", "bias"]
for (name_module_fused, module_fused), (_, module_unfused) in zip(model_fused.named_parameters(), model_unfused.named_parameters()):
if isinstance(module_fused, classes_to_compare):
for tensor_name in tensor_names:
if hasattr(module_fused, tensor_name):
if name_module_fused in modules_not_matching:
assert not torch.allclose(getattr(module_fused, tensor_name), getattr(module_unfused, tensor_name), atol=ATOL), f"Tensor {tensor_name} should not match for module {name_module_fused}"
else:
assert torch.allclose(getattr(module_fused, tensor_name), getattr(module_unfused, tensor_name), atol=0.0, rtol=0.0), f"Tensor {tensor_name} does not match for module {name_module_fused}"


@pytest_cases.parametrize(
'mask',
itertools.product([False, True], repeat=3),
ids=lambda mask: "-".join([rot for mask_el, rot in zip(mask, ["R1", "R2", "R3"]) if mask_el])
)
@pytest_cases.parametrize('full_rotation_method', ['ort', 'had'])
@pytest_cases.parametrize('device', ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu'])
@pytest_cases.parametrize('fuse_rotations', [False, True], ids=["fused", "unfused"])
@pytest_cases.parametrize('use_fx', [True, False], ids=["fx", "no-fx"])
def test_apply_rotate(block_residual_model, mask, full_rotation_method, device, fuse_rotations, use_fx):
# Instantiate a residual model for which a collection of regions is available
model = block_residual_model()
device = torch.device("cuda") if device == 'cuda' else torch.device("cpu")
model.to(device)
# Sample input to pass through the models
sample_inputs = torch.rand(size=(5, IN_FEATURES)).to(device)
# Collect only a subset of regions to be applied
regions_dicts = [
region_dict
for mask_element, region_dict
in zip(mask, RESIDUAL_MODEL_REGION_DICTS)
if mask_element
]
# Use FX model if requested
if use_fx:
graph_model, _ = torch._dynamo.export(model)(sample_inputs)
# The module names in the original model need to be mapped to the ones
# in graph_model
map_model_graph = {}
for graph_module_name, graph_module in graph_model.named_modules():
if hasattr(graph_module, "weight"):
for name, module in model.named_modules():
if hasattr(module, "weight") and graph_module.weight is module.weight:
map_model_graph[name] = graph_module_name
# Replace the names of the modules in sources/sinks by the names of the modules in the FX model
regions_dicts = [
{k: list(map(lambda x: map_model_graph[x], v)) for k, v in region_dict.items()} for region_dict in regions_dicts
]
# Rotation will be applied on the FX model
model = graph_model

# Deepcopy the models as parameters are going to be modified in-place
rotated_model_unfused = copy.deepcopy(model)
rotated_model_fused = copy.deepcopy(model)

# Generator to control the random orthogonal matrices generated
generator = torch.Generator()
generator.manual_seed(SEED)
# Clone generator to make sure we can use the same rotation matrices
generator_clone = generator.clone_state()

# Apply rotations on the model with unfused rotations
regions_unfused = list(map(lambda x: _instantiate_region(x, rotated_model_unfused), regions_dicts))
if full_rotation_method == 'had':
# _apply_ort_device is patched to ensure that the hadamard matrices in hadamard.pt are used, instead of
# the random ones generated by random_hadamard_matrices
with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device):
rewriters = _apply_rotate(rotated_model_unfused, regions_unfused, full_rotation_method=full_rotation_method, fuse_rotations=False)
elif full_rotation_method == 'ort':
with patch('brevitas.graph.equalize.random_orthogonal_matrix', partial(_random_orthogonal_matrix, generator=generator)):
rewriters = _apply_rotate(rotated_model_unfused, regions_unfused, full_rotation_method=full_rotation_method, fuse_rotations=False)
# Register parametrizations after calling _apply_rotate, as these are not inmediately registered since they alter the structure of the
# model, thus potentially causing a crash if the model is offloaded
for r in rewriters:
if isinstance(r, ModuleInstanceRegisterParametrization):
rotated_model_unfused = r.apply(rotated_model_unfused)
# Apply rotations on the model with fused rotations
with patch('brevitas.graph.equalize.random_orthogonal_matrix', partial(_random_orthogonal_matrix, generator=generator_clone)):
regions_fused = list(map(lambda x: _instantiate_region(x, rotated_model_fused), regions_dicts))
_apply_rotate(rotated_model_fused, regions_fused, full_rotation_method=full_rotation_method, fuse_rotations=True)

# Compute outputs for each model
model_output = model(sample_inputs)
rotated_model_unfused_output = rotated_model_unfused(sample_inputs)
rotated_model_fused_output = rotated_model_fused(sample_inputs)

# Verify that the correct number of unique rotation matrices were included. Orphan sinks (len(region_dict["srcs"]) == 0) do not
# an attached parametrization
assert sum([len(region_dict["srcs"]) > 0 for region_dict in regions_dicts]) == sum(["rot_mat" in name for name, _ in rotated_model_unfused.named_parameters(remove_duplicate=True)])
# Verify that RotatedModules were added appropiately
for rotated_model in [rotated_model_fused, rotated_model_unfused]:
assert sum([len(region_dict["srcs"]) == 0 for region_dict in regions_dicts]) == sum([isinstance(module, RotatedModule) for module in rotated_model.modules()])
# Optionally fuse the rotations
if fuse_rotations:
rotated_model_unfused = _fuse_rotations(rotated_model_unfused)
# Verify that no parametrizations remain after fusing
for module in rotated_model_unfused.modules():
assert not parametrize.is_parametrized(module)
# Outputs should match for rotated and unrotated models
assert torch.allclose(model_output, rotated_model_fused_output, atol=ATOL)
assert torch.allclose(rotated_model_unfused_output, rotated_model_fused_output, atol=0.0, rtol=0.0)
# Verify that the weights have changed with respect to the unrotated module for the modules that have received parametrizations
rotated_modules = set([module_name for region_dict in regions_dicts for module_name in (region_dict["srcs"] + region_dict["sinks"])])
for rotated_model in [rotated_model_fused, rotated_model_unfused]:
_compare_model_weights(model, rotated_model, modules_not_matching=rotated_modules)
# Verify that weights match between the fused and unfused model
_compare_model_weights(rotated_model_fused, rotated_model_unfused)

66 changes: 66 additions & 0 deletions tests/brevitas/graph/test_transforms.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@

from packaging import version
import pytest
import pytest_cases
import torch
from torch import nn
from torchvision import models
@@ -17,9 +18,14 @@
from brevitas.graph import MergeBatchNorm
from brevitas.graph import MethodToModule
from brevitas.graph.base import ModuleToModuleByInstance
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import ModuleInstanceFuseRotationWeights
from brevitas.graph.base import RotationWeightParametrization
from brevitas.graph.base import ModuleInstanceWrapModule
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d
from brevitas.nn.equalized_layer import RotatedModule

SEED = 123456
INPUT_SIZE = (1, 3, 224, 224)
@@ -290,3 +296,63 @@ def forward(self, x):
kwargs = {'stride': lambda module, name: 2 if module.in_channels == 3 else 1}
model = ModuleToModuleByInstance(model.conv, nn.Conv2d, **kwargs).apply(model)
assert model.conv.stride == (2, 2)


def test_module_instance_register_parametrization():

class TestModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(2,2, bias=False)

def forward(self, x):
return self.linear(x)

class ZeroParametrization(nn.Module):
def forward(self, x):
return torch.zeros_like(x)

model = TestModel()
model = ModuleInstanceRegisterParametrization(model.linear, "weight", ZeroParametrization()).apply(model)
assert torch.all(model.linear.weight == 0.)


def test_module_instance_wrap_module():
class TestModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(2,2, bias=False)

def forward(self, x):
return self.linear(x)

model = TestModel()
model = ModuleInstanceWrapModule(model.linear, RotatedModule, "layer", {"had_mat": None, "k": None}).apply(model)
assert isinstance(model.linear, RotatedModule)


@pytest_cases.parametrize("axis", [0, 1], ids=lambda axis:f"axis={axis}")
def test_fuse_rotation_weights(axis):

def rot_func(weight, ort, K):
return torch.matmul(weight, ort)

class TestModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(2,2, bias=False)

def forward(self, x):
return self.linear(x)

rot_mat = torch.rand(2,2)
model_fused = TestModel()
model_unfused = TestModel()
model_unfused.linear.weight.data = model_fused.linear.weight.data

model_fused = ModuleInstanceFuseRotationWeights(model_fused.linear, rot_mat, rot_func, None, "weight", axis).apply(model_fused)
model_unfused = ModuleInstanceRegisterParametrization(model_unfused.linear, "weight", RotationWeightParametrization(rot_mat, rot_func, axis, None)).apply(model_unfused)
assert torch.all(model_fused.linear.weight == model_unfused.linear.weight)
81 changes: 81 additions & 0 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
@@ -723,3 +723,84 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl):
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"

@pytest_cases.fixture(
ids=[
"llama_fused_rotation_ort",
"llama_fused_rotation_ort_no_orphan",
"llama_fused_rotation_had",
"llama_fused_rotation_had_no_orphan",
"llama_layerwise",],
params=[
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx",
"rotation_orphan_sink": True,
"rotation_mode": "ort",
"float_ppl": 33238.8984375,
"quant_ppl": 33232.65234375},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx",
"rotation_orphan_sink": False,
"rotation_mode": "ort",
"float_ppl": 33238.8984375,
"quant_ppl": 33420.65234375},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx",
"rotation_orphan_sink": True,
"rotation_mode": "had",
"float_ppl": 33238.8984375,
"quant_ppl": 33290.48046875},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx",
"rotation_orphan_sink": False,
"rotation_mode": "had",
"float_ppl": 33238.8984375,
"quant_ppl": 33204.80859375},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "layerwise",
"float_ppl": 33238.8984375,
"quant_ppl": 33446.734375},])
def rotation_ppl_args_and_ppl(default_run_args, request):
args = default_run_args
run_dict = request.param
float_ppl = run_dict["float_ppl"]
quant_ppl = run_dict["quant_ppl"]
del run_dict["float_ppl"]
del run_dict["quant_ppl"]
args.update(**run_dict)
yield args, float_ppl, quant_ppl

@requires_pt_ge('2.4')
def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
caplog.set_level(logging.INFO)
args, exp_float_ppl, exp_quant_ppl = rotation_ppl_args_and_ppl
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
float_ppl = float_ppl.detach().cpu().numpy()
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"

0 comments on commit e536294

Please sign in to comment.