Skip to content

Commit

Permalink
Address comments and new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 16, 2025
1 parent 28144be commit 80d247a
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 38 deletions.
8 changes: 4 additions & 4 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,8 @@ def act_handler(model, layer_map):
if node.op == 'call_module':
module = get_module(model, node.target)
if isinstance(module, tuple(layer_map.keys())):
if layer_map[type(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type(module)]
if layer_map[type_before_parametrizations(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)]
quant_module = quant_module_class(**quant_module_kwargs)
# Check for activation equalization mul nodes
if len(node.users) == 1:
Expand Down Expand Up @@ -470,8 +470,8 @@ def layer_handler(
quant_identity_map=quant_identity_map,
quant_act_map=quant_act_map,
unsigned_act_tuple=unsigned_act_tuple)
if layer_map[type(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type(module)]
if layer_map[type_before_parametrizations(module)] is not None:
quant_module_class, quant_module_kwargs = layer_map[type_before_parametrizations(module)]
# Quantize the input if is not quantized, input_quant is not specified,
# and the quant_identity_map is provided.
if not are_inputs_quantized_and_aligned(
Expand Down
13 changes: 13 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool =
full_rotation_method=args.rotation_mode,
return_rewriters=True,
sdpa_regions=args.rotation_sdpa_regions)
# NOTE: When fuse_rotations=False, parametrized rotations are applied, i.e. the weights of
# selected modules stop being attributes but, instead, properties, and their value is
# computed by passing the original value of the tensor through the forward passes of the
# parametrization modules. Parametrizations are registered using
# torch.nn.utils.parametrize.register_parametrization, which modifies the __class__
# attribute of the parametrized module, e.g. "<class 'torch.nn.modules.linear.Linear'>"
# changes to "<class 'torch.nn.utils.parametrize.ParametrizedLinear'>". Therefore,
# algorithms that do type checking might need to use type_before_parametrizations(module),
# instead of only type(module) (see layerwise_layer_handler). Moreover, if, for instance,
# the "weight" attribute is parametrized, it will be removed from the attributes
# of the class. Consequently, quantization algorithms that rely on in-place modifications
# of the weights should not operate on parametrized modules. In this situation, parametrizations
# need to be removed beforehand by invoking fuse_parametrized_rotations
new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations)
rewriters = fix_rewriter(rewriters, model, 'weight')
for r in rewriters:
Expand Down
49 changes: 49 additions & 0 deletions tests/brevitas/graph/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from brevitas.graph.base import _remove_parametrization_entries_state_dict
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.quantize import quantize
from brevitas.utils.python_utils import recurse_getattr
from brevitas.utils.rotation_utils import RotationWeightParametrization
from tests.marker import requires_pt_ge


@pytest_cases.parametrize(
Expand Down Expand Up @@ -142,3 +144,50 @@ def test_remove_parametrization_entries_state_dict(kwargs):
assert key in expected_state_dict_keys, f"Unexpected key {key} in state_dict"
# Compare tensor values
assert torch.allclose(value, old_state_dict[key], rtol=0.0, atol=0.0), f"Value of tensor {value} does not match with that in the original state_dict"


@requires_pt_ge('2.3.1')
@pytest_cases.parametrize(
'kwargs',
[
{
'model': nn.Sequential(nn.Linear(2, 3)),
'sample_input': torch.tensor([[0.8, -0.6]]),
'rot_mat': torch.tensor([[1., -1.], [1., 1.]]) / torch.sqrt(torch.tensor(2.)),
'rot_func': lambda tensor,
rot_mat,
K: torch.matmul(tensor, rot_mat),
'key': '0',
'expected': "<class 'torch.nn.utils.parametrize.ParametrizedQuantLinear'>"},])
def test_quantize_parametrized_modules(kwargs):
key = kwargs['key']
exp = kwargs['expected']
rot_mat = kwargs['rot_mat']
rot_func = kwargs['rot_func']
sample_input = kwargs['sample_input']
model = kwargs["model"]

graph_model, _ = torch._dynamo.export(model)(sample_input)
orig_module = recurse_getattr(model, key)
# Use tied weights to identify equivalent model
key, module = [(key, module) for key, module in graph_model.named_modules() if hasattr(module, "weight") and module.weight is orig_module.weight][0]
# Register rotation parametrization to module
parametrize.register_parametrization(
module=module,
tensor_name="weight",
parametrization=RotationWeightParametrization(
rot_mat=nn.Parameter(rot_mat),
rot_func=rot_func,
axis=1,
K=None,
))
qmodel = quantize(graph_model)
checked = False
found_names = []
for n, m in qmodel.named_modules():
found_names.append(n)
if n == key:
mt = str(type(m))
assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}"
checked = True
assert checked, f"Layer named {key} not found. Layer names are: {found_names}"
104 changes: 70 additions & 34 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import platform
import shutil
from unittest.mock import patch

import numpy as np
import onnx
Expand All @@ -23,21 +24,16 @@
from tests.marker import jit_disabled_for_export
from tests.marker import requires_pt_ge

ATOL_PPL = 2e+02
RTOL_PPL = 1e-04


def ptid2pathname(string):
return string.replace("/", "-").replace(":", "-")


def allclose(x, y):
return np.allclose(x, y, rtol=1e-03, atol=1e+01, equal_nan=False)


def allveryclose(x, y):
return np.allclose(x, y, rtol=1e-04, atol=2e+02, equal_nan=False)


def allexact(x, y):
return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False)
def allclose(x, y, rtol=RTOL_PPL, atol=ATOL_PPL):
return np.allclose(x, y, rtol=rtol, atol=atol, equal_nan=False)


