From cd6702afef6acb98837cebb08cf17c9e125c9d4e Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Thu, 4 Apr 2024 01:22:38 +0100 Subject: [PATCH] Fix tile PTQ transform (#3261) fix tile PTQ transfomr --- src/otx/core/model/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index f23b1ad268a..84f8cd15de8 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -19,6 +19,7 @@ from jsonargparse import ArgumentParser from lightning import LightningModule from openvino.model_api.models import Model +from openvino.model_api.tilers import Tiler from torch import Tensor, nn from torch.optim.lr_scheduler import ConstantLR from torch.optim.sgd import SGD @@ -816,9 +817,11 @@ def transform_fn(self, data_batch: T_OTXBatchDataEntity) -> np.array: """Data transform function for PTQ.""" np_data = self._customize_inputs(data_batch) image = np_data["inputs"][0] - resized_image = self.model.resize(image, (self.model.w, self.model.h)) - resized_image = self.model.input_transform(resized_image) - return self.model._change_layout(resized_image) # noqa: SLF001 + # NOTE: Tiler wraps the model, so we need to unwrap it to get the model + model = self.model.model if isinstance(self.model, Tiler) else self.model + resized_image = model.resize(image, (model.w, model.h)) + resized_image = model.input_transform(resized_image) + return model._change_layout(resized_image) # noqa: SLF001 def _read_ptq_config_from_ir(self, ov_model: Model) -> dict[str, Any]: """Generates the PTQ (Post-Training Quantization) configuration from the meta data of the given OpenVINO model.