Skip to content

Commit

Permalink
Feat (equalize): enable parametrized rotations (#1148)
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago authored Jan 14, 2025
1 parent 0f3877f commit 52cfffd
Show file tree
Hide file tree
Showing 9 changed files with 781 additions and 46 deletions.
129 changes: 129 additions & 0 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
from abc import abstractmethod
import inspect
from inspect import getcallargs
from typing import Any, Callable, Dict, Optional, Type, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter
import torch.nn.utils.parametrize as parametrize
from torch.overrides import get_testing_overrides

from brevitas.fx import GraphModule
from brevitas.fx import immutable_dict
from brevitas.fx import Node
from brevitas.graph.utils import *
from brevitas.utils.python_utils import islambda
from brevitas.utils.rotation_utils import RotationWeightParametrization

__all__ = [
'Transform',
Expand Down Expand Up @@ -174,6 +179,130 @@ def apply(self, graph_model: GraphModule) -> GraphModule:
return graph_model


class ModuleInstanceRegisterParametrization(Transform):
r"""Transform to register a parametrization to a given parameter of a
module.
Args:
module (nn.Module): module on which to register the
parametrization
tensor_name: (str): name of the :class:`torch.nn.Parameter` of
module which is to be parametrized
transform_module (nn.Module): the parametrization to
register
"""

def __init__(self, module: Module, tensor_name: str, transform_module: Module) -> None:
self.module = module
self.tensor_name = tensor_name
self.transform_module = transform_module

# TODO: Unify inferfaces with ModuleInstanceToModuleInstance for
# compatibility with fix_rewriter
@property
def old_module_instance(self):
return self.module

@old_module_instance.setter
def old_module_instance(self, old_module_instance):
self.module = old_module_instance

def apply(self, model: GraphModule) -> GraphModule:
for module in model.modules():
if module is self.module:
# register the parametrization to module
parametrize.register_parametrization(
module, self.tensor_name, self.transform_module)
break
return model


class ModuleInstanceTransformTensor(Transform):
r"""Transform to transform in-place a given parameter of a module
Args:
module (nn.Module): parent module of the parameter to be transformed
tensor_name (str): name of the :class:`torch.nn.Parameter` of
module which is to be transformed
transform_module (nn.Module): module defining the transformation to apply
to the tensor
"""

def __init__(
self,
module: Module,
tensor_name: str,
transform_module: Module,
):
self.module = module
self.tensor_name = tensor_name
self.transform_module = transform_module

# TODO: Unify inferfaces with ModuleInstanceToModuleInstance for
# compatibility with fix_rewriter
@property
def old_module_instance(self):
return self.module

@old_module_instance.setter
def old_module_instance(self, old_module_instance):
self.module = old_module_instance

def apply(self, model: GraphModule) -> GraphModule:
for module in model.modules():
if module is self.module:
# This check is needed to apply the change in the parameters
# when the model is offloaded
# TODO: Move outside the apply function
if hasattr(module, 'allocate_params'):
module.allocate_params(module)
tensor = getattr(module, self.tensor_name).data
tensor = self.transform_module(tensor)
# Modify the weights in-place
setattr(module, self.tensor_name, torch.nn.Parameter(tensor))

if hasattr(module, 'offload_params'):
module.offload_params(module)
break
return model


class ModuleInstanceWrapModule(Transform):
r"""Transform to replace a module by a wrapper module which has the original
one as a submodule
Args:
old_module_instance (nn.Module): module to be wrapped
wrapper_class (type): class of the wrapper for old_module_instance
module_attribute (str): name of the parameter to pass the original
module to the constructor of wrapper_class
kwargs_wrapper (dict, optional): dictionary with the constructor
arguments for wrapper_class
"""

def __init__(
self,
old_module_instance: Module,
wrapper_class: Type[Module],
module_attribute: str,
kwargs_wrapper: Dict[str, Any]):
self.old_module_instance = old_module_instance
self.wrapper_class = wrapper_class
self.module_attribute = module_attribute
self.kwargs_wrapper = kwargs_wrapper

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
kwargs = {self.module_attribute: self.old_module_instance}
kwargs.update(self.kwargs_wrapper)
new_module_instance = self.wrapper_class(**kwargs)
# init the new module based on the old one
replace_module(model, old_module, new_module_instance)
break
return model


class ModuleInstanceToModuleInstance(Transform):

def __init__(self, old_module_instance, new_module_instance):
Expand Down
166 changes: 123 additions & 43 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,32 @@
import torch
from torch.fx import GraphModule as TorchGraphModule
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

from brevitas import torch_version
from brevitas.fx import GraphModule
from brevitas.fx import Node
from brevitas.graph import ModuleToModuleByClass
from brevitas.graph import ModuleToModuleByInstance
from brevitas.graph.base import GraphTransform
from brevitas.graph.base import InsertModuleCallAfter
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleInstanceTransformTensor
from brevitas.graph.base import ModuleInstanceWrapModule
from brevitas.graph.base import Transform
from brevitas.graph.hadamard import get_hadK
from brevitas.graph.hadamard import matmul_hadU
from brevitas.graph.hadamard import matmul_hadU_cuda
from brevitas.graph.hadamard import random_hadamard_matrix
from brevitas.graph.utils import get_module
from brevitas.graph.utils import get_node
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.nn.equalized_layer import functional_rotate_input
from brevitas.nn.equalized_layer import INPUT_NAMES
from brevitas.nn.equalized_layer import RotatedModule
from brevitas.nn.quant_scale_bias import ScaleBias
from brevitas.utils.python_utils import recurse_getattr
from brevitas.utils.rotation_utils import RotationWeightParametrization
from brevitas.utils.torch_utils import KwargsForwardHook

# External optional dependency
Expand Down Expand Up @@ -1299,8 +1305,19 @@ def random_orthogonal_matrix(size):
return q


def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had'):
def _apply_rotate(
model: nn.Module,
regions: List[Region],
full_rotation_method='had',
fuse_rotations: bool = True,
apply_inplace_rotations: bool = True):
rewriters = []
# First, rotations on orphan sinks are applied so the order in which rotations are
# applied is consistent, irrespective of the value of fuse_rotations. This is due to
# the fact that parametrizations need to be registered, once all the in-place
# operations have taken place
regions = [region for region in regions if len(region.srcs) == 0] + [
region for region in regions if len(region.srcs) > 0]
for region in regions:
insert_rotation_module = len(region.srcs) == 0

Expand All @@ -1311,6 +1328,14 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
rot_mat = random_orthogonal_matrix(hidden_dim)
K = None
rot_func = _apply_ort_device
elif not insert_rotation_module and not fuse_rotations:
# If the model is distributed across GPUs, the device will be
# not be the same for all of the parameters, so explicit moves
# to the same device as the weights need to be added
device = next(model.parameters()).device
rot_mat = random_hadamard_matrix(hidden_dim, device)
K = None
rot_func = _apply_ort_device
else:
try:
# Build hadamard rotation matrix
Expand All @@ -1326,57 +1351,112 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
print("Skipping layers")
continue

# Cast rotation matrix to the weight dtype
if rot_mat is not None:
dtype = next(model.parameters()).dtype
rot_mat = rot_mat.to(dtype=dtype)
# If the rotation is not fused, redefine as a Parameter, to enable its optimization
if not insert_rotation_module and not fuse_rotations:
rot_mat = torch.nn.Parameter(rot_mat)

for name, indexes in region.srcs.items():
module = region.get_module_from_name(name)
if hasattr(module, 'allocate_params'):
module.allocate_params(module)
axis = _get_output_axis(module)
weight = module.weight.data

if axis == 0:
rotated_weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rotated_weight, 'weight')
elif axis == 1:
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")

if getattr(module, 'bias', None) is not None:
bias = module.bias.data
bias = rot_func(bias, rot_mat, K)
module.bias.data = bias
if hasattr(module, 'offload_params'):
module.offload_params(module)
# Rotate "bias" if present
tensor_names_axis = [("weight", _get_output_axis(module))] + ([
("bias", 1)] if getattr(module, 'bias', None) is not None else [])
# If rotations are fused, transform is applied directly onto the tensor
rewriter_class = ModuleInstanceTransformTensor if fuse_rotations else ModuleInstanceRegisterParametrization
# Obtain rewriters for applying the rotations
for tensor_name, axis in tensor_names_axis:
rewriter = rewriter_class(
module=module,
tensor_name=tensor_name,
transform_module=RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=rot_func,
axis=axis,
K=K,
))
rewriters.append(rewriter)

