diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 461c7bc92..93e4435de 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +from functools import wraps from typing import List, Optional, Tuple import torch @@ -127,3 +128,21 @@ def is_broadcastable(tensor, other): else: return False return True + + +def torch_dtype(dtype): + + def decorator(fn): + + @wraps(fn) + def wrapped_fn(*args, **kwargs): + cur_dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(dtype) + fn(*args, **kwargs) + finally: + torch.set_default_dtype(cur_dtype) + + return wrapped_fn + + return decorator diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 6c7e26f31..148507ada 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from hypothesis import given +from hypothesis import settings import mock import pytest import torch @@ -15,9 +16,11 @@ from brevitas.core.scaling import FloatScaling from brevitas.function.ops import max_float from brevitas.utils.torch_utils import float_internal_scale +from brevitas.utils.torch_utils import torch_dtype from tests.brevitas.hyp_helper import float_st from tests.brevitas.hyp_helper import float_tensor_random_shape_st from tests.brevitas.hyp_helper import random_minifloat_format +from tests.brevitas.hyp_helper import random_minifloat_format_and_value from tests.marker import jit_disabled_for_mock @@ -233,3 +236,35 @@ def test_inner_scale(inp, minifloat_format, scale): out_nans = out.isnan() expected_out_nans = expected_out.isnan() assert torch.equal(out[~out_nans], expected_out[~expected_out_nans]) + + +@given( + minifloat_format_and_value=random_minifloat_format_and_value( + min_bit_width=4, max_bit_with=10, rand_exp_bias=True)) +@settings(max_examples=1000) +@jit_disabled_for_mock() +@torch_dtype(torch.float64) +@torch.no_grad() +def test_valid_float_values(minifloat_format_and_value): + minifloat_value, exponent, mantissa, sign, bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = minifloat_format_and_value + scaling_impl = mock.Mock(side_effect=lambda x, y: 1.0) + float_scaling = FloatScaling(None, None, True) + float_clamp = FloatClamp( + tensor_clamp_impl=TensorClamp(), + signed=signed, + inf_values=None, + nan_values=None, + saturating=True) + float_quant = FloatQuant( + bit_width=bit_width, + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, + exponent_bias=exponent_bias, + signed=signed, + input_view_impl=Identity(), + scaling_impl=scaling_impl, + float_scaling_impl=float_scaling, + float_clamp_impl=float_clamp) + inp = torch.tensor(minifloat_value) + quant_value, *_ = float_quant(inp) + assert torch.equal(inp, quant_value) diff --git a/tests/brevitas/hyp_helper.py b/tests/brevitas/hyp_helper.py index c3fd6a82a..baad3f6bd 100644 --- a/tests/brevitas/hyp_helper.py +++ b/tests/brevitas/hyp_helper.py @@ -227,16 +227,29 @@ def min_max_tensor_random_shape_st(draw, min_dims=1, max_dims=4, max_size=3, wid @st.composite -def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with=MAX_INT_BIT_WIDTH): +def random_minifloat_format( + draw, + min_bit_width=MIN_INT_BIT_WIDTH, + max_bit_with=MAX_INT_BIT_WIDTH, + rand_exp_bias=False, + valid_only=False): """" Generate a minifloat format. Returns bit_width, exponent, mantissa, and signed. """ # TODO: add support for new minifloat format that comes with FloatQuantTensor bit_width = draw(st.integers(min_value=min_bit_width, max_value=max_bit_with)) - exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width)) - signed = draw(st.booleans()) - - exponent_bias = 2 ** (exponent_bit_width - 1) - 1 + if valid_only: + # Only works if min_bit_width >= 3 + signed = draw(st.booleans()) + exponent_bit_width = draw(st.integers(min_value=1, max_value=bit_width - 1 - int(signed))) + else: + exponent_bit_width = draw(st.integers(min_value=0, max_value=bit_width)) + signed = draw(st.booleans()) + + if rand_exp_bias: + exponent_bias = draw(st.integers(min_value=-127, max_value=127)) + else: + exponent_bias = 2 ** (exponent_bit_width - 1) - 1 # if no budget is left, return if bit_width == exponent_bit_width: @@ -246,3 +259,37 @@ def random_minifloat_format(draw, min_bit_width=MIN_INT_BIT_WIDTH, max_bit_with= mantissa_bit_width = bit_width - exponent_bit_width - int(signed) return bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias + + +@st.composite +def random_valid_minifloat( + draw, bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias): + """" + Generate a random floating-point value that can be represented in the specified minifloat format. + """ + # Sanity-check that the format is valid + assert bit_width == exponent_bit_width + mantissa_bit_width + int(signed) + # Generate int values of the minifloat components + sign = draw(st.integers(min_value=0, max_value=int(signed))) + mantissa = draw(st.integers(min_value=0, max_value=int(2 ** mantissa_bit_width - 1))) + exponent = draw(st.integers(min_value=0, max_value=int(2 ** exponent_bit_width - 1))) + # Scale mantissa between 0-1 + mantissa_fixed = mantissa / 2 ** mantissa_bit_width + # Add 1 unless denormalised + mantissa_fixed += 0. if exponent == 0 else 1. + # Adjust exponent if denormalised, otherwise leave it unchanged + exponent_value = 1 if exponent == 0 else exponent + valid_minifloat = ((-1.) ** sign) * (mantissa_fixed * 2 ** (exponent_value - exponent_bias)) + return valid_minifloat, exponent, mantissa, sign + + +@st.composite +def random_minifloat_format_and_value( + draw, + min_bit_width=MIN_INT_BIT_WIDTH, + max_bit_with=MAX_INT_BIT_WIDTH, + rand_exp_bias=False, + valid_format_only=True): + bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias = draw(random_minifloat_format(min_bit_width=min_bit_width, max_bit_with=max_bit_with, rand_exp_bias=rand_exp_bias, valid_only=valid_format_only)) + valid_minifloat, exponent, mantissa, sign = draw(random_valid_minifloat(bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, signed=signed, exponent_bias=exponent_bias)) + return valid_minifloat, exponent, mantissa, sign, bit_width, exponent_bit_width, mantissa_bit_width, signed, exponent_bias