From e53629433610372bd2a63b8605a4d4ff14e47ec1 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 9 Jan 2025 20:05:30 +0000 Subject: [PATCH] Enable parametrized rotations --- src/brevitas/graph/base.py | 104 ++++++++ src/brevitas/graph/equalize.py | 134 ++++++++--- tests/brevitas/graph/equalization_fixtures.py | 46 ++++ tests/brevitas/graph/test_equalization.py | 222 ++++++++++++++++++ tests/brevitas/graph/test_transforms.py | 66 ++++++ tests/brevitas_examples/test_llm.py | 81 +++++++ 6 files changed, 620 insertions(+), 33 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index def3f7070..0052351d0 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -5,10 +5,14 @@ from abc import abstractmethod import inspect from inspect import getcallargs +from typing import Any, Callable, Dict, Type, Union, Optional import torch from torch.nn import Module +from torch.nn import Parameter +from torch import Tensor from torch.overrides import get_testing_overrides +import torch.nn.utils.parametrize as parametrize from brevitas.fx import GraphModule from brevitas.fx import immutable_dict @@ -173,6 +177,106 @@ def apply(self, graph_model: GraphModule) -> GraphModule: graph_model.graph.lint() return graph_model +class ModuleInstanceRegisterParametrization(Transform): + + def __init__( + self, old_module_instance: Module, tensor_name: str, + parametrization_module: Module) -> None: + self.old_module_instance = old_module_instance + self.tensor_name = tensor_name + self.parametrization_module = parametrization_module + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + # register the parametrization in the old_module + parametrize.register_parametrization( + old_module, self.tensor_name, self.parametrization_module, unsafe=True) + break + return model + +class RotationWeightParametrization(torch.nn.Module): + + def __init__( + self, + rot_mat: torch.nn.Parameter, + rot_func: Callable, + axis: int, + K: Optional[int], + ) -> None: + super().__init__() + self.rot_mat = rot_mat + self.rot_func = rot_func + self.axis = axis + self.K = K + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + + if self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + elif self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + else: + raise RuntimeError("Not supported yet") + + return weight + +class ModuleInstanceFuseRotationWeights(Transform): + + def __init__( + self, + old_module_instance: Module, + rot_mat: Union[Parameter, Tensor], + rot_func: Callable, + K: Optional[int], + tensor_name: str, + axis: int, + ): + self.old_module_instance = old_module_instance + self.rot_mat = rot_mat + self.rot_func = rot_func + self.K = K + self.tensor_name = tensor_name + self.axis = axis + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + if hasattr(old_module, 'allocate_params'): + old_module.allocate_params(old_module) + weight = getattr(old_module, self.tensor_name).data + weight = RotationWeightParametrization(self.rot_mat, self.rot_func, self.axis, self.K)(weight) + # Modify the weights in-place + getattr(old_module, self.tensor_name).data = weight + + if hasattr(old_module, 'offload_params'): + old_module.offload_params(old_module) + break + return model + +class ModuleInstanceWrapModule(Transform): + + def __init__( + self, + old_module_instance: Module, + wrapper_class: Type[Module], + module_attribute: str, + kwargs_wrapper: Dict[str, Any]): + self.old_module_instance = old_module_instance + self.wrapper_class = wrapper_class + self.module_attribute = module_attribute + self.kwargs_wrapper = kwargs_wrapper + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + kwargs = {self.module_attribute: self.old_module_instance} + kwargs.update(self.kwargs_wrapper) + new_module_instance = self.wrapper_class(**kwargs) + # init the new module based on the old one + replace_module(model, old_module, new_module_instance) + break + return model class ModuleInstanceToModuleInstance(Transform): diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 4e5c1a162..ce1f7041b 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -15,19 +15,24 @@ import torch from torch.fx import GraphModule as TorchGraphModule import torch.nn as nn +import torch.nn.utils.parametrize as parametrize from brevitas import torch_version from brevitas.fx import GraphModule from brevitas.fx import Node -from brevitas.graph import ModuleToModuleByClass from brevitas.graph import ModuleToModuleByInstance from brevitas.graph.base import GraphTransform from brevitas.graph.base import InsertModuleCallAfter from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import ModuleInstanceFuseRotationWeights +from brevitas.graph.base import ModuleInstanceRegisterParametrization +from brevitas.graph.base import ModuleInstanceWrapModule +from brevitas.graph.base import RotationWeightParametrization from brevitas.graph.base import Transform from brevitas.graph.hadamard import get_hadK from brevitas.graph.hadamard import matmul_hadU from brevitas.graph.hadamard import matmul_hadU_cuda +from brevitas.graph.hadamard import random_hadamard_matrix from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node from brevitas.nn.equalized_layer import EqualizedModule @@ -1299,7 +1304,7 @@ def random_orthogonal_matrix(size): return q -def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had'): +def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had', fuse_rotations: bool = True, apply_inplace_rotations: bool = True): rewriters = [] for region in regions: insert_rotation_module = len(region.srcs) == 0 @@ -1311,6 +1316,14 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= rot_mat = random_orthogonal_matrix(hidden_dim) K = None rot_func = _apply_ort_device + elif not insert_rotation_module and not fuse_rotations: + # If the model is distributed across GPUs, the device will be + # not be the same for all of the parameters, so explicit moves + # to the same device as the weights need to be added + device = next(model.parameters()).device + rot_mat = random_hadamard_matrix(hidden_dim, device) + K = None + rot_func = _apply_ort_device else: try: # Build hadamard rotation matrix @@ -1326,53 +1339,108 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= print("Skipping layers") continue + # If the rotation is not fused, redefine as a Parameter, to enable its optimization + if not insert_rotation_module and not fuse_rotations: + rot_mat = torch.nn.Parameter(rot_mat) + for name, indexes in region.srcs.items(): module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) axis = _get_output_axis(module) - weight = module.weight.data - if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() - elif axis == 1: - weight = rot_func(weight, rot_mat, K) - else: - raise RuntimeError("Not supported yet") - module.weight.data = weight + if fuse_rotations: + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="weight", + axis=axis, + ) + rewriters.append(rewriter) - if getattr(module, 'bias', None) is not None: - bias = module.bias.data - bias = rot_func(bias, rot_mat, K) - module.bias.data = bias - if hasattr(module, 'offload_params'): - module.offload_params(module) + if getattr(module, 'bias', None) is not None: + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="bias", + axis=1, + ) + rewriters.append(rewriter) + else: + rewriter = ModuleInstanceRegisterParametrization( + module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + axis=axis, + K=K, + )) + rewriters.append(rewriter) + if getattr(module, 'bias', None) is not None: + rewriter = ModuleInstanceRegisterParametrization( + module, + "bias", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + axis=1, + K=K, + )) + rewriters.append(rewriter) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) axis = _get_input_axis(module) - weight = module.weight.data - if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') - elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') + if not insert_rotation_module and not fuse_rotations: + rewriter = ModuleInstanceRegisterParametrization( + module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + axis=axis, + K=K, + )) + rewriters.append(rewriter) else: - raise RuntimeError("Not supported yet") - - if hasattr(module, 'offload_params'): - module.offload_params(module) + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="weight", + axis=axis, + ) + rewriters.append(rewriter) if insert_rotation_module and len(region.srcs) == 0: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) + rewriter = ModuleInstanceWrapModule( + module, RotatedModule, "layer", { + "had_mat": None, "k": K}) rewriters.append(rewriter) - for r in rewriters: - model = r.apply(model) + if apply_inplace_rotations: + for r in rewriters: + # The parametrizations need to be registered after the potential HF hooks have been + # removed, as otherwise the device maps will not match the structure of the + # model's state_dict after the registration of the parametrizations. + if not isinstance(r, ModuleInstanceRegisterParametrization): + model = r.apply(model) return rewriters +def _fuse_rotations(model: nn.Module) -> nn.Module: + for module in model.modules(): + # Names of the tensors that can potentially be parametrized + tensor_names = ["weight", "bias"] + # Remove parametrizations from each tensor + for tensor_name in tensor_names: + if parametrize.is_parametrized(module) and tensor_name in module.parametrizations: + parametrize.remove_parametrizations(module, tensor_name, leave_parametrized=True) + return model + def _replace_bias(next_module, new_bias): new_bias = new_bias.view(-1) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 035cdaadd..a31e3cfef 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -528,3 +528,49 @@ def forward(self, x): rotation_fixtures = fixture_union( 'rotation_fixtures', list_of_rotation_mixtures, ids=list_of_rotation_mixtures) + +IN_FEATURES = 12 +RESIDUAL_MODEL_REGION_DICTS = [ + { + "srcs": ["embedding", "block1_linear2", "block2_linear2"], + "sinks": ["block1_linear1", "block2_linear1", "head"], + }, + { + "srcs": ["block1_linear1"], + "sinks": ["block1_linear2"] + }, + { + "srcs": [], + "sinks": ["block2_linear2"] + }, +] +@pytest_cases.fixture +def block_residual_model(): + class BlockResidualModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.embedding = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + + self.block1_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True) + self.block1_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + + self.block2_linear1 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + self.act = nn.SiLU() + self.block2_linear2 = nn.Linear(IN_FEATURES, IN_FEATURES, bias=True) + + self.head = nn.Linear(IN_FEATURES, IN_FEATURES, bias=False) + + def forward(self, x): + x = self.embedding(x) + r = x + x = self.block1_linear1(x) + x = self.block1_linear2(x) + r + r = x + x = self.block2_linear1(x) + x = self.act(x) + x = self.block2_linear2(x) + r + x = self.head(x) + return x + + return BlockResidualModel diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index afb8636e4..da8d9b987 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -2,11 +2,20 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +import itertools +from functools import partial +from unittest.mock import patch +import pytest import torch +import torch.nn.utils.parametrize as parametrize from torchvision import models from brevitas.fx import symbolic_trace +from brevitas.graph.base import ModuleInstanceRegisterParametrization +from brevitas.graph.base import RotationWeightParametrization +from brevitas.graph.equalize import Region +from brevitas.graph.equalize import EqualizationIndexes from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions from brevitas.graph.equalize import _is_supported_module @@ -14,6 +23,14 @@ from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine +from brevitas.graph.equalize import random_orthogonal_matrix +from brevitas.graph.equalize import _apply_rotate +from brevitas.graph.equalize import _apply_had_device +from brevitas.graph.equalize import _fuse_rotations +from brevitas.graph.equalize import _apply_ort_device +from brevitas.graph.equalize import _get_input_axis +from brevitas.graph.equalize import _get_output_axis +from brevitas.nn.equalized_layer import RotatedModule from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module @@ -276,3 +293,208 @@ def test_models(rotation_fixtures, partial_had): if partial_had: last_weight_new = model.linear_2.layer.weight.data assert not torch.allclose(last_weight, last_weight_new) + + +@pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}") +def test_composition_unfused_rotations(N): + torch.manual_seed(SEED) + + for rotation_flags in itertools.product([False, True], repeat=N): + + in_features = 5 + module = nn.Linear(in_features=in_features, out_features=in_features) + rot_module = copy.deepcopy(module) + + # Sample input to pass through the block + sample_input = torch.rand((1, in_features),) + # Composite rotation matrices + rot_mat_input = torch.eye(in_features) + rot_mat_output = torch.eye(in_features) + + for is_source in rotation_flags: + # Generate a random matrix + rot_mat = random_orthogonal_matrix(in_features).to(dtype=torch.float32) + + # Aggregate rotation matrices + if is_source: + rot_mat_output = rot_mat_output @ rot_mat + else: + rot_mat_input = rot_mat_input @ rot_mat + + # Compose rotation modules + parametrize.register_parametrization( + rot_module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=_apply_ort_device, + axis=_get_output_axis(rot_module) if is_source else _get_input_axis(rot_module), + )) + if is_source: + parametrize.register_parametrization( + rot_module, + "bias", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=_apply_ort_device, + axis=1, + )) + + # If the node is a sink, the input is multiplied by the inverse of the rotation matrix x <- xQ^{-1} + # If the node is a source, the output is multiplied by the rotation matrix o <- oQ + gt_output = module(sample_input @ rot_mat_input.t()) @ rot_mat_output + rot_output = rot_module(sample_input) + + # Verify that the rotation operations were computed correctly + assert torch.allclose(gt_output, rot_output, atol=ATOL) + + +# This method is almost the same as brevitas.graph.equalize.random_orthogonal_matrix, except for the +# possibility of passing a generator, that enables controlling the random matrices that are generated +# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26 +# This functions needs to be patches to enable passing the generator and ensuring that the orthogonal +# matrices generated are the same. +def _random_orthogonal_matrix(size, generator): + """ + Generate a random orthogonal matrix of the specified size. + First, we generate a random matrix with entries from a standard distribution. + Then, we use QR decomposition to obtain an orthogonal matrix. + Finally, we multiply by a diagonal matrix with diag r to adjust the signs. + Args: + size (int): The size of the matrix (size x size). + Returns: + torch.Tensor: An orthogonal matrix of the specified size. + """ + torch.cuda.empty_cache() + random_matrix = torch.randn(size, size, dtype=torch.float64, generator=generator) + q, r = torch.linalg.qr(random_matrix) + q *= torch.sign(torch.diag(r)).unsqueeze(0).float() + return q + +# Auxiliar method to convert a dictionary of sources/sinks into a valid region +def _instantiate_region(region_dict, model) -> Region: + if len(region_dict["srcs"]) > 0: + sorted_srcs = dict(sorted({ src: EqualizationIndexes(0, IN_FEATURES, 0) for src in region_dict["srcs"]}.items())) + sorted_sinks = dict(sorted({ sink: EqualizationIndexes(0, IN_FEATURES, 0) for sink in region_dict["sinks"]}.items())) + else: + sorted_srcs = dict() + sorted_sinks = dict(sorted({ sink: EqualizationIndexes(0, IN_FEATURES, 0) for sink in region_dict["sinks"]}.items())) + sorted_acts = tuple() + return Region( + srcs=sorted_srcs, + sinks=sorted_sinks, + acts=sorted_acts, + name_to_module=model._modules + ) + +# Auxiliar function to compare the weights of modules instances belonging to classes_to_compare +def _compare_model_weights(model_fused, model_unfused, modules_not_matching=[], classes_to_compare=(nn.Linear,)): + tensor_names = ["weight", "bias"] + for (name_module_fused, module_fused), (_, module_unfused) in zip(model_fused.named_parameters(), model_unfused.named_parameters()): + if isinstance(module_fused, classes_to_compare): + for tensor_name in tensor_names: + if hasattr(module_fused, tensor_name): + if name_module_fused in modules_not_matching: + assert not torch.allclose(getattr(module_fused, tensor_name), getattr(module_unfused, tensor_name), atol=ATOL), f"Tensor {tensor_name} should not match for module {name_module_fused}" + else: + assert torch.allclose(getattr(module_fused, tensor_name), getattr(module_unfused, tensor_name), atol=0.0, rtol=0.0), f"Tensor {tensor_name} does not match for module {name_module_fused}" + + +@pytest_cases.parametrize( + 'mask', + itertools.product([False, True], repeat=3), + ids=lambda mask: "-".join([rot for mask_el, rot in zip(mask, ["R1", "R2", "R3"]) if mask_el]) +) +@pytest_cases.parametrize('full_rotation_method', ['ort', 'had']) +@pytest_cases.parametrize('device', ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']) +@pytest_cases.parametrize('fuse_rotations', [False, True], ids=["fused", "unfused"]) +@pytest_cases.parametrize('use_fx', [True, False], ids=["fx", "no-fx"]) +def test_apply_rotate(block_residual_model, mask, full_rotation_method, device, fuse_rotations, use_fx): + # Instantiate a residual model for which a collection of regions is available + model = block_residual_model() + device = torch.device("cuda") if device == 'cuda' else torch.device("cpu") + model.to(device) + # Sample input to pass through the models + sample_inputs = torch.rand(size=(5, IN_FEATURES)).to(device) + # Collect only a subset of regions to be applied + regions_dicts = [ + region_dict + for mask_element, region_dict + in zip(mask, RESIDUAL_MODEL_REGION_DICTS) + if mask_element + ] + # Use FX model if requested + if use_fx: + graph_model, _ = torch._dynamo.export(model)(sample_inputs) + # The module names in the original model need to be mapped to the ones + # in graph_model + map_model_graph = {} + for graph_module_name, graph_module in graph_model.named_modules(): + if hasattr(graph_module, "weight"): + for name, module in model.named_modules(): + if hasattr(module, "weight") and graph_module.weight is module.weight: + map_model_graph[name] = graph_module_name + # Replace the names of the modules in sources/sinks by the names of the modules in the FX model + regions_dicts = [ + {k: list(map(lambda x: map_model_graph[x], v)) for k, v in region_dict.items()} for region_dict in regions_dicts + ] + # Rotation will be applied on the FX model + model = graph_model + + # Deepcopy the models as parameters are going to be modified in-place + rotated_model_unfused = copy.deepcopy(model) + rotated_model_fused = copy.deepcopy(model) + + # Generator to control the random orthogonal matrices generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = generator.clone_state() + + # Apply rotations on the model with unfused rotations + regions_unfused = list(map(lambda x: _instantiate_region(x, rotated_model_unfused), regions_dicts)) + if full_rotation_method == 'had': + # _apply_ort_device is patched to ensure that the hadamard matrices in hadamard.pt are used, instead of + # the random ones generated by random_hadamard_matrices + with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): + rewriters = _apply_rotate(rotated_model_unfused, regions_unfused, full_rotation_method=full_rotation_method, fuse_rotations=False) + elif full_rotation_method == 'ort': + with patch('brevitas.graph.equalize.random_orthogonal_matrix', partial(_random_orthogonal_matrix, generator=generator)): + rewriters = _apply_rotate(rotated_model_unfused, regions_unfused, full_rotation_method=full_rotation_method, fuse_rotations=False) + # Register parametrizations after calling _apply_rotate, as these are not inmediately registered since they alter the structure of the + # model, thus potentially causing a crash if the model is offloaded + for r in rewriters: + if isinstance(r, ModuleInstanceRegisterParametrization): + rotated_model_unfused = r.apply(rotated_model_unfused) + # Apply rotations on the model with fused rotations + with patch('brevitas.graph.equalize.random_orthogonal_matrix', partial(_random_orthogonal_matrix, generator=generator_clone)): + regions_fused = list(map(lambda x: _instantiate_region(x, rotated_model_fused), regions_dicts)) + _apply_rotate(rotated_model_fused, regions_fused, full_rotation_method=full_rotation_method, fuse_rotations=True) + + # Compute outputs for each model + model_output = model(sample_inputs) + rotated_model_unfused_output = rotated_model_unfused(sample_inputs) + rotated_model_fused_output = rotated_model_fused(sample_inputs) + + # Verify that the correct number of unique rotation matrices were included. Orphan sinks (len(region_dict["srcs"]) == 0) do not + # an attached parametrization + assert sum([len(region_dict["srcs"]) > 0 for region_dict in regions_dicts]) == sum(["rot_mat" in name for name, _ in rotated_model_unfused.named_parameters(remove_duplicate=True)]) + # Verify that RotatedModules were added appropiately + for rotated_model in [rotated_model_fused, rotated_model_unfused]: + assert sum([len(region_dict["srcs"]) == 0 for region_dict in regions_dicts]) == sum([isinstance(module, RotatedModule) for module in rotated_model.modules()]) + # Optionally fuse the rotations + if fuse_rotations: + rotated_model_unfused = _fuse_rotations(rotated_model_unfused) + # Verify that no parametrizations remain after fusing + for module in rotated_model_unfused.modules(): + assert not parametrize.is_parametrized(module) + # Outputs should match for rotated and unrotated models + assert torch.allclose(model_output, rotated_model_fused_output, atol=ATOL) + assert torch.allclose(rotated_model_unfused_output, rotated_model_fused_output, atol=0.0, rtol=0.0) + # Verify that the weights have changed with respect to the unrotated module for the modules that have received parametrizations + rotated_modules = set([module_name for region_dict in regions_dicts for module_name in (region_dict["srcs"] + region_dict["sinks"])]) + for rotated_model in [rotated_model_fused, rotated_model_unfused]: + _compare_model_weights(model, rotated_model, modules_not_matching=rotated_modules) + # Verify that weights match between the fused and unfused model + _compare_model_weights(rotated_model_fused, rotated_model_unfused) + diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index 875d5a52c..392703b44 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -5,6 +5,7 @@ from packaging import version import pytest +import pytest_cases import torch from torch import nn from torchvision import models @@ -17,9 +18,14 @@ from brevitas.graph import MergeBatchNorm from brevitas.graph import MethodToModule from brevitas.graph.base import ModuleToModuleByInstance +from brevitas.graph.base import ModuleInstanceRegisterParametrization +from brevitas.graph.base import ModuleInstanceFuseRotationWeights +from brevitas.graph.base import RotationWeightParametrization +from brevitas.graph.base import ModuleInstanceWrapModule from brevitas.nn import QuantConv1d from brevitas.nn import QuantConv2d from brevitas.nn import QuantConv3d +from brevitas.nn.equalized_layer import RotatedModule SEED = 123456 INPUT_SIZE = (1, 3, 224, 224) @@ -290,3 +296,63 @@ def forward(self, x): kwargs = {'stride': lambda module, name: 2 if module.in_channels == 3 else 1} model = ModuleToModuleByInstance(model.conv, nn.Conv2d, **kwargs).apply(model) assert model.conv.stride == (2, 2) + + +def test_module_instance_register_parametrization(): + + class TestModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(2,2, bias=False) + + def forward(self, x): + return self.linear(x) + + class ZeroParametrization(nn.Module): + def forward(self, x): + return torch.zeros_like(x) + + model = TestModel() + model = ModuleInstanceRegisterParametrization(model.linear, "weight", ZeroParametrization()).apply(model) + assert torch.all(model.linear.weight == 0.) + + +def test_module_instance_wrap_module(): + class TestModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(2,2, bias=False) + + def forward(self, x): + return self.linear(x) + + model = TestModel() + model = ModuleInstanceWrapModule(model.linear, RotatedModule, "layer", {"had_mat": None, "k": None}).apply(model) + assert isinstance(model.linear, RotatedModule) + + +@pytest_cases.parametrize("axis", [0, 1], ids=lambda axis:f"axis={axis}") +def test_fuse_rotation_weights(axis): + + def rot_func(weight, ort, K): + return torch.matmul(weight, ort) + + class TestModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(2,2, bias=False) + + def forward(self, x): + return self.linear(x) + + rot_mat = torch.rand(2,2) + model_fused = TestModel() + model_unfused = TestModel() + model_unfused.linear.weight.data = model_fused.linear.weight.data + + model_fused = ModuleInstanceFuseRotationWeights(model_fused.linear, rot_mat, rot_func, None, "weight", axis).apply(model_fused) + model_unfused = ModuleInstanceRegisterParametrization(model_unfused.linear, "weight", RotationWeightParametrization(rot_mat, rot_func, axis, None)).apply(model_unfused) + assert torch.all(model_fused.linear.weight == model_unfused.linear.weight) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index c02a3e320..0817c3bec 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -723,3 +723,84 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): quant_ppl = quant_ppl.detach().cpu().numpy() assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + +@pytest_cases.fixture( + ids=[ + "llama_fused_rotation_ort", + "llama_fused_rotation_ort_no_orphan", + "llama_fused_rotation_had", + "llama_fused_rotation_had_no_orphan", + "llama_layerwise",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "rotation_orphan_sink": True, + "rotation_mode": "ort", + "float_ppl": 33238.8984375, + "quant_ppl": 33232.65234375}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "rotation_orphan_sink": False, + "rotation_mode": "ort", + "float_ppl": 33238.8984375, + "quant_ppl": 33420.65234375}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "rotation_orphan_sink": True, + "rotation_mode": "had", + "float_ppl": 33238.8984375, + "quant_ppl": 33290.48046875}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "fused_no_fx", + "rotation_orphan_sink": False, + "rotation_mode": "had", + "float_ppl": 33238.8984375, + "quant_ppl": 33204.80859375}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_calibration": False, + "weight_bit_width": 4, + "input_bit_width": None, + "replace_rmsnorm": True, + "rotation": "layerwise", + "float_ppl": 33238.8984375, + "quant_ppl": 33446.734375},]) +def rotation_ppl_args_and_ppl(default_run_args, request): + args = default_run_args + run_dict = request.param + float_ppl = run_dict["float_ppl"] + quant_ppl = run_dict["quant_ppl"] + del run_dict["float_ppl"] + del run_dict["quant_ppl"] + args.update(**run_dict) + yield args, float_ppl, quant_ppl + +@requires_pt_ge('2.4') +def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): + caplog.set_level(logging.INFO) + args, exp_float_ppl, exp_quant_ppl = rotation_ppl_args_and_ppl + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + float_ppl = float_ppl.detach().cpu().numpy() + quant_ppl = quant_ppl.detach().cpu().numpy() + assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" \ No newline at end of file