Skip to content

Commit

Permalink
Feat (rotation): equalize across SDPA
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 12, 2025
1 parent b83ab89 commit ace2f33
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
67 changes: 60 additions & 7 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,14 +1332,18 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
module.allocate_params(module)
axis = _get_output_axis(module)
weight = module.weight.data

original_dtype = next(module.parameters()).dtype
if axis == 0:
weight = rot_func(weight.t(), rot_mat, K).t()
rotated_weight = rot_func(
weight.t().to(torch.float32), rot_mat.to(torch.float32),
K).t().to(original_dtype)
_update_weights(module, rotated_weight, 'weight')
elif axis == 1:
weight = rot_func(weight, rot_mat, K)
rotated_weight = rot_func(weight.to(torch.float32), rot_mat.to(torch.float32),
K).to(original_dtype)
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")
module.weight.data = weight

if getattr(module, 'bias', None) is not None:
bias = module.bias.data
Expand All @@ -1354,11 +1358,16 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
module.allocate_params(module)
axis = _get_input_axis(module)
weight = module.weight.data

original_dtype = next(module.parameters()).dtype
if axis == 1:
_update_weights(module, rot_func(weight, rot_mat, K), 'weight')
rotated_weight = rot_func(weight.to(torch.float32), rot_mat.to(torch.float32),
K).to(original_dtype)
_update_weights(module, rotated_weight, 'weight')
elif axis == 0:
_update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight')
rotated_weight = rot_func(
weight.t().to(torch.float32), rot_mat.to(torch.float32),
K).t().to(original_dtype)
_update_weights(module, rotated_weight, 'weight')
else:
raise RuntimeError("Not supported yet")

Expand Down Expand Up @@ -1428,6 +1437,7 @@ def __init__(
self,
blacklist_layers: Optional[List[str]] = None,
orphan_sink: bool = False,
sdpa_regions: bool = False,
rotate_matmul: bool = False,
full_rotation_method: str = 'had',
return_rewriters: bool = False) -> None:
Expand All @@ -1445,6 +1455,7 @@ def __init__(
self.rotate_matmul = rotate_matmul
self.full_rotation_method = full_rotation_method
self.return_rewriters = return_rewriters
self.sdpa_regions = sdpa_regions

def rotate_matmuls(self, graph_module):
matmul_nodes = list(graph_module.graph.nodes)
Expand All @@ -1463,6 +1474,44 @@ def rotate_matmuls(self, graph_module):
graph_module.recompile()
graph_module.graph.lint()

def rotate_sdpa(self, graph_module):
sdpa_nodes = list(graph_module.graph.nodes)
sdpa_nodes = [
c for c in sdpa_nodes if 'scaled_dot_product' in str(c.meta.get('orig_target', 'None'))]
regions = []

def find_src(node):
if node.op != 'call_module':
return find_src(node.args[0])
else:
return node

def find_sink(node):
output_node = list(node.users.keys())[0]
if output_node.op != 'call_module':
return find_sink(output_node)
else:
return output_node

for sdpa_node in sdpa_nodes:
value_input = sdpa_node.args[-1]

value_node = find_src(value_input)
output_node = find_sink(value_input)
sink_module = get_module(graph_module, output_node.target)
src_module = get_module(graph_module, value_node.target)
sink_weight = get_weight_sink(sink_module)
src_weight = get_weight_source(src_module)
sink_eq_indexes = EqualizationIndexes(0, sink_weight.shape[0], 0)
src_eq_indexes = EqualizationIndexes(0, src_weight.shape[0], 0)
region = Region(
srcs={'src0': src_eq_indexes},
sinks={'sink0': sink_eq_indexes},
name_to_module={
'src0': src_module, 'sink0': sink_module})
regions.append(region)
return regions

def apply(self,
graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]:
rewriters = []
Expand All @@ -1479,6 +1528,10 @@ def apply(self,
for r in regions:
id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names]
eq_layers.update(id_list)

if self.rotate_sdpa:
sdpa_regions = self.rotate_sdpa(graph_model)
regions.extend(sdpa_regions)
if self.orphan_sink:
for o_r in orphan_regions:
# Layerwise have only a single sink named 'sinks0'
Expand Down
11 changes: 9 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def fused_rotation_no_fx(model, calibration_loader, args):
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
return_rewriters=True)
return_rewriters=True,
sdpa_regions=args.sdpa_regions)
new_model, rewriters = eq.apply(new_model)
rewriters = fix_rewriter(rewriters, model, 'weight')

Expand Down Expand Up @@ -296,7 +297,9 @@ def quantize_llm(args):
if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode)
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
Expand Down Expand Up @@ -772,6 +775,10 @@ def parse_args(args, override_defaults={}):
help=
'If GraphRotation is enabled, decide wheter to add standalone hadamard matrices for the unfused layers'
)
parser.add_argument(
'--rotation-sdpa-regions',
action="store_true",
help='If GraphRotation is enabled, decide wheter to equalize across SDPA')
parser.add_argument(
'--act-equalization',
default=None,
Expand Down

0 comments on commit ace2f33

Please sign in to comment.