diff --git a/src/brevitas/nn/quant_bn.py b/src/brevitas/nn/quant_bn.py index a8047a690..765e4918d 100644 --- a/src/brevitas/nn/quant_bn.py +++ b/src/brevitas/nn/quant_bn.py @@ -16,6 +16,29 @@ class _BatchNormToQuantScaleBias(QuantScaleBias, ABC): + def __init__( + self, + num_features: int, + bias: bool = True, + weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, + bias_quant: Optional[BiasQuantType] = None, + input_quant: Optional[ActQuantType] = None, + output_quant: Optional[ActQuantType] = None, + return_quant_tensor: bool = False, + runtime_shape=(1, -1, 1), + **kwargs): + QuantScaleBias.__init__( + self, + num_features=num_features, + weight_quant=weight_quant, + bias_quant=bias_quant, + input_quant=input_quant, + output_quant=output_quant, + return_quant_tensor=return_quant_tensor, + runtime_shape=runtime_shape, + **kwargs + ) + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -59,11 +82,12 @@ def __init__( input_quant: Optional[ActQuantType] = None, output_quant: Optional[ActQuantType] = None, return_quant_tensor: bool = False, + runtime_shape=(1, -1, 1), **kwargs): super(BatchNorm1dToQuantScaleBias, self).__init__( num_features, bias=True, - runtime_shape=(1, -1, 1), + runtime_shape=runtime_shape, weight_quant=weight_quant, bias_quant=bias_quant, input_quant=input_quant, @@ -84,11 +108,12 @@ def __init__( input_quant: Optional[ActQuantType] = None, output_quant: Optional[ActQuantType] = None, return_quant_tensor: bool = False, + runtime_shape=(1, -1, 1, 1), **kwargs): super(BatchNorm2dToQuantScaleBias, self).__init__( num_features, bias=True, - runtime_shape=(1, -1, 1, 1), + runtime_shape=runtime_shape, weight_quant=weight_quant, bias_quant=bias_quant, input_quant=input_quant, diff --git a/src/brevitas/nn/quant_scale_bias.py b/src/brevitas/nn/quant_scale_bias.py index a97f54ed5..7a6427da1 100644 --- a/src/brevitas/nn/quant_scale_bias.py +++ b/src/brevitas/nn/quant_scale_bias.py @@ -48,8 +48,9 @@ def __init__( input_quant: Optional[ActQuantType] = None, output_quant: Optional[ActQuantType] = None, return_quant_tensor: bool = False, + runtime_shape=(1, -1, 1), **kwargs) -> None: - ScaleBias.__init__(self, num_features, bias) + ScaleBias.__init__(self, num_features, bias, runtime_shape=runtime_shape) QuantWBIOL.__init__( self, weight_quant=weight_quant,