From b1b59a21147093cc9570c8abe1b1c2a179b6f257 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 14 Jan 2025 11:21:04 +0000 Subject: [PATCH] Fix tests --- src/brevitas/graph/equalize.py | 2 +- tests/brevitas/graph/test_transforms.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 3478ce2a8..b304fd22f 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1384,7 +1384,7 @@ def _apply_rotate( # Only "weight" is rotated tensor_names_axis = [("weight", _get_input_axis(module))] # If rotations are fused or if the module is an orphan sink, transform is applied directly onto the tensor - rewriter_class = ModuleInstanceRegisterParametrization if insert_rotation_module or fuse_rotations else ModuleInstanceRegisterParametrization + rewriter_class = ModuleInstanceTransformTensor if insert_rotation_module or fuse_rotations else ModuleInstanceRegisterParametrization # Obtain rewriters for applying the rotations for tensor_name, axis in tensor_names_axis: rewriter = rewriter_class( diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index c56294d9a..2d5c7a78f 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -359,7 +359,8 @@ def forward(self, x): model_unfused.linear.weight.data = model_fused.linear.weight.data model_fused = ModuleInstanceTransformTensor( - model_fused.linear, "weight", rot_mat, rot_func, None, axis).apply(model_fused) + model_fused.linear, "weight", RotationWeightParametrization(rot_mat, rot_func, axis, + None)).apply(model_fused) model_unfused = ModuleInstanceRegisterParametrization( model_unfused.linear, "weight",