for name, indexes in region.sinks.items():
module = region.get_module_from_name(name)
if hasattr(module, 'allocate_params'):
module.allocate_params(module)
axis = _get_input_axis(module)
weight = module.weight.data

if axis == 1:
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
elif axis == 0:
rotated_weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")

if hasattr(module, 'offload_params'):
module.offload_params(module)

# Only "weight" is rotated
tensor_names_axis = [("weight", _get_input_axis(module))]
# If rotations are fused or if the module is an orphan sink, transform is applied directly onto the tensor
rewriter_class = ModuleInstanceTransformTensor if insert_rotation_module or fuse_rotations else ModuleInstanceRegisterParametrization
# Obtain rewriters for applying the rotations
for tensor_name, axis in tensor_names_axis:
rewriter = rewriter_class(
module=module,
tensor_name=tensor_name,
transform_module=RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=rot_func,
axis=axis,
K=K,
))
rewriters.append(rewriter)
# Replace by RotatedModule in orphan sink
if insert_rotation_module and len(region.srcs) == 0:
rewriter = ModuleInstanceToModuleInstance(
module, RotatedModule(had_mat=rot_mat, k=K, layer=module))
rewriter = ModuleInstanceWrapModule(
module, RotatedModule, "layer", {
"had_mat": rot_mat, "k": K})
rewriters.append(rewriter)
for r in rewriters:
model = r.apply(model)
if apply_inplace_rotations:
for r in rewriters:
# The parametrizations need to be registered after the potential HF hooks have been
# removed, as otherwise the device maps will not match the structure of the
# model's state_dict after the registration of the parametrizations.
if not isinstance(r, ModuleInstanceRegisterParametrization):
model = r.apply(model)
return rewriters


