From e3de22814b02636447fa5d1787e88be465280832 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 13 Jan 2025 22:20:18 +0100 Subject: [PATCH] Feat (rotation): equalize across SDPA (#1149) --- src/brevitas/graph/equalize.py | 57 ++++++++++++++++++++++++++--- src/brevitas_examples/llm/README.md | 4 ++ src/brevitas_examples/llm/main.py | 11 +++++- tests/brevitas_examples/test_llm.py | 19 +++++++++- 4 files changed, 83 insertions(+), 8 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 4e5c1a162..7cbe38c6a 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1334,12 +1334,13 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= weight = module.weight.data if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() + rotated_weight = rot_func(weight.t(), rot_mat, K).t() + _update_weights(module, rotated_weight, 'weight') elif axis == 1: - weight = rot_func(weight, rot_mat, K) + rotated_weight = rot_func(weight, rot_mat, K) + _update_weights(module, rotated_weight, 'weight') else: raise RuntimeError("Not supported yet") - module.weight.data = weight if getattr(module, 'bias', None) is not None: bias = module.bias.data @@ -1356,9 +1357,11 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= weight = module.weight.data if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') + rotated_weight = rot_func(weight, rot_mat, K) + _update_weights(module, rotated_weight, 'weight') elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') + rotated_weight = rot_func(weight.t(), rot_mat, K).t() + _update_weights(module, rotated_weight, 'weight') else: raise RuntimeError("Not supported yet") @@ -1428,6 +1431,7 @@ def __init__( self, blacklist_layers: Optional[List[str]] = None, orphan_sink: bool = False, + sdpa_regions: bool = False, rotate_matmul: bool = False, full_rotation_method: str = 'had', return_rewriters: bool = False) -> None: @@ -1445,6 +1449,7 @@ def __init__( self.rotate_matmul = rotate_matmul self.full_rotation_method = full_rotation_method self.return_rewriters = return_rewriters + self.sdpa_regions = sdpa_regions def rotate_matmuls(self, graph_module): matmul_nodes = list(graph_module.graph.nodes) @@ -1463,6 +1468,44 @@ def rotate_matmuls(self, graph_module): graph_module.recompile() graph_module.graph.lint() + def rotate_sdpa(self, graph_module): + sdpa_nodes = list(graph_module.graph.nodes) + sdpa_nodes = [ + c for c in sdpa_nodes if 'scaled_dot_product' in str(c.meta.get('orig_target', 'None'))] + regions = [] + + def find_src(node): + if node.op != 'call_module': + return find_src(node.args[0]) + else: + return node + + def find_sink(node): + output_node = list(node.users.keys())[0] + if output_node.op != 'call_module': + return find_sink(output_node) + else: + return output_node + + for sdpa_node in sdpa_nodes: + value_input = sdpa_node.args[-1] + + value_node = find_src(value_input) + output_node = find_sink(value_input) + sink_module = get_module(graph_module, output_node.target) + src_module = get_module(graph_module, value_node.target) + sink_weight = get_weight_sink(sink_module) + src_weight = get_weight_source(src_module) + sink_eq_indexes = EqualizationIndexes(0, sink_weight.shape[0], 0) + src_eq_indexes = EqualizationIndexes(0, src_weight.shape[0], 0) + region = Region( + srcs={'src0': src_eq_indexes}, + sinks={'sink0': sink_eq_indexes}, + name_to_module={ + 'src0': src_module, 'sink0': sink_module}) + regions.append(region) + return regions + def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] @@ -1476,9 +1519,13 @@ def apply(self, eq_layers = set() orphan_regions = [] self.find_module(graph_model, orphan_regions) + if self.sdpa_regions: + sdpa_regions = self.rotate_sdpa(graph_model) + regions.extend(sdpa_regions) for r in regions: id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names] eq_layers.update(id_list) + if self.orphan_sink: for o_r in orphan_regions: # Layerwise have only a single sink named 'sinks0' diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 59e17084f..ee0e3df77 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -50,6 +50,7 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--replace-mha] [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] [--rotation-mode {had,ort}] [--rotation-orphan-sink] + [--rotation-sdpa-regions] [--act-equalization {None,layerwise,fx}] [--act-equalization-alpha ACT_EQUALIZATION_ALPHA] [--load-awq LOAD_AWQ] @@ -184,6 +185,9 @@ options: --rotation-orphan-sink If GraphRotation is enabled, decide wheter to add standalone hadamard matrices for the unfused layers + --rotation-sdpa-regions + If GraphRotation is enabled, decide wheter to equalize + across SDPA --act-equalization {None,layerwise,fx} Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,while fx merges them diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index cca2172ab..cd997df11 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -90,7 +90,8 @@ def fused_rotation_no_fx(model, calibration_loader, args): eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, - return_rewriters=True) + return_rewriters=True, + sdpa_regions=args.rotation_sdpa_regions) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -297,7 +298,9 @@ def quantize_llm(args): if args.rotation == 'fx': model = offload_model(model) eq = GraphRotationEqualization( - orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode) + orphan_sink=args.rotation_orphan_sink, + full_rotation_method=args.rotation_mode, + sdpa_regions=args.rotation_sdpa_regions) model = eq.apply(model) remove_hooks(model) elif args.rotation == 'layerwise': @@ -789,6 +792,10 @@ def parse_args(args, override_defaults={}): help= 'If GraphRotation is enabled, decide wheter to add standalone hadamard matrices for the unfused layers' ) + parser.add_argument( + '--rotation-sdpa-regions', + action="store_true", + help='If GraphRotation is enabled, decide wheter to equalize across SDPA') parser.add_argument( '--act-equalization', default=None, diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index c02a3e320..ca1c7cda7 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -434,7 +434,8 @@ def test_small_models_quant_layer(caplog, layer_args): "llama-int8-act_equalization=layerwise", "mistral-int8-quant-last-layer", "llama-rotation-mixed-fx", - "llama-rotation-full-fx",], + "llama-rotation-full-fx", + "llama-rotation-full-fx-sdpa"], params=[ { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", @@ -547,6 +548,22 @@ def test_small_models_quant_layer(caplog, layer_args): "": 15, # LM Head + Q/K/V projs + Up/Gate/Down projs "": 5, # Input + Post attention + "": 0,}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "ln_affine_merge": True, + "replace_rmsnorm": True, + "quantize_last_layer": True, + "no_quantize": True, + "rotation_orphan_sink": True, + "convert_layernorm_to_rmsnorm": True, + "rotation_sdpa_regions": True, + "rotation": "fx", + "exp_layer_types_count": { + "": 2, # Sinks: Only Down proj + "": + 15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs + "": 5, "": 0,}},]) def layer_args_types_count(default_run_args, request): args = default_run_args