Skip to content

Commit

Permalink
Feat (float): int / bitwidth attributes to proxy / quant tensor (#1072)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored Oct 28, 2024
1 parent 1dfed96 commit a215fc4
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 6 deletions.
7 changes: 7 additions & 0 deletions src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@

class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase, ABC):

def bit_width(self):
if not self.is_quant_enabled:
return None
x = self.__call__(self.tracked_parameter_list[0])
bit_width = x.mantissa_bit_width + x.exponent_bit_width + 1
return bit_width

def scale(self):
if not self.is_quant_enabled:
return None
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/quant_tensor/float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def _pre_round_float_value(self):
minifloat_value = minifloat_value / int_scale
return minifloat_value

def int(self):
fx_value = torch.round(self._pre_round_float_value)
return fx_value

@property
def is_valid(self):
with torch.no_grad():
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def _pre_round_float_value(self):
minifloat_value = minifloat_value / int_scale
return minifloat_value

def int(self):
fx_value = torch.round(self._pre_round_float_value)
return fx_value

@property
def is_valid(self):
with torch.no_grad():
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, quant_tensor: GroupwiseFloatQuantTensor, metadata_only: bool)
self.shape = quant_tensor.value.shape
if metadata_only:
self.value = None
self.quant_tensor = quant_tensor.set(value=None)
self.quant_tensor = quant_tensor.set(value_=None)
else:
self.quant_tensor = quant_tensor
# torch.compile compatibility
Expand Down Expand Up @@ -146,7 +146,7 @@ def __init__(self, quant_tensor: GroupwiseIntQuantTensor, metadata_only: bool):
self.shape = quant_tensor.value.shape
if metadata_only:
self.value = None
self.quant_tensor = quant_tensor.set(value=None)
self.quant_tensor = quant_tensor.set(value_=None)
else:
self.quant_tensor = quant_tensor
# torch.compile compatibility
Expand Down
94 changes: 90 additions & 4 deletions tests/brevitas/quant_tensor/test_quant_tensor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from enum import Enum

import numpy as np
from packaging import version
import pytest
import pytest_cases
Expand All @@ -13,7 +15,11 @@
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor import GroupwiseFloatQuantTensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.utils.quant_utils import _CachedIO
from brevitas.utils.quant_utils import _CachedIOFloat
from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat


class Operator(Enum):
Expand All @@ -24,14 +30,40 @@ class Operator(Enum):
MATMUL = 4


def to_quant_tensor(input: torch.Tensor) -> IntQuantTensor:
mod = QuantIdentity(bit_width=8, return_quant_tensor=True)
def to_quant_tensor(input: torch.Tensor, bit_width=8) -> IntQuantTensor:
mod = QuantIdentity(bit_width=bit_width, return_quant_tensor=True)
return mod(input)


def to_float_quant_tensor(input: torch.Tensor) -> FloatQuantTensor:
def to_float_quant_tensor(
input: torch.Tensor,
bit_width=8,
exponent_bit_width=4,
mantissa_bit_width=3) -> FloatQuantTensor:
mod = QuantIdentity(
bit_width=8, return_quant_tensor=True, act_quant=Fp8e5m2OCPActPerTensorFloat)
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
return_quant_tensor=True,
act_quant=Fp8e5m2OCPActPerTensorFloat)
return mod(input)


def to_mx_quant_tensor(
input: torch.Tensor,
bit_width=8,
exponent_bit_width=4,
mantissa_bit_width=3,
group_size=32,
group_dim=1) -> GroupwiseFloatQuantTensor:
mod = QuantIdentity(
bit_width=bit_width,
group_size=group_size,
group_dim=group_dim,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
return_quant_tensor=True,
act_quant=MXFloat8e4m3Act)
return mod(input)


Expand Down Expand Up @@ -138,3 +170,57 @@ def test_minifloat(quant_class_key_vale):
qx = q(x)
# Check that minifloat doesn't raise error
qx.minifloat()


@pytest.mark.parametrize("metadata_only", [True, False])
def test_int_quant_tensor(metadata_only, bit_width=8):
limit = np.exp2(bit_width) - 1
w = torch.randn(32, 1024)
q = to_quant_tensor(w, bit_width=bit_width)
i = q.int().float()
assert ((i.max() - i.min()) <= limit).all()
# test caching works
cache = _CachedIO(q, metadata_only=metadata_only)
assert cache.bit_width == bit_width


@pytest.mark.parametrize("metadata_only", [True, False])
def test_float_quant_tensor(metadata_only, bit_width=8, exponent_bit_width=4, mantissa_bit_width=3):
assert mantissa_bit_width + exponent_bit_width + 1 == bit_width
limit = (np.exp2(mantissa_bit_width + 1) - 1) * np.exp2(np.exp2(exponent_bit_width) - 2)
w = torch.randn(32, 1024)
q = to_float_quant_tensor(
w,
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width)
# test that the integer API returns fixed point values in the right range
i = q.int().float()
assert ((i.max() - i.min()) <= limit).all()
# test caching works
cache = _CachedIOFloat(q, metadata_only=metadata_only)
assert cache.mantissa_bit_width == mantissa_bit_width
assert cache.exponent_bit_width == exponent_bit_width


@pytest.mark.parametrize("metadata_only", [True, False])
def test_mx_quant_tensor(metadata_only, bit_width=8, exponent_bit_width=4, mantissa_bit_width=3):
assert mantissa_bit_width + exponent_bit_width + 1 == bit_width
limit = (np.exp2(mantissa_bit_width + 1) - 1) * np.exp2(np.exp2(exponent_bit_width) - 2)
w = torch.randn(32, 1024)
q = to_mx_quant_tensor(
w,
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
group_size=32,
group_dim=1)
# test that the integer API returns fixed point values in the right range
i = q.int().float()
assert ((i.max() - i.min()) <= limit).all()
# test caching works
cache = _CachedIOGroupwiseFloat(q, metadata_only=metadata_only)
assert cache.mantissa_bit_width == mantissa_bit_width
assert cache.exponent_bit_width == exponent_bit_width
assert cache.group_size == 32
assert cache.group_dim == 1

0 comments on commit a215fc4

Please sign in to comment.