Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (rotation): equalize across SDPA #1149

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,12 +1334,13 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
weight = module.weight.data

if axis == 0:
weight = rot_func(weight.t(), rot_mat, K).t()
rotated_weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rotated_weight, 'weight')
elif axis == 1:
weight = rot_func(weight, rot_mat, K)
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")
module.weight.data = weight

if getattr(module, 'bias', None) is not None:
bias = module.bias.data
Expand All @@ -1356,9 +1357,11 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
weight = module.weight.data

if axis == 1:
_update_weights(module, rot_func(weight, rot_mat, K), 'weight')
rotated_weight = rot_func(weight, rot_mat, K)
_update_weights(module, rotated_weight, 'weight')
elif axis == 0:
_update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight')
rotated_weight = rot_func(weight.t(), rot_mat, K).t()
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")

Expand Down Expand Up @@ -1428,6 +1431,7 @@ def __init__(
self,
blacklist_layers: Optional[List[str]] = None,
orphan_sink: bool = False,
sdpa_regions: bool = False,
rotate_matmul: bool = False,
full_rotation_method: str = 'had',
return_rewriters: bool = False) -> None:
Expand All @@ -1445,6 +1449,7 @@ def __init__(
self.rotate_matmul = rotate_matmul
self.full_rotation_method = full_rotation_method
self.return_rewriters = return_rewriters
self.sdpa_regions = sdpa_regions

def rotate_matmuls(self, graph_module):
matmul_nodes = list(graph_module.graph.nodes)
Expand All @@ -1463,6 +1468,44 @@ def rotate_matmuls(self, graph_module):
graph_module.recompile()
graph_module.graph.lint()

def rotate_sdpa(self, graph_module):
sdpa_nodes = list(graph_module.graph.nodes)
sdpa_nodes = [
c for c in sdpa_nodes if 'scaled_dot_product' in str(c.meta.get('orig_target', 'None'))]
regions = []

def find_src(node):
if node.op != 'call_module':
return find_src(node.args[0])
else:
return node

def find_sink(node):
output_node = list(node.users.keys())[0]
if output_node.op != 'call_module':
return find_sink(output_node)
else:
return output_node

for sdpa_node in sdpa_nodes:
value_input = sdpa_node.args[-1]

value_node = find_src(value_input)
output_node = find_sink(value_input)
sink_module = get_module(graph_module, output_node.target)
src_module = get_module(graph_module, value_node.target)
sink_weight = get_weight_sink(sink_module)
src_weight = get_weight_source(src_module)
sink_eq_indexes = EqualizationIndexes(0, sink_weight.shape[0], 0)
src_eq_indexes = EqualizationIndexes(0, src_weight.shape[0], 0)
region = Region(
srcs={'src0': src_eq_indexes},
sinks={'sink0': sink_eq_indexes},
name_to_module={
'src0': src_module, 'sink0': sink_module})
regions.append(region)
return regions

def apply(self,
graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]:
rewriters = []
Expand All @@ -1476,9 +1519,13 @@ def apply(self,
eq_layers = set()
orphan_regions = []
self.find_module(graph_model, orphan_regions)
if self.sdpa_regions:
sdpa_regions = self.rotate_sdpa(graph_model)
regions.extend(sdpa_regions)
for r in regions:
id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names]
eq_layers.update(id_list)

if self.orphan_sink:
for o_r in orphan_regions:
# Layerwise have only a single sink named 'sinks0'
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--replace-mha] [--weight-equalization]
[--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--rotation-sdpa-regions]
[--act-equalization {None,layerwise,fx}]
[--act-equalization-alpha ACT_EQUALIZATION_ALPHA]
[--load-awq LOAD_AWQ]
Expand Down Expand Up @@ -184,6 +185,9 @@ options:
--rotation-orphan-sink
If GraphRotation is enabled, decide wheter to add
standalone hadamard matrices for the unfused layers
--rotation-sdpa-regions
If GraphRotation is enabled, decide wheter to equalize
across SDPA
--act-equalization {None,layerwise,fx}
Apply activation equalization (SmoothQuant). Layerwise
introduces standalone mul nodes,while fx merges them
Expand Down
11 changes: 9 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def fused_rotation_no_fx(model, calibration_loader, args):
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
return_rewriters=True)
return_rewriters=True,
sdpa_regions=args.rotation_sdpa_regions)
new_model, rewriters = eq.apply(new_model)
rewriters = fix_rewriter(rewriters, model, 'weight')

Expand Down Expand Up @@ -296,7 +297,9 @@ def quantize_llm(args):
if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode)
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
Expand Down Expand Up @@ -772,6 +775,10 @@ def parse_args(args, override_defaults={}):
help=
'If GraphRotation is enabled, decide wheter to add standalone hadamard matrices for the unfused layers'
)
parser.add_argument(
'--rotation-sdpa-regions',
action="store_true",
help='If GraphRotation is enabled, decide wheter to equalize across SDPA')
parser.add_argument(
'--act-equalization',
default=None,
Expand Down
19 changes: 18 additions & 1 deletion tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ def test_small_models_quant_layer(caplog, layer_args):
"llama-int8-act_equalization=layerwise",
"mistral-int8-quant-last-layer",
"llama-rotation-mixed-fx",
"llama-rotation-full-fx",],
"llama-rotation-full-fx",
"llama-rotation-full-fx-sdpa"],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
Expand Down Expand Up @@ -547,6 +548,22 @@ def test_small_models_quant_layer(caplog, layer_args):
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5, # Input + Post attention
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"ln_affine_merge": True,
"replace_rmsnorm": True,
"quantize_last_layer": True,
"no_quantize": True,
"rotation_orphan_sink": True,
"convert_layernorm_to_rmsnorm": True,
"rotation_sdpa_regions": True,
"rotation": "fx",
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>": 2, # Sinks: Only Down proj
"<class 'torch.nn.modules.linear.Linear'>":
15, # LM Head + Q/K/V/O projs + Up/Gate/Down projs
"<class 'torch.nn.modules.normalization.RMSNorm'>": 5,
"<class 'torch.nn.modules.normalization.LayerNorm'>": 0,}},])
def layer_args_types_count(default_run_args, request):
args = default_run_args
Expand Down
Loading