diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 7e3347a21..01fb8a63b 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -62,9 +62,14 @@ def prepare_for_export(self, module: nn.Module): self.bit_width = module.bit_width() self.min_clamp = min_int(module.is_signed, module.is_narrow_range, self.bit_width) self.max_clamp = max_int(module.is_signed, module.is_narrow_range, self.bit_width) + if hasattr(module.tensor_quant, 'int_quant'): + self.float_to_int_impl = module.tensor_quant.int_quant.float_to_int_impl + elif hasattr(module, 'fused_activation_quant_proxy'): + self.float_to_int_impl = module.fused_activation_quant_proxy.tensor_quant.int_quant.float_to_int_impl def quantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tuple[Tensor]: - return torch.clamp(torch.round(x / scale + zero_point), self.min_clamp, self.max_clamp) + return torch.clamp( + self.float_to_int_impl(x / scale + zero_point), self.min_clamp, self.max_clamp) def dequantize(self, x: Tensor, scale: Tensor, zero_point: Tensor) -> Tensor: return (x - zero_point) * scale