def transformers_version_ge(required_version: str):
Expand Down Expand Up @@ -252,8 +248,8 @@ def test_small_models_acc(caplog, acc_args_and_acc):
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
float_ppl = float_ppl.detach().cpu().numpy()
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"
assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"


@pytest_cases.fixture(
Expand Down Expand Up @@ -294,8 +290,8 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
float_ppl = float_ppl.detach().cpu().numpy()
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"
assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"


@pytest_cases.fixture(
Expand Down Expand Up @@ -738,8 +734,8 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl):
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
float_ppl = float_ppl.detach().cpu().numpy()
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"
assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"


@pytest_cases.fixture(
Expand All @@ -760,7 +756,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl):
"rotation_orphan_sink": True,
"rotation_mode": "ort",
"float_ppl": 33238.8984375,
"quant_ppl": 33232.65234375},
"quant_ppl": 33232.65234375,},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
Expand All @@ -771,7 +767,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl):
"rotation_orphan_sink": False,
"rotation_mode": "ort",
"float_ppl": 33238.8984375,
"quant_ppl": 33420.65234375},
"quant_ppl": 33420.65234375,},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
Expand All @@ -782,7 +778,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl):
"rotation_orphan_sink": True,
"rotation_mode": "had",
"float_ppl": 33238.8984375,
"quant_ppl": 33290.48046875},
"quant_ppl": 33290.48046875,},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
Expand All @@ -793,7 +789,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl):
"rotation_orphan_sink": False,
"rotation_mode": "had",
"float_ppl": 33238.8984375,
"quant_ppl": 33204.80859375},
"quant_ppl": 33204.80859375,},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
Expand All @@ -802,7 +798,7 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl):
"replace_rmsnorm": True,
"rotation": "layerwise",
"float_ppl": 33238.8984375,
"quant_ppl": 33446.734375},])
"quant_ppl": 33446.734375,},])
def rotation_ppl_args_and_ppl(default_run_args, request):
args = default_run_args
run_dict = request.param
Expand All @@ -823,8 +819,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
float_ppl = float_ppl.detach().cpu().numpy()
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"
assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"


