Skip to content

Commit

Permalink
Brevitas 0.2.0-alpha
Browse files Browse the repository at this point in the history
Add config file based on env variables.
Add env variable to ignore missing expected missing keys when retraining.
Add env variable to force reinit weight quantization when loading a pretrained model.
Add scaling override for custom scaling implementation.
Add const scaling based on He init.
Add affine stats based scaling, i.e. stats with scale and bias parameters.
Extend parameter from stats scaling to properly handle per-channel init.
Extend stats based scaling support to activations.
Add optional hard minimum for scale factors, to avoid NaN during training.
Add bias quantization with user defined bit width.
Add binary activations support.
Add attributes to retrieve int_weight and scale factors to quantized layers.
Fix scaling in quantized avg pool.
Start documenting WeightQuantProxy and extending its docs to QuantConv2d and QuantLinear.
Various minor refactoring.
  • Loading branch information
volcacius committed Sep 24, 2019
1 parent d994199 commit 25edbd7
Show file tree
Hide file tree
Showing 115 changed files with 33,162 additions and 414 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
[![DOI](https://zenodo.org/badge/140494324.svg)](https://zenodo.org/badge/latestdoi/140494324)

# Brevitas

Brevitas is a Pytorch library for training-aware quantization.

*Brevitas is currently in alpha stage and under active development. APIs might and probably will change. Documentation, examples, and pretrained models will be progressively released.*
*Brevitas is currently under active development and to be considered in alpha stage. APIs might and probably will change. Documentation, examples, and pretrained models will be progressively released.*

## Requirements
* [Pytorch](https://pytorch.org) >= 1.1.0
Expand All @@ -24,6 +22,10 @@ Brevitas is mainly targeted at researchers and practicioners in the fields of tr

The implementation is quite rich in options and allows for very fine grained control over the trained model. However, compared to other software solutions in this space, the burden of correctly modelling the target data-path is currently placed on the user.

## Docs

Soon.

## Features

Soon.
Expand Down
7 changes: 7 additions & 0 deletions brevitas/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os
import docrep

docstrings = docrep.DocstringProcessor()

IGNORE_MISSING_KEYS = bool(os.environ.get('BREVITAS_IGNORE_MISSING_KEYS', False))
REINIT_WEIGHT_QUANT_ON_LOAD = bool(os.environ.get('BREVITAS_REINIT_WEIGHT_QUANT_ON_LOAD', True))
25 changes: 15 additions & 10 deletions brevitas/core/bit_width.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from torch import Tensor
from torch.nn import Parameter

import brevitas.config as config
from brevitas.utils.python_utils import AutoName
from brevitas.function.ops import tensor_clamp_ste, tensor_clamp
from .restrict_val import RestrictValueOpImplType, RestrictValueType, RestrictValue, FloatToIntImplType
Expand Down Expand Up @@ -107,10 +108,11 @@ def __init__(self,
bit_width_init_op = RestrictValue.restrict_value_op(restrict_bit_width_type,
restrict_value_op_impl_type=RestrictValueOpImplType.MATH)
self.restrict_bit_width = RestrictValue(restrict_bit_width_type,
float_to_int_impl_type=FloatToIntImplType.ROUND)
float_to_int_impl_type=FloatToIntImplType.ROUND,
min_val=None)
self.bit_width_base = bit_width_init_op(min_overall_bit_width)
self.max_bit_width = bit_width_init_op(min_overall_bit_width) if max_overall_bit_width is not None else None
bit_width_offset_init = bit_width_init_op(bit_width_init) - self.bit_width_base
bit_width_offset_init = max(bit_width_init_op(bit_width_init) - self.bit_width_base, 0.0)
self.bit_width_offset = Parameter(torch.tensor(float(bit_width_offset_init)))

@torch.jit.script_method
Expand All @@ -126,7 +128,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
super(BitWidthParameter, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
bit_width_offset_key = prefix + 'bit_width_offset'
if bit_width_offset_key in missing_keys:
if config.IGNORE_MISSING_KEYS and bit_width_offset_key in missing_keys:
missing_keys.remove(bit_width_offset_key)
if self.override_pretrained and bit_width_offset_key in state_dict:
del state_dict[bit_width_offset_key]
Expand Down Expand Up @@ -162,7 +164,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
super(RemoveBitwidthParameter, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
bit_width_coeff_key = prefix + 'bit_width_coeff'
if bit_width_coeff_key in missing_keys:
if config.IGNORE_MISSING_KEYS and bit_width_coeff_key in missing_keys:
missing_keys.remove(bit_width_coeff_key)
if self.override_pretrained and bit_width_coeff_key in state_dict:
del state_dict[bit_width_coeff_key]
Expand All @@ -187,7 +189,8 @@ def __init__(self,
self.bit_width_to_remove_impl = BitWidthConst(ms_bit_width_to_clamp, RestrictValueType.INT)
elif bit_width_impl_type == BitWidthImplType.PARAMETER:
restrict_bit_width_impl = RestrictValue(RestrictValueType.INT,
float_to_int_impl_type=FloatToIntImplType.ROUND)
float_to_int_impl_type=FloatToIntImplType.ROUND,
min_val=None)
self.bit_width_to_remove_impl = RemoveBitwidthParameter(bit_width_to_remove=ms_bit_width_to_clamp,
remove_at_least_init_val=clamp_at_least_init_val,
restrict_bit_width_impl=restrict_bit_width_impl,
Expand All @@ -202,7 +205,7 @@ def forward(self, input_bit_width: Tensor, zero_hw_sentinel: Tensor) -> Tensor:
output_bit_width = torch.abs(input_bit_width - bit_width_to_remove)
output_bit_width = tensor_clamp_ste(output_bit_width,
self.min_overall_bit_width + zero_hw_sentinel,
self.max_overall_bit_width + zero_hw_sentinel)
self.max_overall_bit_width + zero_hw_sentinel) #todo STE on max only
return output_bit_width


Expand All @@ -225,7 +228,8 @@ def __init__(self,
self.bit_width_to_remove_impl = BitWidthConst(ls_bit_width_to_trunc, RestrictValueType.INT)
elif bit_width_impl_type == BitWidthImplType.PARAMETER:
restrict_bit_width_impl = RestrictValue(RestrictValueType.INT,
float_to_int_impl_type=FloatToIntImplType.ROUND)
float_to_int_impl_type=FloatToIntImplType.ROUND,
min_val=None)
self.bit_width_to_remove_impl = RemoveBitwidthParameter(bit_width_to_remove=ls_bit_width_to_trunc,
remove_at_least_init_val=trunc_at_least_init_val,
restrict_bit_width_impl=restrict_bit_width_impl,
Expand All @@ -237,8 +241,9 @@ def __init__(self,
@torch.jit.script_method
def forward(self, input_bit_width: Tensor, zero_hw_sentinel: Tensor) -> Tensor:
bit_width_to_remove = self.bit_width_to_remove_impl(zero_hw_sentinel)
min_bit_width_to_remove = input_bit_width - self.max_overall_bit_width
max_bit_width_to_remove = input_bit_width - self.min_overall_bit_width
bit_width_to_remove = torch.where(bit_width_to_remove > max_bit_width_to_remove,
max_bit_width_to_remove,
bit_width_to_remove)
bit_width_to_remove = tensor_clamp(bit_width_to_remove, # pass gradient to boundaries
min_bit_width_to_remove, # since input_bit_width is possibly learned
max_bit_width_to_remove)
return bit_width_to_remove
96 changes: 86 additions & 10 deletions brevitas/core/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,11 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from enum import auto

import torch

from brevitas.utils.python_utils import AutoName
from brevitas.function.ops import ceil_ste, round_ste, floor_ste
from brevitas.function.shape import *
from brevitas.function import tensor_clamp, tensor_clamp_ste


class TensorClampImplType(AutoName):
STE = auto()
DIFFERENTIABLE = auto()


class Identity(torch.jit.ScriptModule):
def __init__(self) -> None:
super(Identity, self).__init__()
Expand Down Expand Up @@ -122,3 +113,88 @@ def __init__(self) -> None:
@torch.jit.script_method
def forward(self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor):
return tensor_clamp(x, min_val=min_val, max_val=max_val)


class ConstScalarClamp(torch.jit.ScriptModule):
__constants__ = ['min_val, max_val']

def __init__(self, min_val, max_val) -> None:
super(ConstScalarClamp, self).__init__()
self.min_val = min_val
self.max_val = max_val

@torch.jit.script_method
def forward(self, x: torch.Tensor):
return torch.clamp(x, min=self.min_val, max=self.max_val)


class ClampMin(torch.jit.ScriptModule):
__constants__ = ['min_val']

def __init__(self, min_val: float) -> None:
super(ClampMin, self).__init__()
self.min_val = min_val

@torch.jit.script_method
def forward(self, x: torch.Tensor):
return x.clamp_min(self.min_val)


class OverTensorView(torch.jit.ScriptModule):

def __init__(self) -> None:
super(OverTensorView, self).__init__()

@torch.jit.script_method
def shape(self, x: torch.Tensor):
return over_tensor(x)

@torch.jit.script_method
def forward(self, x: torch.Tensor):
shape = self.shape(x)
return x.view(shape)


class OverOutputChannelView(torch.jit.ScriptModule):

def __init__(self) -> None:
super(OverOutputChannelView, self).__init__()

@torch.jit.script_method
def shape(self, x: torch.Tensor):
return over_output_channels(x)

@torch.jit.script_method
def forward(self, x: torch.Tensor):
shape = self.shape(x)
return x.view(shape)


class OverBatchOverTensorView(torch.jit.ScriptModule):

def __init__(self) -> None:
super(OverBatchOverTensorView, self).__init__()

@torch.jit.script_method
def shape(self, x: torch.Tensor):
return over_batch_over_tensor(x)

@torch.jit.script_method
def forward(self, x: torch.Tensor):
shape = self.shape(x)
return x.view(shape)


class OverBatchOverOutputChannelView(torch.jit.ScriptModule):

def __init__(self) -> None:
super(OverBatchOverOutputChannelView, self).__init__()

@torch.jit.script_method
def shape(self, x: torch.Tensor):
return over_batch_over_output_channels(x)

@torch.jit.script_method
def forward(self, x: torch.Tensor):
shape = self.shape(x)
return x.view(shape)
60 changes: 53 additions & 7 deletions brevitas/core/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from torch.nn import Module

from brevitas.utils.python_utils import AutoName
from brevitas.function.ops import min_int, max_int, max_uint
from brevitas.function.ops import min_int, max_int, max_uint, tensor_clamp, tensor_clamp_ste
from brevitas.function import binary_sign_ste, ternary_sign_ste


Expand All @@ -69,7 +69,7 @@ def forward(self, x: Tensor, zero_hw_sentinel: Tensor) -> Tuple[Tensor, Tensor,


class BinaryQuant(torch.jit.ScriptModule):
__constants__ = ['threshold', 'bit_width']
__constants__ = ['bit_width']

def __init__(self, scaling_impl: Module):
super(BinaryQuant, self).__init__()
Expand All @@ -83,6 +83,21 @@ def forward(self, x: Tensor, zero_hw_sentinel: Tensor) -> Tuple[Tensor, Tensor,
return y, scale, zero_hw_sentinel + self.bit_width


class ClampedBinaryQuant(torch.jit.ScriptModule):
__constants__ = ['bit_width']

def __init__(self, scaling_impl: Module):
super(ClampedBinaryQuant, self).__init__()
self.scaling_impl = scaling_impl
self.bit_width = 1

def forward(self, x: Tensor, zero_hw_sentinel: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
scale = self.scaling_impl(zero_hw_sentinel)
y = tensor_clamp(x, - scale, scale)
y = binary_sign_ste(y) * scale
return y, scale, zero_hw_sentinel + self.bit_width


class TernaryQuant(torch.jit.ScriptModule):
__constants__ = ['threshold', 'bit_width']

Expand All @@ -101,15 +116,15 @@ def forward(self, x: Tensor, zero_hw_sentinel: Tensor) -> Tuple[Tensor, Tensor,
return y, scale, zero_hw_sentinel + self.bit_width


class PrescaledRestrictIntQuant(torch.jit.ScriptModule):
class PrescaledRestrictIntQuantWithInputBitWidth(torch.jit.ScriptModule):

def __init__(self,
narrow_range: bool,
signed: bool,
tensor_clamp_impl: Module,
msb_clamp_bit_width_impl: Module,
float_to_int_impl: Module):
super(PrescaledRestrictIntQuant, self).__init__()
super(PrescaledRestrictIntQuantWithInputBitWidth, self).__init__()
self.int_quant = IntQuant(signed=signed,
narrow_range=narrow_range,
tensor_clamp_impl=tensor_clamp_impl,
Expand All @@ -127,6 +142,31 @@ def forward(self,
return y, scale, msb_clamp_bit_width


class PrescaledRestrictIntQuant(torch.jit.ScriptModule):

def __init__(self,
narrow_range: bool,
signed: bool,
tensor_clamp_impl: Module,
msb_clamp_bit_width_impl: Module,
float_to_int_impl: Module):
super(PrescaledRestrictIntQuant, self).__init__()
self.int_quant = IntQuant(signed=signed,
narrow_range=narrow_range,
tensor_clamp_impl=tensor_clamp_impl,
float_to_int_impl=float_to_int_impl)
self.msb_clamp_bit_width_impl = msb_clamp_bit_width_impl

@torch.jit.script_method
def forward(self,
x: Tensor,
scale: Tensor,
zero_hw_sentinel: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
msb_clamp_bit_width = self.msb_clamp_bit_width_impl(zero_hw_sentinel)
y = self.int_quant(scale, zero_hw_sentinel + 1, msb_clamp_bit_width, x)
return y, scale, msb_clamp_bit_width


class IdentityPrescaledIntQuant(torch.jit.ScriptModule):

@torch.jit.script_method
Expand Down Expand Up @@ -158,9 +198,11 @@ def forward(self,


class RescalingIntQuant(torch.jit.ScriptModule):
__constants__ = ['runtime']

def __init__(self,
narrow_range: bool,
runtime: bool,
signed: bool,
scaling_impl: Module,
int_scaling_impl: Module,
Expand All @@ -172,21 +214,25 @@ def __init__(self,
narrow_range=narrow_range,
tensor_clamp_impl=tensor_clamp_impl,
float_to_int_impl=float_to_int_impl)
self.runtime = runtime
self.scaling_impl = scaling_impl
self.int_scaling_impl = int_scaling_impl
self.msb_clamp_bit_width_impl = msb_clamp_bit_width_impl

@staticmethod
def scaling_init_from_min_max(min_val_init: Union[int, float], max_val_init: Union[int, float]) -> float:
def scaling_init_from_min_max(min_val_init: Union[int, float], max_val_init: Union[int, float]) -> torch.Tensor:
scaling_init = max(abs(float(min_val_init)), abs(float(max_val_init)))
return scaling_init
return torch.tensor(scaling_init)

@torch.jit.script_method
def forward(self,
x: Tensor,
zero_hw_sentinel: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
msb_clamp_bit_width = self.msb_clamp_bit_width_impl(zero_hw_sentinel)
scale = self.scaling_impl(zero_hw_sentinel)
if self.runtime:
scale = self.scaling_impl(x)
else:
scale = self.scaling_impl(zero_hw_sentinel)
int_scale = self.int_scaling_impl(msb_clamp_bit_width)
y = self.int_quant(scale, int_scale, msb_clamp_bit_width, x)
output_bit_width = msb_clamp_bit_width
Expand Down
Loading

0 comments on commit 25edbd7

Please sign in to comment.