From 253a51cfbc66ba50e2ab58f76e08b425a8526497 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 10 Jan 2025 15:25:59 +0000 Subject: [PATCH] fix + readme --- src/brevitas/graph/equalize.py | 6 +++--- src/brevitas_examples/llm/README.md | 4 ++++ src/brevitas_examples/llm/main.py | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index ad9320724..9baeb9421 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1525,13 +1525,13 @@ def apply(self, eq_layers = set() orphan_regions = [] self.find_module(graph_model, orphan_regions) + if self.rotate_sdpa: + sdpa_regions = self.rotate_sdpa(graph_model) + regions.extend(sdpa_regions) for r in regions: id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names] eq_layers.update(id_list) - if self.rotate_sdpa: - sdpa_regions = self.rotate_sdpa(graph_model) - regions.extend(sdpa_regions) if self.orphan_sink: for o_r in orphan_regions: # Layerwise have only a single sink named 'sinks0' diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 0d6fb5f42..dc1d16b38 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -50,6 +50,7 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--replace-mha] [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] [--rotation-mode {had,ort}] [--rotation-orphan-sink] + [--rotation-sdpa-regions] [--act-equalization {None,layerwise,fx}] [--act-equalization-alpha ACT_EQUALIZATION_ALPHA] [--load-awq LOAD_AWQ] @@ -184,6 +185,9 @@ options: --rotation-orphan-sink If GraphRotation is enabled, decide wheter to add standalone hadamard matrices for the unfused layers + --rotation-sdpa-regions + If GraphRotation is enabled, decide wheter to equalize + across SDPA --act-equalization {None,layerwise,fx} Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,while fx merges them diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 99d9d5013..c09b398ed 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -90,7 +90,7 @@ def fused_rotation_no_fx(model, calibration_loader, args): orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, return_rewriters=True, - sdpa_regions=args.sdpa_regions) + sdpa_regions=args.rotation_sdpa_regions) new_model, rewriters = eq.apply(new_model) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -299,7 +299,7 @@ def quantize_llm(args): eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, - sdpa_regions=args.sdpa_regions) + sdpa_regions=args.rotation_sdpa_regions) model = eq.apply(model) remove_hooks(model) elif args.rotation == 'layerwise':