Skip to content

Commit

Permalink
fix + readme
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 10, 2025
1 parent 65795e0 commit 253a51c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,13 +1525,13 @@ def apply(self,
eq_layers = set()
orphan_regions = []
self.find_module(graph_model, orphan_regions)
if self.rotate_sdpa:
sdpa_regions = self.rotate_sdpa(graph_model)
regions.extend(sdpa_regions)
for r in regions:
id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names]
eq_layers.update(id_list)

if self.rotate_sdpa:
sdpa_regions = self.rotate_sdpa(graph_model)
regions.extend(sdpa_regions)
if self.orphan_sink:
for o_r in orphan_regions:
# Layerwise have only a single sink named 'sinks0'
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--replace-mha] [--weight-equalization]
[--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--rotation-sdpa-regions]
[--act-equalization {None,layerwise,fx}]
[--act-equalization-alpha ACT_EQUALIZATION_ALPHA]
[--load-awq LOAD_AWQ]
Expand Down Expand Up @@ -184,6 +185,9 @@ options:
--rotation-orphan-sink
If GraphRotation is enabled, decide wheter to add
standalone hadamard matrices for the unfused layers
--rotation-sdpa-regions
If GraphRotation is enabled, decide wheter to equalize
across SDPA
--act-equalization {None,layerwise,fx}
Apply activation equalization (SmoothQuant). Layerwise
introduces standalone mul nodes,while fx merges them
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def fused_rotation_no_fx(model, calibration_loader, args):
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
return_rewriters=True,
sdpa_regions=args.sdpa_regions)
sdpa_regions=args.rotation_sdpa_regions)
new_model, rewriters = eq.apply(new_model)
rewriters = fix_rewriter(rewriters, model, 'weight')

Expand Down Expand Up @@ -299,7 +299,7 @@ def quantize_llm(args):
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.sdpa_regions)
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
Expand Down

0 comments on commit 253a51c

Please sign in to comment.