# This function is adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py
def _untie_parameters_with_parametrizations(model: torch.nn.Module):
# get ALL model parameters and their names
all_named_parameters = {
name: param for name, param in model.named_parameters(remove_duplicate=False)}

# get ONLY unique named parameters,
# if parameter is tied and have multiple names, it will be included only once
no_duplicate_named_parameters = {
name: param for name, param in model.named_parameters(remove_duplicate=True)}

# the difference of the two sets will give us the tied parameters
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())

for tied_param_name in tied_param_names:
tied_param_name_split = tied_param_name.split(".")
# The names of the original parameters after registering the parametrization
# have the format "prefix.parametrizations.tensor_name.original", e.g.
# "model.layer.parametrizations.weight.original". This allows to identify
# which subset of tied parameters are original tied parameters of the module
if len(tied_param_name_split) >= 3 and tied_param_name_split[
-3] == "parametrizations" and tied_param_name_split[-1] == "original":
# If that is the case, retrieve the parent module
parent_module = recurse_getattr(model, ".".join(tied_param_name_split[:-1]))
# And set to a new parameter, thus breaking the tie
setattr(parent_module, "original", nn.Parameter(all_named_parameters[tied_param_name]))

return model


def _fuse_rotations(model: nn.Module) -> nn.Module:
# First of all, parameters that have parametrizations need to be untied
model = _untie_parameters_with_parametrizations(model)
# Then, parametrizations can be safely removed
for module in model.modules():
# Names of the tensors that can potentially be parametrized
tensor_names = ["weight", "bias"]
# Remove parametrizations from each tensor
for tensor_name in tensor_names:
if parametrize.is_parametrized(module) and tensor_name in module.parametrizations:
parametrize.remove_parametrizations(module, tensor_name, leave_parametrized=True)
return model


def _replace_bias(next_module, new_bias):
new_bias = new_bias.view(-1)
if next_module.bias is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/graph/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def random_hadamard_matrix(size, device):
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
Q = Q * 2 - 1
Q = torch.diag(Q)
return matmul_hadU(Q).to(device)
# Set to float32 for consistency with random_orthogonal_matrix and get_hadK
return matmul_hadU(Q).to(device).float()


def matmul_hadU_cuda(X, hadK, K):
Expand Down
Loading

0 comments on commit 52cfffd

Please sign in to comment.