Skip to content

Commit

Permalink
Fix logic registering parametrizations
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Dec 31, 2024
1 parent 53488cc commit 0a209ee
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 76 deletions.
98 changes: 98 additions & 0 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
from collections import OrderedDict
import inspect
from inspect import getcallargs
from typing import Any, Callable, Dict, Type, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter
import torch.nn.utils.parametrize as parametrize
from torch.overrides import get_testing_overrides

Expand Down Expand Up @@ -187,6 +190,76 @@ def apply(self, graph_model: GraphModule) -> GraphModule:
return graph_model


class ModuleInstanceRegisterParametrization(Transform):

def __init__(
self, old_module_instance: Module, tensor_name: str,
parametrization_module: Module) -> None:
self.old_module_instance = old_module_instance
self.tensor_name = tensor_name
self.parametrization_module = parametrization_module

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
# register the parametrization in the old_module
parametrize.register_parametrization(
old_module, self.tensor_name, self.parametrization_module)
break
return model


class ModuleInstanceFuseRotationWeights(Transform):

def __init__(
self,
old_module_instance: Module,
rot_mat: Union[Parameter, Tensor],
rot_func: Callable,
K: int,
tensor_name: str,
axis: int,
is_source: bool,
):
self.old_module_instance = old_module_instance
self.rot_mat = rot_mat
self.rot_func = rot_func
self.K = K
self.tensor_name = tensor_name
self.axis = axis
self.is_source = is_source

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
if hasattr(old_module, 'allocate_params'):
old_module.allocate_params(old_module)
weight = getattr(old_module, self.tensor_name).data

if self.is_source:
if self.axis == 0:
weight = self.rot_func(weight.t(), self.rot_mat, self.K).t()
elif self.axis == 1:
weight = self.rot_func(weight, self.rot_mat, self.K)
else:
raise RuntimeError("Not supported yet")
# If not a source, the module is either a sink or an orphan
else:
if self.axis == 1:
weight = self.rot_func(weight, self.rot_mat, self.K)
elif self.axis == 0:
weight = self.rot_func(weight.t(), self.rot_mat, self.K).t()
else:
raise RuntimeError("Not supported yet")
# Modify the weights in-place
getattr(old_module, self.tensor_name).data = weight

if hasattr(old_module, 'offload_params'):
old_module.offload_params(old_module)
break
return model


class ModuleInstanceToModuleInstance(Transform):

def __init__(self, old_module_instance, new_module_instance):
Expand All @@ -202,6 +275,31 @@ def apply(self, model: GraphModule) -> GraphModule:
return model


class ModuleInstanceWrapModule(Transform):

def __init__(
self,
old_module_instance: Module,
wrapper_class: Type[Module],
module_attribute: str,
kwargs_wrapper: Dict[str, Any]):
self.old_module_instance = old_module_instance
self.wrapper_class = wrapper_class
self.module_attribute = module_attribute
self.kwargs_wrapper = kwargs_wrapper

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
kwargs = {self.module_attribute: self.old_module_instance}
kwargs.update(self.kwargs_wrapper)
new_module_instance = self.wrapper_class(**kwargs)
# init the new module based on the old one
replace_module(model, old_module, new_module_instance)
break
return model


class ModuleToModuleByName(ModuleToModule):

