diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index f74fffae8..bf9a6f1a3 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -344,3 +344,25 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: # pre-zero centering before rounding and clipping z = self.get_zero_center(x) / scale # need to scale the norm by s return z + + +class RuntimeDynamicGroupZeroPoint(brevitas.jit.ScriptModule): + + def __init__( + self, + input_view_impl: Module, + zero_point_stats_impl: Module, + int_quant: Module, + quantize_zero_point: bool) -> None: + super(RuntimeDynamicGroupZeroPoint, self).__init__() + + self.zero_point_stats_impl = zero_point_stats_impl + self.input_view_impl = input_view_impl + self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + + @brevitas.jit.script_method + def forward(self, stats_input: torch.Tensor, scale, bit_width) -> torch.Tensor: + + stats_input_reshaped = self.input_view_impl(stats_input) + out = self.zero_point_stats_impl(stats_input_reshaped) + return self.scale_shift_zero_point(-out, scale, bit_width)