-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A first sample version of FloatQuant
#159
Open
nghielme
wants to merge
4
commits into
fastmachinelearning:feature/float_quant
Choose a base branch
from
nghielme:float_quant
base: feature/float_quant
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
3208a89
Sample ` FloatQuant` function implemented. A sample use of the functi…
nghielme d7d35b2
[FloatQ] copy over float_quantize into custom op placeholder
maltanar 0f6633a
[Test] add test skeleton for compute_max_val and float_quantize
maltanar 491a3be
FloatQuant implementation improved to pass the nullifying tests Yaman…
nghielme File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) 2024 Nicolo Ghielmetti | ||
# Copyright (c) 2024 Advanced Micro Devices, Inc. | ||
# All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions are met: | ||
# | ||
# * Redistributions of source code must retain the above copyright notice, this | ||
# list of conditions and the following disclaimer. | ||
# | ||
# * Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
# * Neither the name of qonnx nor the names of its | ||
# contributors may be used to endorse or promote products derived from | ||
# this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
import numpy as np | ||
|
||
from qonnx.custom_op.general.quant import resolve_rounding_mode | ||
|
||
|
||
def compute_default_exponent_bias(exponent_bitwidth): | ||
return (2.0 ** (exponent_bitwidth - 1)) - 1 | ||
|
||
|
||
def compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias=None): | ||
if exponent_bias is None: | ||
exponent_bias = compute_default_exponent_bias(exponent_bitwidth) | ||
max_exponent = (2.0**exponent_bitwidth) - 1.0 - exponent_bias | ||
max_mantissa = np.sum((2.0 ** np.arange(0, -1.0 * mantissa_bitwidth - 1.0, -1.0))) | ||
max_val = max_mantissa * (2**max_exponent) | ||
return max_val | ||
|
||
|
||
def float_quantize( | ||
X, | ||
scale, | ||
exponent_bitwidth, | ||
mantissa_bitwidth, | ||
exponent_bias=None, | ||
max_val=None, | ||
rounding_mode="ROUND", | ||
lt_subnorm_to_zero=False, | ||
): | ||
"""Quantize a given floating point array to minifloat format by specifying the desired minifloat quantization""" | ||
if exponent_bias is None: | ||
exponent_bias = compute_default_exponent_bias(exponent_bitwidth) | ||
if max_val is None: | ||
max_val = compute_max_val(exponent_bitwidth, mantissa_bitwidth, exponent_bias) | ||
# copy the sign of the input | ||
sign = np.sign(X) | ||
# compute the mask of the values equal to 0 - it will always be zero at the output | ||
zero_mask = np.where(X == 0) | ||
# copy the input in order to not modify it | ||
X = X.copy() | ||
# set the zeros to 1.0 - but could be any random value | ||
X[zero_mask] = 1.0 | ||
# apply the scale to the input | ||
X /= scale | ||
# get input exponents from the floats - no need to use eps since the zeros have been already removed | ||
e_inp = np.floor(np.log2(np.abs(X))) | ||
# compute the max exponent given the exponent bitwidth. | ||
# Note: inf/NaN representation is included and it is clipped at the end of this function | ||
e_max = np.maximum(2.0 ** (exponent_bitwidth) - 1, 1.0) | ||
# compute exponent range given the max exponent. e_low represent the subnormals of the | ||
# quantized representation, e_high the infs/NaNs | ||
e_low, e_high = -e_max + exponent_bias + 1, e_max - exponent_bias | ||
# limit the value of the exponent given the quantization range | ||
e_quant = np.clip(e_inp, e_low, e_high) | ||
# compute the shift to get the quantized value rounded properly. This part basically quantize the mantissa | ||
# (round the mantissa by setting to 0 the bits not beloging to the quantised representation) | ||
round_shift = 2.0 ** (e_quant - mantissa_bitwidth) | ||
# apply the shift | ||
man = X / round_shift | ||
# round the mantissa | ||
man_quant = resolve_rounding_mode(rounding_mode)(man) | ||
# compute the max value of the mantissa (i.e. all the mantissa bits set to 1) | ||
man_max = 2.0 ** (mantissa_bitwidth + 1) - 1 | ||
# compute the min value of the mantissa (i.e. one bit at the position indicated by the exponent) | ||
man_min = 2.0**-mantissa_bitwidth | ||
# if the quantised value is a subnormal, remove 1 from the mantissa (i.e. 1 + 2**m => 2**m) | ||
man_max = np.where(e_quant != e_low, man_max, man_max - 1) | ||
# make sure the mantissa is in the representable range | ||
man_clip = np.clip(man_quant, -man_max, man_max) | ||
# go back to float representation | ||
qx = man_clip * round_shift | ||
# if it's inf or nan, saturates to sign*max_val | ||
qx = np.where(e_quant == e_high, sign * max_val, qx) | ||
if lt_subnorm_to_zero: | ||
# compute the min subnormal as the lower possible exponent x the min mantissa | ||
min_subnormal = 2.0 ** (e_low + 1) * man_min | ||
# if the value is closer to zero than the minimum subnormal then set it to 0 | ||
qx = np.where((X <= min_subnormal) & (X >= -min_subnormal), 0.0, qx) # restore the original zeros | ||
qx[zero_mask] = 0.0 | ||
# unscale the input | ||
qx *= scale | ||
return qx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) 2024 Advanced Micro Devices, Inc. | ||
# All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions are met: | ||
# | ||
# * Redistributions of source code must retain the above copyright notice, this | ||
# list of conditions and the following disclaimer. | ||
# | ||
# * Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# | ||
# * Neither the name of qonnx nor the names of its | ||
# contributors may be used to endorse or promote products derived from | ||
# this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
|
||
import numpy as np | ||
|
||
from qonnx.custom_op.general.floatquant import compute_max_val, float_quantize | ||
|
||
|
||
def test_compute_max_val(): | ||
# reference max normal values from OCP MX 1.0 standard | ||
assert compute_max_val(2, 3) == 7.5 # FP6 E2M3 | ||
assert compute_max_val(3, 2) == 28.0 # FP6 E3M2 | ||
assert compute_max_val(2, 1) == 6.0 # FP4 E2M1 | ||
|
||
|
||
def test_float_quantize(): | ||
zero_tensor = np.zeros((2, 2)) | ||
unit_scale = np.asarray([1.0], dtype=np.float32) | ||
assert np.all(float_quantize(zero_tensor, unit_scale, 2, 3) == zero_tensor) | ||
testcase_a = np.asarray([1.5], dtype=np.float32) | ||
testcase_b = np.asarray([3.25], dtype=np.float32) | ||
testcase_c = np.asarray([8.0], dtype=np.float32) | ||
testcase_d = np.asarray([28.2], dtype=np.float32) | ||
testcase_e = np.asarray([6.1], dtype=np.float32) | ||
testcase_f = np.asarray([0.124], dtype=np.float32) | ||
assert np.all(float_quantize(testcase_a, unit_scale, 2, 3) == testcase_a) | ||
assert np.all(float_quantize(testcase_b, unit_scale, 2, 3) == testcase_b) | ||
assert np.all(float_quantize(testcase_c, unit_scale, 2, 3) == compute_max_val(2, 3)) | ||
assert np.all(float_quantize(testcase_d, unit_scale, 3, 2) == compute_max_val(3, 2)) | ||
assert np.all(float_quantize(testcase_e, unit_scale, 2, 1) == compute_max_val(2, 1)) | ||
assert np.all(float_quantize(testcase_f, unit_scale, 2, 3, lt_subnorm_to_zero=True) == 0.0) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@maltanar , this name is terrible! Please, help me in finding a better one 🤦♂️