def __init__(self, old_module_name, new_module_class, **kwargs):
Expand Down
104 changes: 62 additions & 42 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from brevitas import torch_version
from brevitas.fx import GraphModule
from brevitas.fx import Node
from brevitas.graph import ModuleToModuleByClass
from brevitas.graph import ModuleToModuleByInstance
from brevitas.graph.base import GraphTransform
from brevitas.graph.base import InsertModuleCallAfter
from brevitas.graph.base import ModuleInstanceFuseRotationWeights
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleInstanceWrapModule
from brevitas.graph.base import Transform
from brevitas.graph.hadamard import get_hadK
from brevitas.graph.hadamard import matmul_hadU
Expand Down Expand Up @@ -1316,7 +1318,8 @@ def _apply_rotate(
model: nn.Module,
regions: List[Region],
full_rotation_method: str = 'had',
fuse_rotations: bool = True):
fuse_rotations: bool = True,
apply_inplace_rotations: bool = True):
rewriters = []
for region in regions:
insert_rotation_module = len(region.srcs) == 0
Expand Down Expand Up @@ -1351,44 +1354,52 @@ def _apply_rotate(
continue

# If the rotation is not fused, redefine as a Parameter, to enable its optimization
if not fuse_rotations:
if not insert_rotation_module and not fuse_rotations:
rot_mat = torch.nn.Parameter(rot_mat)

for name, indexes in region.srcs.items():
module = region.get_module_from_name(name)
axis = _get_output_axis(module)

if fuse_rotations:
if hasattr(module, 'allocate_params'):
module.allocate_params(module)
weight = module.weight.data

if axis == 0:
weight = rot_func(weight.t(), rot_mat, K).t()
elif axis == 1:
weight = rot_func(weight, rot_mat, K)
else:
raise RuntimeError("Not supported yet")
module.weight.data = weight
rewriter = ModuleInstanceFuseRotationWeights(
old_module_instance=module,
rot_mat=rot_mat,
rot_func=rot_func,
K=K,
tensor_name="weight",
axis=axis,
is_source=True,
)
rewriters.append(rewriter)

if getattr(module, 'bias', None) is not None:
bias = module.bias.data
bias = rot_func(bias, rot_mat, K)
module.bias.data = bias
if hasattr(module, 'offload_params'):
module.offload_params(module)
rewriter = ModuleInstanceFuseRotationWeights(
old_module_instance=module,
rot_mat=rot_mat,
rot_func=rot_func,
K=K,
tensor_name="bias",
axis=1,
is_source=True,
)
rewriters.append(rewriter)
else:
parametrize.register_parametrization(
rewriter = ModuleInstanceRegisterParametrization(
module,
"weight",
RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=rot_func,
output_axis=axis,
axis=axis,
is_source=True,
))
rewriters.append(rewriter)
if getattr(module, 'bias', None) is not None:
parametrize.register_parametrization(
# TODO: Consolidate RotationBiasParametrization into a single
# class, by setting output_axis = 1. Also, could use a single
# axis, as input_axis and output_axis are not used simultaneously
rewriter = ModuleInstanceRegisterParametrization(
module,
"bias",
RotationBiasParametrization(
Expand All @@ -1397,45 +1408,49 @@ def _apply_rotate(
output_axis=axis,
is_source=True,
))
rewriters.append(rewriter)

for name, indexes in region.sinks.items():
module = region.get_module_from_name(name)
axis = _get_input_axis(module)

if not insert_rotation_module and not fuse_rotations:
parametrize.register_parametrization(
rewriter = ModuleInstanceRegisterParametrization(
module,
"weight",
RotationWeightParametrization(
rot_mat=rot_mat,
rot_func=rot_func,
input_axis=axis,
axis=axis,
is_sink=True,
))
rewriters.append(rewriter)
else:
# Verify that there are no parametrizations, as otherwise the underlying weights will not be updated
assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations."

if hasattr(module, 'allocate_params'):
module.allocate_params(module)
weight = module.weight.data

if axis == 1:
_update_weights(module, rot_func(weight, rot_mat, K), 'weight')
elif axis == 0:
_update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight')
else:
raise RuntimeError("Not supported yet")

if hasattr(module, 'offload_params'):
module.offload_params(module)
rewriter = ModuleInstanceFuseRotationWeights(
old_module_instance=module,
rot_mat=rot_mat,
rot_func=rot_func,
K=K,
tensor_name="weight",
axis=axis,
is_source=False,
)
rewriters.append(rewriter)

if insert_rotation_module and len(region.srcs) == 0:
rewriter = ModuleInstanceToModuleInstance(
module, RotatedModule(had_mat=rot_mat, k=K, layer=module))
rewriter = ModuleInstanceWrapModule(
module, RotatedModule, "layer", {
"had_mat": rot_mat, "k": K})
rewriters.append(rewriter)
for r in rewriters:
model = r.apply(model)
# The parametrizations need to be registered after the potential HF hooks have been
# removed, as otherwise the device maps will not match the structure of the
# model's state_dict after the registration of the parametrizations.
if apply_inplace_rotations and not isinstance(r, ModuleInstanceRegisterParametrization):
model = r.apply(model)
return rewriters


Expand Down Expand Up @@ -1532,7 +1547,8 @@ def apply(
self,
graph_model: GraphModule,
fuse_rotations: bool = True,
additional_regions: Optional[List[Region]] = None
additional_regions: Optional[List[Region]] = None,
apply_inplace_rotations: bool = True,
) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]:
rewriters = []
regions = _extract_regions(
Expand Down Expand Up @@ -1566,7 +1582,11 @@ def apply(
self.rotate_matmuls(graph_model)
if len(regions) > 0:
rewriters = _apply_rotate(
graph_model, regions, self.full_rotation_method, fuse_rotations)
graph_model,
regions,
self.full_rotation_method,
fuse_rotations,
apply_inplace_rotations)
if self.return_rewriters:
return graph_model, rewriters
else:
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from tqdm import tqdm

import brevitas
from brevitas.graph.base import InsertModuleCallAfter
Expand Down Expand Up @@ -538,6 +539,6 @@ def layerwise_layer_handler(
quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))]
rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs)
rewriters.append(rewriter)
for rewriter in rewriters:
for rewriter in tqdm(rewriters, leave=False):
model = rewriter.apply(model)
return model
22 changes: 6 additions & 16 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,6 @@ def __init__(self, layer, had_mat=None, k=None) -> None:
self.layer = layer
self.k = k

@property
def weight(self) -> Optional[torch.Tensor]:
return getattr(self.layer, 'weight', None)

@property
def bias(self) -> Optional[torch.Tensor]:
return getattr(self.layer, 'bias', None)

def forward(self, inp, **kwargs):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
# If k is None, we assume that an orthogonal matrix is used
Expand Down Expand Up @@ -110,35 +102,33 @@ def __init__(
self,
rot_mat: torch.nn.Parameter,
rot_func: Callable,
input_axis: Optional[int] = None,
output_axis: Optional[int] = None,
axis: int,
is_source: bool = False,
is_sink: bool = False,
is_orphan: bool = False,
) -> None:
super().__init__()
self.rot_mat = rot_mat
self.rot_func = rot_func
self.input_axis = input_axis
self.output_axis = output_axis
self.axis = axis
self.is_source = is_source
self.is_sink = is_sink
self.is_orphan = is_orphan
self.K = None

def forward(self, weight: torch.Tensor) -> torch.Tensor:
if self.is_sink or self.is_orphan:
if self.input_axis == 1:
if self.axis == 1:
weight = self.rot_func(weight, self.rot_mat, self.K)
elif self.input_axis == 0:
elif self.axis == 0:
weight = self.rot_func(weight.t(), self.rot_mat, self.K).t()
else:
raise RuntimeError("Not supported yet")

if self.is_source:
if self.output_axis == 0:
if self.axis == 0:
weight = self.rot_func(weight.t(), self.rot_mat, self.K).t()
elif self.output_axis == 1:
elif self.axis == 1:
weight = self.rot_func(weight, self.rot_mat, self.K)
else:
raise RuntimeError("Not supported yet")
Expand Down
3 changes: 1 addition & 2 deletions src/brevitas_examples/llm/llm_quant/rotation_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
@dataclass
class ModelArguments:
input_model: Optional[str] = field(
default="hf-internal-testing/tiny-random-LlamaForCausalLM",
metadata={"help": "Input model"})
default="meta-llama/Llama-3.2-1B", metadata={"help": "Input model"})
output_rotation_path: Optional[str] = field(
default="test-output", metadata={"help": "Output rotation checkpoint path"})
optimized_rotation_path: Optional[str] = field(
Expand Down
Loading

0 comments on commit 0a209ee

Please sign in to comment.