@pytest_cases.fixture(
Expand Down Expand Up @@ -857,7 +853,12 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"--save_strategy",
"no"],
"float_ppl": 33238.8984375,
"quant_ppl": 33232.65234375},
"quant_ppl": 33278.98828125,
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>": 4,
"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>": 1,
"<class 'torch.nn.utils.parametrize.ParametrizedEmbedding'>": 1,
"<class 'torch.nn.utils.parametrize.ParametrizedQuantLinear'>": 14,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
Expand All @@ -881,7 +882,12 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"--save_strategy",
"no"],
"float_ppl": 33238.8984375,
"quant_ppl": 33420.65234375},
"quant_ppl": 33424.73046875,
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>": 0,
"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>": 1,
"<class 'torch.nn.utils.parametrize.ParametrizedEmbedding'>": 1,
"<class 'torch.nn.utils.parametrize.ParametrizedQuantLinear'>": 14,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
Expand All @@ -905,7 +911,12 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"--save_strategy",
"no"],
"float_ppl": 33238.8984375,
"quant_ppl": 33290.48046875},
"quant_ppl": 33339.21875,
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>": 4,
"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>": 1,
"<class 'torch.nn.utils.parametrize.ParametrizedEmbedding'>": 1,
"<class 'torch.nn.utils.parametrize.ParametrizedQuantLinear'>": 14,}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"act_calibration": False,
Expand All @@ -929,28 +940,53 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"--save_strategy",
"no"],
"float_ppl": 33238.8984375,
"quant_ppl": 33204.80859375},])
def rotation_optimization_args_and_ppl(default_run_args, request):
"quant_ppl": 33219.08984375,
"exp_layer_types_count": {
"<class 'brevitas.nn.equalized_layer.RotatedModule'>": 0,
"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>": 1,
"<class 'torch.nn.utils.parametrize.ParametrizedEmbedding'>": 1,
"<class 'torch.nn.utils.parametrize.ParametrizedQuantLinear'>": 14,}},])
def rotation_optimization_args_layer_count_and_ppl(default_run_args, request):
args = default_run_args
run_dict = request.param
unknown_args = run_dict["unknown_args"]
float_ppl = run_dict["float_ppl"]
quant_ppl = run_dict["quant_ppl"]
exp_layer_types_count = run_dict["exp_layer_types_count"]
del run_dict["float_ppl"]
del run_dict["quant_ppl"]
del run_dict["unknown_args"]
del run_dict["exp_layer_types_count"]
args.update(**run_dict)
yield args, unknown_args, float_ppl, quant_ppl
yield args, unknown_args, float_ppl, quant_ppl, exp_layer_types_count


@requires_pt_ge('2.4')
def test_small_models_rotation_optimization_ppl(caplog, rotation_optimization_args_and_ppl):
def test_small_models_rotation_optimization_ppl(
caplog, rotation_optimization_args_layer_count_and_ppl):
if platform.system() == "Windows":
pytest.skip("Skipping dynamo + windows")
# Tolerances are stricter for this test, to ensure that it does not pass
# with non-optimized quantized perplexities
RTOL_ROT, ATOL_ROT = 1e-05, 2.
caplog.set_level(logging.INFO)
args, unknown_args, exp_float_ppl, exp_quant_ppl = rotation_optimization_args_and_ppl
float_ppl, quant_ppl, model = validate_args_and_run_main(args, unknown_args)
args, unknown_args, exp_float_ppl, exp_quant_ppl, _ = rotation_optimization_args_layer_count_and_ppl
float_ppl, quant_ppl, _ = validate_args_and_run_main(args, unknown_args)
float_ppl = float_ppl.detach().cpu().numpy()
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"
assert allclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
assert allclose(exp_quant_ppl, quant_ppl, rtol=RTOL_ROT, atol=ATOL_ROT), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}"


@requires_pt_ge('2.4')
def test_small_models_rotation_optimization_layer_count(
caplog, rotation_optimization_args_layer_count_and_ppl):
if platform.system() == "Windows":
pytest.skip("Skipping dynamo + windows")
# Tolerances are stricter for this test, to ensure that it does not pass
# with non-optimized quantized perplexities
caplog.set_level(logging.INFO)
args, unknown_args, _, _, exp_layer_types_count = rotation_optimization_args_layer_count_and_ppl
with patch('brevitas_examples.llm.main.fuse_parametrized_rotations', lambda model: model):
_, _, model = validate_args_and_run_main(args, unknown_args)
assert_layer_types_count(model, exp_layer_types_count)

0 comments on commit 80d247a

Please sign in to comment.