From 0a209ee4363643717cca5d74620761fbfe0ba019 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 31 Dec 2024 11:04:44 +0000 Subject: [PATCH] Fix logic registering parametrizations --- src/brevitas/graph/base.py | 98 +++++++++++++++++ src/brevitas/graph/equalize.py | 104 +++++++++++------- src/brevitas/graph/quantize_impl.py | 3 +- src/brevitas/nn/equalized_layer.py | 22 +--- .../llm/llm_quant/rotation_optimization.py | 3 +- .../llm/llm_quant/rotation_utils.py | 4 +- src/brevitas_examples/llm/main.py | 68 ++++++++++-- tests/brevitas/graph/test_equalization.py | 3 +- 8 files changed, 229 insertions(+), 76 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index dae5160ce..1546ecb67 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -6,9 +6,12 @@ from collections import OrderedDict import inspect from inspect import getcallargs +from typing import Any, Callable, Dict, Type, Union import torch +from torch import Tensor from torch.nn import Module +from torch.nn import Parameter import torch.nn.utils.parametrize as parametrize from torch.overrides import get_testing_overrides @@ -187,6 +190,76 @@ def apply(self, graph_model: GraphModule) -> GraphModule: return graph_model +class ModuleInstanceRegisterParametrization(Transform): + + def __init__( + self, old_module_instance: Module, tensor_name: str, + parametrization_module: Module) -> None: + self.old_module_instance = old_module_instance + self.tensor_name = tensor_name + self.parametrization_module = parametrization_module + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + # register the parametrization in the old_module + parametrize.register_parametrization( + old_module, self.tensor_name, self.parametrization_module) + break + return model + + +class ModuleInstanceFuseRotationWeights(Transform): + + def __init__( + self, + old_module_instance: Module, + rot_mat: Union[Parameter, Tensor], + rot_func: Callable, + K: int, + tensor_name: str, + axis: int, + is_source: bool, + ): + self.old_module_instance = old_module_instance + self.rot_mat = rot_mat + self.rot_func = rot_func + self.K = K + self.tensor_name = tensor_name + self.axis = axis + self.is_source = is_source + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + if hasattr(old_module, 'allocate_params'): + old_module.allocate_params(old_module) + weight = getattr(old_module, self.tensor_name).data + + if self.is_source: + if self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + elif self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + else: + raise RuntimeError("Not supported yet") + # If not a source, the module is either a sink or an orphan + else: + if self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + elif self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + else: + raise RuntimeError("Not supported yet") + # Modify the weights in-place + getattr(old_module, self.tensor_name).data = weight + + if hasattr(old_module, 'offload_params'): + old_module.offload_params(old_module) + break + return model + + class ModuleInstanceToModuleInstance(Transform): def __init__(self, old_module_instance, new_module_instance): @@ -202,6 +275,31 @@ def apply(self, model: GraphModule) -> GraphModule: return model +class ModuleInstanceWrapModule(Transform): + + def __init__( + self, + old_module_instance: Module, + wrapper_class: Type[Module], + module_attribute: str, + kwargs_wrapper: Dict[str, Any]): + self.old_module_instance = old_module_instance + self.wrapper_class = wrapper_class + self.module_attribute = module_attribute + self.kwargs_wrapper = kwargs_wrapper + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + kwargs = {self.module_attribute: self.old_module_instance} + kwargs.update(self.kwargs_wrapper) + new_module_instance = self.wrapper_class(**kwargs) + # init the new module based on the old one + replace_module(model, old_module, new_module_instance) + break + return model + + class ModuleToModuleByName(ModuleToModule): def __init__(self, old_module_name, new_module_class, **kwargs): diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index faae03a62..0da49233c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -21,11 +21,13 @@ from brevitas import torch_version from brevitas.fx import GraphModule from brevitas.fx import Node -from brevitas.graph import ModuleToModuleByClass from brevitas.graph import ModuleToModuleByInstance from brevitas.graph.base import GraphTransform from brevitas.graph.base import InsertModuleCallAfter +from brevitas.graph.base import ModuleInstanceFuseRotationWeights +from brevitas.graph.base import ModuleInstanceRegisterParametrization from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import ModuleInstanceWrapModule from brevitas.graph.base import Transform from brevitas.graph.hadamard import get_hadK from brevitas.graph.hadamard import matmul_hadU @@ -1316,7 +1318,8 @@ def _apply_rotate( model: nn.Module, regions: List[Region], full_rotation_method: str = 'had', - fuse_rotations: bool = True): + fuse_rotations: bool = True, + apply_inplace_rotations: bool = True): rewriters = [] for region in regions: insert_rotation_module = len(region.srcs) == 0 @@ -1351,7 +1354,7 @@ def _apply_rotate( continue # If the rotation is not fused, redefine as a Parameter, to enable its optimization - if not fuse_rotations: + if not insert_rotation_module and not fuse_rotations: rot_mat = torch.nn.Parameter(rot_mat) for name, indexes in region.srcs.items(): @@ -1359,36 +1362,44 @@ def _apply_rotate( axis = _get_output_axis(module) if fuse_rotations: - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - weight = module.weight.data - - if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() - elif axis == 1: - weight = rot_func(weight, rot_mat, K) - else: - raise RuntimeError("Not supported yet") - module.weight.data = weight + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="weight", + axis=axis, + is_source=True, + ) + rewriters.append(rewriter) if getattr(module, 'bias', None) is not None: - bias = module.bias.data - bias = rot_func(bias, rot_mat, K) - module.bias.data = bias - if hasattr(module, 'offload_params'): - module.offload_params(module) + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="bias", + axis=1, + is_source=True, + ) + rewriters.append(rewriter) else: - parametrize.register_parametrization( + rewriter = ModuleInstanceRegisterParametrization( module, "weight", RotationWeightParametrization( rot_mat=rot_mat, rot_func=rot_func, - output_axis=axis, + axis=axis, is_source=True, )) + rewriters.append(rewriter) if getattr(module, 'bias', None) is not None: - parametrize.register_parametrization( + # TODO: Consolidate RotationBiasParametrization into a single + # class, by setting output_axis = 1. Also, could use a single + # axis, as input_axis and output_axis are not used simultaneously + rewriter = ModuleInstanceRegisterParametrization( module, "bias", RotationBiasParametrization( @@ -1397,45 +1408,49 @@ def _apply_rotate( output_axis=axis, is_source=True, )) + rewriters.append(rewriter) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) axis = _get_input_axis(module) if not insert_rotation_module and not fuse_rotations: - parametrize.register_parametrization( + rewriter = ModuleInstanceRegisterParametrization( module, "weight", RotationWeightParametrization( rot_mat=rot_mat, rot_func=rot_func, - input_axis=axis, + axis=axis, is_sink=True, )) + rewriters.append(rewriter) else: # Verify that there are no parametrizations, as otherwise the underlying weights will not be updated assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - weight = module.weight.data - - if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') - elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') - else: - raise RuntimeError("Not supported yet") - - if hasattr(module, 'offload_params'): - module.offload_params(module) + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="weight", + axis=axis, + is_source=False, + ) + rewriters.append(rewriter) if insert_rotation_module and len(region.srcs) == 0: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) + rewriter = ModuleInstanceWrapModule( + module, RotatedModule, "layer", { + "had_mat": rot_mat, "k": K}) rewriters.append(rewriter) for r in rewriters: - model = r.apply(model) + # The parametrizations need to be registered after the potential HF hooks have been + # removed, as otherwise the device maps will not match the structure of the + # model's state_dict after the registration of the parametrizations. + if apply_inplace_rotations and not isinstance(r, ModuleInstanceRegisterParametrization): + model = r.apply(model) return rewriters @@ -1532,7 +1547,8 @@ def apply( self, graph_model: GraphModule, fuse_rotations: bool = True, - additional_regions: Optional[List[Region]] = None + additional_regions: Optional[List[Region]] = None, + apply_inplace_rotations: bool = True, ) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( @@ -1566,7 +1582,11 @@ def apply( self.rotate_matmuls(graph_model) if len(regions) > 0: rewriters = _apply_rotate( - graph_model, regions, self.full_rotation_method, fuse_rotations) + graph_model, + regions, + self.full_rotation_method, + fuse_rotations, + apply_inplace_rotations) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index a4d348ab5..538ce5717 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch.nn.utils.parametrize as parametrize +from tqdm import tqdm import brevitas from brevitas.graph.base import InsertModuleCallAfter @@ -538,6 +539,6 @@ def layerwise_layer_handler( quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) - for rewriter in rewriters: + for rewriter in tqdm(rewriters, leave=False): model = rewriter.apply(model) return model diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index ccd812713..2c48f9da3 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -69,14 +69,6 @@ def __init__(self, layer, had_mat=None, k=None) -> None: self.layer = layer self.k = k - @property - def weight(self) -> Optional[torch.Tensor]: - return getattr(self.layer, 'weight', None) - - @property - def bias(self) -> Optional[torch.Tensor]: - return getattr(self.layer, 'bias', None) - def forward(self, inp, **kwargs): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None # If k is None, we assume that an orthogonal matrix is used @@ -110,8 +102,7 @@ def __init__( self, rot_mat: torch.nn.Parameter, rot_func: Callable, - input_axis: Optional[int] = None, - output_axis: Optional[int] = None, + axis: int, is_source: bool = False, is_sink: bool = False, is_orphan: bool = False, @@ -119,8 +110,7 @@ def __init__( super().__init__() self.rot_mat = rot_mat self.rot_func = rot_func - self.input_axis = input_axis - self.output_axis = output_axis + self.axis = axis self.is_source = is_source self.is_sink = is_sink self.is_orphan = is_orphan @@ -128,17 +118,17 @@ def __init__( def forward(self, weight: torch.Tensor) -> torch.Tensor: if self.is_sink or self.is_orphan: - if self.input_axis == 1: + if self.axis == 1: weight = self.rot_func(weight, self.rot_mat, self.K) - elif self.input_axis == 0: + elif self.axis == 0: weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() else: raise RuntimeError("Not supported yet") if self.is_source: - if self.output_axis == 0: + if self.axis == 0: weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() - elif self.output_axis == 1: + elif self.axis == 1: weight = self.rot_func(weight, self.rot_mat, self.K) else: raise RuntimeError("Not supported yet") diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 31cf00051..618763498 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -20,8 +20,7 @@ @dataclass class ModelArguments: input_model: Optional[str] = field( - default="hf-internal-testing/tiny-random-LlamaForCausalLM", - metadata={"help": "Input model"}) + default="meta-llama/Llama-3.2-1B", metadata={"help": "Input model"}) output_rotation_path: Optional[str] = field( default="test-output", metadata={"help": "Output rotation checkpoint path"}) optimized_rotation_path: Optional[str] = field( diff --git a/src/brevitas_examples/llm/llm_quant/rotation_utils.py b/src/brevitas_examples/llm/llm_quant/rotation_utils.py index 037de166d..9c84aeff7 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_utils.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_utils.py @@ -66,13 +66,15 @@ def fuse_rotations(model: nn.Module) -> None: parametrize.remove_parametrizations(module, "bias", leave_parametrized=True) +# TODO: Remove? We rely on ModuleInstanceRegisterParametrization def extract_rewriters_unfused_rotations(model: nn.Module, rewriters: List[Transform]) -> List[Transform]: extra_rewriters = [] for module in model.modules(): if hasattr(module, "parametrizations"): # Verify that the current module does not have already associated a RotatedModule - if len([r for r in rewriters if r.old_module_instance is module]) == 0: + if len([r for r in rewriters if r.old_module_instance is module and + isinstance(r, ModuleInstanceToModuleInstance)]) == 0: # Identity rewriter, only useful externaly rewriter = ModuleInstanceToModuleInstance(module, module) extra_rewriters.append(rewriter) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 95a3055dd..947314185 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -3,7 +3,10 @@ import argparse from copy import deepcopy +from functools import wraps +import os import sys +from typing import Callable, List from warnings import warn import numpy as np @@ -16,6 +19,7 @@ from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.graph.base import ModuleInstanceFuseRotationWeights from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize @@ -52,6 +56,37 @@ def set_seed(seed): torch.random.manual_seed(seed) +def on_process(process_index: int): + + def decorator(func: Callable): + + @wraps(func) + def _wrapper(model, *args, **kwargs): + curr_process_index = int(os.environ.get('LOCAL_RANK', -1)) + + if curr_process_index == -1 or (process_index == curr_process_index): + print(f"Applying {func.__name__} on process index {curr_process_index}") + return func(model, *args, **kwargs) + else: + print(f"Skipping function {func.__name__} on process index {curr_process_index}") + return model + + return _wrapper + + return decorator + + +def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.Module: + model = offload_model(model) + for r in rewriters: + if isinstance(r, ModuleInstanceFuseRotationWeights): + model = r.apply(model) + remove_hooks(model) + return model + + +# TODO: Use no_grad? The result of fusing the rotations would yield tensor with requires_grad set to False, +# which might no be a problem, as that flag is set in the appropiate QAT/PTQ algorithms. def fused_rotation_no_fx( model, calibration_loader, @@ -66,7 +101,6 @@ def fused_rotation_no_fx( for r in rewriters: r.apply(model) - new_model = offload_model(new_model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, @@ -76,17 +110,17 @@ def fused_rotation_no_fx( find_self_attention_rotation_regions( new_model, model.config.hidden_size // model.config.num_attention_heads) if add_self_attention_regions else None) - new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions) - # Additional rewriters need to be added if rotations are not fused - if not fuse_rotations: - rewriters_unfused_rotations = extract_rewriters_unfused_rotations(new_model, rewriters) - rewriters.extend(rewriters_unfused_rotations) - + new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions, apply_inplace_rotations=False) + # Rewriters need to be fixed to point to the module instances of the original model rewriters = fix_rewriter(rewriters, model, 'weight') - + # The weights of the FX model and the original model are tied, so the rotation fusing has already been applied. + # Note that the parametrization registration cannot be done in a model that has been offloaded using + # offload_model, as the change in the state dictionary when registering the parametrization causes the removal + # of the hooks to crash. This is due to the fact that the device_map in the AlignDevicesHook is no longer valid. + model = apply_fused_rotations(model, rewriters) for r in rewriters: - r.apply(model) - remove_hooks(new_model) + if not isinstance(r, ModuleInstanceFuseRotationWeights): + model = r.apply(model) def set_seed(seed): @@ -264,6 +298,14 @@ def main(args, unknown_args=None): del calibration_loader[i]["attention_mask"] calibration_loader[i]["labels"] = calibration_loader[i]["input_ids"] + def mock_save_pretrained_fn(*args, **kwargs): + pass + + # For a PretrainedModel, the Trainer in accelerate calls save_pretrained after + # finishing the optimization. However, this method no longer works after + # registering parametrizations/quantizing, so this method is mocked to prevent + # a crash. + model.save_pretrained = mock_save_pretrained_fn model.config.use_cache = False model.config.loss_type = "ForCausalLM" @@ -368,7 +410,6 @@ def main(args, unknown_args=None): quantize_embedding=False) if not args.quantize_last_layer: if require_fx: - # TODO: Fix when using UnfusedRotation, layer_map[type(last_module)][1] crashes last_node = [node for node in model.graph.nodes if node.op == 'call_module'][-1] last_module = get_module(model, last_node.target) last_layer_kwargs = layer_map[type(last_module)][1] @@ -411,6 +452,8 @@ def main(args, unknown_args=None): unknown_args=unknown_args, ) + remove_hooks(model) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) @@ -447,10 +490,11 @@ def main(args, unknown_args=None): if args.eval and not args.no_quantize: print("Model eval...") + model = offload_model(model) quant_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") - remove_hooks(model) + remove_hooks(model) if args.checkpoint_name is not None: print(f"Saving checkpoint to {args.checkpoint_name}") diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 1e913845c..2acf8287b 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -361,8 +361,7 @@ def test_composition_unfused_rotations(N): RotationWeightParametrization( rot_mat=rot_mat, rot_func=_apply_ort_device, - input_axis=_get_input_axis(rot_module), - output_axis=_get_output_axis(rot_module), + axis=_get_output_axis(rot_module) if is_source else _get_input_axis(rot_module), is_source=is_source, is_sink=is_sink, is_orphan=is_orphan,