diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 5de901de3..da7db062c 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -102,6 +102,7 @@ def __init__(self, quant_layer, quant_injector): self.skip_create_quant_tensor = False quant_delay_steps = quant_injector.quant_delay_steps if 'quant_delay_steps' in quant_injector else None self.delay_wrapper = DelayWrapper(quant_delay_steps) + self.observer_only = False @property def input_view_impl(self): @@ -191,7 +192,8 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: out = self.fused_activation_quant_proxy(y) quant_value, *quant_args = out quant_args = tuple(quant_args) - quant_value = self.dequantize(*((quant_value,) + quant_args)) + if not self.observer_only: + quant_value = self.dequantize(*((quant_value,) + quant_args)) quant_value = self.delay_wrapper(y, quant_value) out = (quant_value,) + quant_args # If y is an empty QuantTensor, we need to check if this is a passthrough proxy,