From ace2f3373b9c1b609150483e81f70faa472df187 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 10 Jan 2025 15:14:54 +0000 Subject: [PATCH 1/7] Feat (rotation): equalize across SDPA --- src/brevitas/graph/equalize.py | 67 +++++++++++++++++++++++++++---- src/brevitas_examples/llm/main.py | 11 ++++- 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 4e5c1a162..ad9320724 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1332,14 +1332,18 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= module.allocate_params(module) axis = _get_output_axis(module) weight = module.weight.data - + original_dtype = next(module.parameters()).dtype if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() + rotated_weight = rot_func( + weight.t().to(torch.float32), rot_mat.to(torch.float32), + K).t().to(original_dtype) + _update_weights(module, rotated_weight, 'weight') elif axis == 1: - weight = rot_func(weight, rot_mat, K) + rotated_weight = rot_func(weight.to(torch.float32), rot_mat.to(torch.float32), + K).to(original_dtype) + _update_weights(module, rotated_weight, 'weight') else: raise RuntimeError("Not supported yet") - module.weight.data = weight if getattr(module, 'bias', None) is not None: bias = module.bias.data @@ -1354,11 +1358,16 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= module.allocate_params(module) axis = _get_input_axis(module) weight = module.weight.data - + original_dtype = next(module.parameters()).dtype if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') + rotated_weight = rot_func(weight.to(torch.float32), rot_mat.to(torch.float32), + K).to(original_dtype) + _update_weights(module, rotated_weight, 'weight') elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') + rotated_weight = rot_func( + weight.t().to(torch.float32), rot_mat.to(torch.float32), + K).t().to(original_dtype) + _update_weights(module, rotated_weight, 'weight') else: raise RuntimeError("Not supported yet") @@ -1428,6 +1437,7 @@ def __init__( self, blacklist_layers: Optional[List[str]] = None, orphan_sink: bool = False, + sdpa_regions: bool = False, rotate_matmul: bool = False, full_rotation_method: str = 'had', return_rewriters: bool = False) -> None: @@ -1445,6 +1455,7 @@ def __init__( self.rotate_matmul = rotate_matmul self.full_rotation_method = full_rotation_method self.return_rewriters = return_rewriters + self.sdpa_regions = sdpa_regions def rotate_matmuls(self, graph_module): matmul_nodes = list(graph_module.graph.nodes) @@ -1463,6 +1474,44 @@ def rotate_matmuls(self, graph_module): graph_module.recompile() graph_module.graph.lint() + def rotate_sdpa(self, graph_module): + sdpa_nodes = list(graph_module.graph.nodes) + sdpa_nodes = [ + c for c in sdpa_nodes if 'scaled_dot_product' in str(c.meta.get('orig_target', 'None'))] + regions = [] + + def find_src(node): + if node.op != 'call_module': + return find_src(node.args[0]) + else: + return node + + def find_sink(node): + output_node = list(node.users.keys())[0] + if output_node.op != 'call_module': + return find_sink(output_node) + else: + return output_node + + for sdpa_node in sdpa_nodes: + value_input = sdpa_node.args[-1] + + value_node = find_src(value_input) + output_node = find_sink(value_input) + sink_module = get_module(graph_module, output_node.target) + src_module = get_module(graph_module, value_node.target) + sink_weight = get_weight_sink(sink_module) + src_weight = get_weight_source(src_module) + sink_eq_indexes = EqualizationIndexes(0, sink_weight.shape[0], 0) + src_eq_indexes = EqualizationIndexes(0, src_weight.shape[0], 0) + region = Region( + srcs={'src0': src_eq_indexes}, + sinks={'sink0': sink_eq_indexes}, + name_to_module={ + 'src0': src_module, 'sink0': sink_module}) + regions.append(region) + return regions + def apply(self, graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] @@ -1479,6 +1528,10 @@ def apply(self, for r in regions: id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names] eq_layers.update(id_list) + + if self.rotate_sdpa: + sdpa_regions = self.rotate_sdpa(graph_model) + regions.extend(sdpa_regions) if self.orphan_sink: for o_r in orphan_regions: # Layerwise have only a single sink named 'sinks0' diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index ed2ebc2c8..99d9d5013 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -89,7 +89,8 @@ def fused_rotation_no_fx(model, calibration_loader, args): eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, - return_rewriters=True) + return_rewriters=True, + sdpa_regions=args.sdpa_regions) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -296,7 +297,9 @@ def quantize_llm(args): if args.rotation == 'fx': model = offload_model(model) eq = GraphRotationEqualization( - orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode) + orphan_sink=args.rotation_orphan_sink, + full_rotation_method=args.rotation_mode, + sdpa_regions=args.sdpa_regions) model = eq.apply(model) remove_hooks(model) elif args.rotation == 'layerwise': @@ -772,6 +775,10 @@ def parse_args(args, override_defaults={}): help= 'If GraphRotation is enabled, decide wheter to add standalone hadamard matrices for the unfused layers' ) + parser.add_argument( + '--rotation-sdpa-regions', + action="store_true", + help='If GraphRotation is enabled, decide wheter to equalize across SDPA') parser.add_argument( '--act-equalization', default=None, From 51a8c38a2b48cfe8b9e3472408c1da96eebdd017 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 10 Jan 2025 15:25:59 +0000 Subject: [PATCH 2/7] fix + readme --- src/brevitas/graph/equalize.py | 6 +++--- src/brevitas_examples/llm/README.md | 4 ++++ src/brevitas_examples/llm/main.py | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index ad9320724..9baeb9421 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1525,13 +1525,13 @@ def apply(self, eq_layers = set() orphan_regions = [] self.find_module(graph_model, orphan_regions) + if self.rotate_sdpa: + sdpa_regions = self.rotate_sdpa(graph_model) + regions.extend(sdpa_regions) for r in regions: id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names] eq_layers.update(id_list) - if self.rotate_sdpa: - sdpa_regions = self.rotate_sdpa(graph_model) - regions.extend(sdpa_regions) if self.orphan_sink: for o_r in orphan_regions: # Layerwise have only a single sink named 'sinks0' diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 0d6fb5f42..dc1d16b38 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -50,6 +50,7 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--replace-mha] [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] [--rotation-mode {had,ort}] [--rotation-orphan-sink] + [--rotation-sdpa-regions] [--act-equalization {None,layerwise,fx}] [--act-equalization-alpha ACT_EQUALIZATION_ALPHA] [--load-awq LOAD_AWQ] @@ -184,6 +185,9 @@ options: --rotation-orphan-sink If GraphRotation is enabled, decide wheter to add standalone hadamard matrices for the unfused layers + --rotation-sdpa-regions + If GraphRotation is enabled, decide wheter to equalize + across SDPA --act-equalization {None,layerwise,fx} Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,while fx merges them diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 99d9d5013..c09b398ed 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -90,7 +90,7 @@ def fused_rotation_no_fx(model, calibration_loader, args): orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, return_rewriters=True, - sdpa_regions=args.sdpa_regions) + sdpa_regions=args.rotation_sdpa_regions) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -299,7 +299,7 @@ def quantize_llm(args): eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, - sdpa_regions=args.sdpa_regions) + sdpa_regions=args.rotation_sdpa_regions) model = eq.apply(model) remove_hooks(model) elif args.rotation == 'layerwise': From 361e664868537e34cf02f01a61502caf9817a6e3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 12 Jan 2025 16:08:14 +0000 Subject: [PATCH 3/7] no upcast --- src/brevitas/graph/equalize.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 9baeb9421..b9f9e36b8 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1332,15 +1332,11 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= module.allocate_params(module) axis = _get_output_axis(module) weight = module.weight.data - original_dtype = next(module.parameters()).dtype if axis == 0: - rotated_weight = rot_func( - weight.t().to(torch.float32), rot_mat.to(torch.float32), - K).t().to(original_dtype) + rotated_weight = rot_func(weight.t(), rot_mat, K).t() _update_weights(module, rotated_weight, 'weight') elif axis == 1: - rotated_weight = rot_func(weight.to(torch.float32), rot_mat.to(torch.float32), - K).to(original_dtype) + rotated_weight = rot_func(weight, rot_mat, K) _update_weights(module, rotated_weight, 'weight') else: raise RuntimeError("Not supported yet") @@ -1360,13 +1356,10 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= weight = module.weight.data original_dtype = next(module.parameters()).dtype if axis == 1: - rotated_weight = rot_func(weight.to(torch.float32), rot_mat.to(torch.float32), - K).to(original_dtype) + rotated_weight = rot_func(weight, rot_mat, K) _update_weights(module, rotated_weight, 'weight') elif axis == 0: - rotated_weight = rot_func( - weight.t().to(torch.float32), rot_mat.to(torch.float32), - K).t().to(original_dtype) + rotated_weight = rot_func(weight.t(), rot_mat, K).t() _update_weights(module, rotated_weight, 'weight') else: raise RuntimeError("Not supported yet") From 04715435b7d4f77fe426f97ff0cb258ab2980240 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 13 Jan 2025 10:22:15 +0000 Subject: [PATCH 4/7] Attempt revert --- src/brevitas/graph/equalize.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index b9f9e36b8..fe5ab016c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1332,12 +1332,11 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= module.allocate_params(module) axis = _get_output_axis(module) weight = module.weight.data + if axis == 0: - rotated_weight = rot_func(weight.t(), rot_mat, K).t() - _update_weights(module, rotated_weight, 'weight') + weight = rot_func(weight.t(), rot_mat, K).t() elif axis == 1: - rotated_weight = rot_func(weight, rot_mat, K) - _update_weights(module, rotated_weight, 'weight') + weight = rot_func(weight, rot_mat, K) else: raise RuntimeError("Not supported yet") @@ -1354,7 +1353,7 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= module.allocate_params(module) axis = _get_input_axis(module) weight = module.weight.data - original_dtype = next(module.parameters()).dtype + if axis == 1: rotated_weight = rot_func(weight, rot_mat, K) _update_weights(module, rotated_weight, 'weight') From a32a38bb0cda379692900aefdbacda55d7323237 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 13 Jan 2025 10:23:14 +0000 Subject: [PATCH 5/7] attempt revert pt2 --- src/brevitas/graph/equalize.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index fe5ab016c..3cf095075 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1339,6 +1339,7 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= weight = rot_func(weight, rot_mat, K) else: raise RuntimeError("Not supported yet") + module.weight.data = weight if getattr(module, 'bias', None) is not None: bias = module.bias.data From 7a320ff58bd266e05ed3e5b1eea6436ffa829915 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 13 Jan 2025 10:33:47 +0000 Subject: [PATCH 6/7] Fix tests --- src/brevitas/graph/equalize.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 3cf095075..7cbe38c6a 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1334,12 +1334,13 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= weight = module.weight.data if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() + rotated_weight = rot_func(weight.t(), rot_mat, K).t() + _update_weights(module, rotated_weight, 'weight') elif axis == 1: - weight = rot_func(weight, rot_mat, K) + rotated_weight = rot_func(weight, rot_mat, K) + _update_weights(module, rotated_weight, 'weight') else: raise RuntimeError("Not supported yet") - module.weight.data = weight if getattr(module, 'bias', None) is not None: bias = module.bias.data @@ -1518,7 +1519,7 @@ def apply(self, eq_layers = set() orphan_regions = [] self.find_module(graph_model, orphan_regions) - if self.rotate_sdpa: + if self.sdpa_regions: sdpa_regions = self.rotate_sdpa(graph_model) regions.extend(sdpa_regions) for r in regions: From 1a37ac7737a6824d2b0b120c790a6506f9e879da Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 13 Jan 2025 21:06:38 +0000 Subject: [PATCH 7/7] addeed test --- tests/brevitas_examples/test_llm.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index c02a3e320..ca1c7cda7 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -434,7 +434,8 @@ def test_small_models_quant_layer(caplog, layer_args): "llama-int8-act_equalization=layerwise", "mistral-int8-quant-last-layer", "llama-rotation-mixed-fx", - "llama-rotation-full-fx",], + "llama-rotation-full-fx", + "llama-rotation-full-fx-sdpa"], params=[ { "model": "hf-internal-testing/tiny-random-MistralForCausalLM", @@ -547,6 +548,22 @@ def test_small_models_quant_layer(caplog, layer_args): "": 15, # LM Head + Q/K/V projs + Up/Gate/Down projs "": 5, # Input + Post attention + "": 0,}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "ln_affine_merge": True, + "replace_rmsnorm": True, + "quantize_last_layer": True, + "no_quantize": True, + "rotation_orphan_sink": True, + "convert_layernorm_to_rmsnorm": True, + "rotation_sdpa_regions": True, + "rotation": "fx", + "exp_layer_types_count": { + "": 2, # Sinks: Only Down proj + "": + 15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs + "": 5, "": 0,}},]) def layer_args_types_count(default_run_args, request): args = default_run_args