Skip to content

Commit

Permalink
Feat (rotation): equalize across SDPA (#1149)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jan 13, 2025
1 parent 8546589 commit e3de228
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 8 deletions.
57 changes: 52 additions & 5 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,12 +1334,13 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
weight = module.weight.data

if axis == 0:
weight = rot_func(weight.t(), rot_mat, K).t()
rotated_weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rotated_weight, 'weight')
elif axis == 1:
weight = rot_func(weight, rot_mat, K)
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")
module.weight.data = weight

if getattr(module, 'bias', None) is not None:
bias = module.bias.data
Expand All @@ -1356,9 +1357,11 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
weight = module.weight.data

if axis == 1:
_update_weights(module, rot_func(weight, rot_mat, K), 'weight')
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
elif axis == 0:
_update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight')
rotated_weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")

Expand Down Expand Up @@ -1428,6 +1431,7 @@ def __init__(
self,
blacklist_layers: Optional[List[str]] = None,
orphan_sink: bool = False,
sdpa_regions: bool = False,
rotate_matmul: bool = False,
full_rotation_method: str = 'had',
return_rewriters: bool = False) -> None:
Expand All @@ -1445,6 +1449,7 @@ def __init__(
self.rotate_matmul = rotate_matmul
self.full_rotation_method = full_rotation_method
self.return_rewriters = return_rewriters
self.sdpa_regions = sdpa_regions

def rotate_matmuls(self, graph_module):
matmul_nodes = list(graph_module.graph.nodes)
Expand All @@ -1463,6 +1468,44 @@ def rotate_matmuls(self, graph_module):
graph_module.recompile()
graph_module.graph.lint()

def rotate_sdpa(self, graph_module):
sdpa_nodes = list(graph_module.graph.nodes)
sdpa_nodes = [
c for c in sdpa_nodes if 'scaled_dot_product' in str(c.meta.get('orig_target', 'None'))]
regions = []

def find_src(node):
if node.op != 'call_module':
return find_src(node.args[0])
else:
return node

def find_sink(node):
output_node = list(node.users.keys())[0]
if output_node.op != 'call_module':
return find_sink(output_node)
else:
return output_node

for sdpa_node in sdpa_nodes:
value_input = sdpa_node.args[-1]

value_node = find_src(value_input)
output_node = find_sink(value_input)
sink_module = get_module(graph_module, output_node.target)
src_module = get_module(graph_module, value_node.target)
sink_weight = get_weight_sink(sink_module)
src_weight = get_weight_source(src_module)
sink_eq_indexes = EqualizationIndexes(0, sink_weight.shape[0], 0)
src_eq_indexes = EqualizationIndexes(0, src_weight.shape[0], 0)
region = Region(
srcs={'src0': src_eq_indexes},
sinks={'sink0': sink_eq_indexes},
name_to_module={
'src0': src_module, 'sink0': sink_module})
regions.append(region)
return regions

def apply(self,
graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]:
rewriters = []
Expand All @@ -1476,9 +1519,13 @@ def apply(self,
eq_layers = set()
orphan_regions = []
self.find_module(graph_model, orphan_regions)
if self.sdpa_regions:
sdpa_regions = self.rotate_sdpa(graph_model)
regions.extend(sdpa_regions)
for r in regions:
id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names]
eq_layers.update(id_list)

if self.orphan_sink:
for o_r in orphan_regions:
# Layerwise have only a single sink named 'sinks0'
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--replace-mha] [--weight-equalization]
[--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--rotation-sdpa-regions]
[--act-equalization {None,layerwise,fx}]
[--act-equalization-alpha ACT_EQUALIZATION_ALPHA]
[--load-awq LOAD_AWQ]
Expand Down Expand Up @@ -184,6 +185,9 @@ options:
--rotation-orphan-sink
If GraphRotation is enabled, decide wheter to add
standalone hadamard matrices for the unfused layers
--rotation-sdpa-regions
If GraphRotation is enabled, decide wheter to equalize
across SDPA
--act-equalization {None,layerwise,fx}
Apply activation equalization (SmoothQuant). Layerwise
introduces standalone mul nodes,while fx merges them
Expand Down
11 changes: 9 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def fused_rotation_no_fx(model, calibration_loader, args):
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
return_rewriters=True)
return_rewriters=True,
sdpa_regions=args.rotation_sdpa_regions)
new_model, rewriters = eq.apply(new_model)
rewriters = fix_rewriter(rewriters, model, 'weight')

Expand Down Expand Up @@ -297,7 +298,9 @@ def quantize_llm(args):
if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode)
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
Expand Down Expand Up @@ -789,6 +792,10 @@ def parse_args(args, override_defaults={}):
help=
'If GraphRotation is enabled, decide wheter to add standalone hadamard matrices for the unfused layers'
)
parser.add_argument(
'--rotation-sdpa-regions',
action="store_true",
help='If GraphRotation is enabled, decide wheter to equalize across SDPA')
parser.add_argument(
'--act-equalization',
default=None,
Expand Down
19 changes: 18 additions & 1 deletion tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ def test_small_models_quant_layer(caplog, layer_args):
"llama-int8-act_equalization=layerwise",
"mistral-int8-quant-last-layer",
"llama-rotation-mixed-fx",
"llama-rotation-full-fx",],
"llama-rotation-full-fx",
"llama-rotation-full-fx-sdpa"],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
Expand Down Expand Up @@ -547,6 +548,22 @@ def test_small_models_quant_layer(caplog, layer_args):
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5, # Input + Post attention
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
"replace_rmsnorm": True,
"quantize_last_layer": True,
"no_quantize": True,
"rotation_orphan_sink": True,
"convert_layernorm_to_rmsnorm": True,
"rotation_sdpa_regions": True,
"rotation": "fx",
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>": 2, # Sinks: Only Down proj
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5,
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},])
def layer_args_types_count(default_run_args, request):
args = default_run_args
Expand Down

0 comments on commit e3de228

Please sign in to comment.