diff --git a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py index 7172af130da..c3965c172ef 100644 --- a/src/otx/algo/visual_prompting/zero_shot_segment_anything.py +++ b/src/otx/algo/visual_prompting/zero_shot_segment_anything.py @@ -6,7 +6,7 @@ from __future__ import annotations import logging as log -import os +import pickle from collections import defaultdict from copy import deepcopy from itertools import product @@ -625,7 +625,7 @@ def _decide_cascade_results( class OTXZeroShotSegmentAnything(OTXZeroShotVisualPromptingModel): """Zero-Shot Visual Prompting model.""" - def __init__( + def __init__( # noqa: PLR0913 self, backbone: Literal["tiny_vit", "vit_b"], label_info: LabelInfoTypes = NullLabelInfo(), @@ -633,7 +633,8 @@ def __init__( scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, metric: MetricCallable = VisualPromptingMetricCallable, torch_compile: bool = False, - root_reference_info: Path | str = "vpm_zsl_reference_infos", + reference_info_dir: Path | str = "reference_infos", + infer_reference_info_root: Path | str = "../.latest/train", save_outputs: bool = True, pixel_mean: list[float] | None = [123.675, 116.28, 103.53], # noqa: B006 pixel_std: list[float] | None = [58.395, 57.12, 57.375], # noqa: B006 @@ -669,7 +670,8 @@ def __init__( ) self.save_outputs = save_outputs - self.root_reference_info: Path = Path(root_reference_info) + self.reference_info_dir: Path = Path(reference_info_dir) + self.infer_reference_info_root: Path = Path(infer_reference_info_root) self.register_buffer("pixel_mean", Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", Tensor(pixel_std).view(-1, 1, 1), False) @@ -877,21 +879,37 @@ def initialize_reference_info(self) -> None: self.register_buffer("reference_feats", torch.zeros(0, 1, self.model.embed_dim), False) self.register_buffer("used_indices", torch.tensor([], dtype=torch.int64), False) - def _find_latest_reference_info(self, root: Path) -> str | None: - """Find latest reference info to be used.""" - if not Path.is_dir(root): - return None - if len(stamps := sorted(os.listdir(root), reverse=True)) > 0: - return stamps[0] - return None + def save_reference_info(self, default_root_dir: Path | str) -> None: + """Save reference info.""" + reference_info = { + "reference_feats": self.reference_feats, + "used_indices": self.used_indices, + } + # save reference info + path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt" + path_reference_info.parent.mkdir(parents=True, exist_ok=True) + # TODO (sungchul): ticket no. 139210 + torch.save(reference_info, path_reference_info) + pickle.dump( + {k: v.numpy() for k, v in reference_info.items()}, + path_reference_info.with_suffix(".pickle").open("wb"), + ) + log.info(f"Saved reference info at {path_reference_info}.") - def load_latest_reference_info(self, device: str | torch.device = "cpu") -> bool: + def load_reference_info(self, default_root_dir: Path | str, device: str | torch.device = "cpu") -> bool: """Load latest reference info to be used.""" - if (latest_stamp := self._find_latest_reference_info(self.root_reference_info)) is not None: - latest_reference_info = self.root_reference_info / latest_stamp / "reference_info.pt" - reference_info = torch.load(latest_reference_info) + _infer_reference_info_root: Path = ( + self.infer_reference_info_root + if self.infer_reference_info_root == self.infer_reference_info_root.absolute() + else Path(default_root_dir) / self.infer_reference_info_root + ) + + if ( + path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pt" + ).is_file(): + reference_info = torch.load(path_reference_info) retval = True - log.info(f"reference info saved at {latest_reference_info} was successfully loaded.") + log.info(f"reference info saved at {path_reference_info} was successfully loaded.") else: reference_info = {} retval = False diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 7ba4687b612..465aaf47eab 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -1,17 +1,12 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # -"""Class definition for visual prompting model entity used in OTX.""" - -# TODO(vinnamki): There are so many mypy errors. Resolve them after refactoring visual prompting code. -# mypy: ignore-errors +"""Class definition for visual prompting models entity used in OTX.""" from __future__ import annotations import logging as log -import os import pickle -import time from collections import defaultdict from copy import deepcopy from functools import partial @@ -25,7 +20,7 @@ from torch import Tensor from torchvision import tv_tensors -from otx.core.data.entity.base import OTXBatchLossEntity, Points +from otx.core.data.entity.base import Points from otx.core.data.entity.visual_prompting import ( VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity, @@ -276,7 +271,7 @@ def _set_label_info(self, _: LabelInfoTypes) -> None: class OTXZeroShotVisualPromptingModel( OTXModel[ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity], ): - """Base class for the visual prompting models used in OTX.""" + """Base class for the zero-shot visual prompting models used in OTX.""" def __init__( self, @@ -318,7 +313,7 @@ def _export_parameters(self) -> TaskLevelExportParameters: @property def _optimization_config(self) -> dict[str, Any]: - """PTQ config for visual prompting models.""" + """PTQ config for zero-shot visual prompting models.""" return { "model_type": "transformer", "advanced_parameters": { @@ -344,8 +339,8 @@ def on_train_start(self) -> None: def on_test_start(self) -> None: """Load previously saved reference info.""" super().on_test_start() - if not self.load_latest_reference_info(self.device): - log.warning("No reference info found. `Learn` will be automatically excuted first.") + if not self.load_reference_info(self.trainer.default_root_dir, self.device): + log.warning("No reference info found. `Learn` will be automatically executed first.") self.trainer.lightning_module.automatic_optimization = False self.trainer.fit_loop.run() # to use infer logic @@ -353,12 +348,12 @@ def on_test_start(self) -> None: # to set _combined_loader self.trainer._evaluation_loop.setup_data() # noqa: SLF001 self.trainer._evaluation_loop.reset() # noqa: SLF001 - self.load_latest_reference_info(self.device) + self.load_reference_info(self.trainer.default_root_dir, self.device) def on_predict_start(self) -> None: """Load previously saved reference info.""" - if not self.load_latest_reference_info(self.device): - log.warning("No reference info found. `Learn` will be automatically excuted first.") + if not self.load_reference_info(self.trainer.default_root_dir, self.device): + log.warning("No reference info found. `Learn` will be automatically executed first.") self.trainer.lightning_module.automatic_optimization = False self.trainer.fit_loop.run() # to use infer logic @@ -366,7 +361,7 @@ def on_predict_start(self) -> None: # to set _combined_loader self.trainer._evaluation_loop.setup_data() # noqa: SLF001 self.trainer._evaluation_loop.reset() # noqa: SLF001 - self.load_latest_reference_info(self.device) + self.load_reference_info(self.trainer.default_root_dir, self.device) def on_train_epoch_start(self) -> None: """Skip on_train_epoch_start unused in zero-shot visual prompting.""" @@ -374,23 +369,7 @@ def on_train_epoch_start(self) -> None: def on_train_epoch_end(self) -> None: """Skip on_train_epoch_end unused in zero-shot visual prompting.""" if self.save_outputs: - reference_info = { - "reference_feats": self.reference_feats, - "used_indices": self.used_indices, - } - # save reference info - path_reference_info: Path = self.root_reference_info / time.strftime("%Y%m%d_%H%M%S") / "reference_info.pt" - Path.mkdir(Path(path_reference_info).parent, parents=True, exist_ok=True) - if isinstance(self, OTXZeroShotVisualPromptingModel): - torch.save(reference_info, path_reference_info) - pickle.dump( - {k: v.numpy() for k, v in reference_info.items()}, - Path.open(Path(str(path_reference_info).replace(".pt", ".pickle")), "wb"), - ) - else: - torch.save({k: torch.as_tensor(v) for k, v in reference_info.items()}, path_reference_info) - pickle.dump(reference_info, Path.open(Path(str(path_reference_info).replace(".pt", ".pickle")), "wb")) - log.info(f"Saved reference info at {path_reference_info}.") + self.save_reference_info(self.trainer.default_root_dir) def on_validation_epoch_start(self) -> None: """Skip on_validation_epoch_start unused in zero-shot visual prompting.""" @@ -411,7 +390,7 @@ def training_step( def validation_step( self, - inputs: VisualPromptingBatchDataEntity | ZeroShotVisualPromptingBatchDataEntity, + inputs: ZeroShotVisualPromptingBatchDataEntity, batch_idx: int, ) -> None: """Skip validation_step unused in zero-shot visual prompting.""" @@ -424,11 +403,11 @@ def test_step( """Perform a single test step on a batch of data from the test set. Args: - inputs (VisualPromptingBatchDataEntity): The input data for the test step. + inputs (ZeroShotVisualPromptingBatchDataEntity): The input data for the test step. batch_idx (int): The index of the current batch. Raises: - TypeError: If the predictions are not of type VisualPromptingBatchPredEntity. + TypeError: If the predictions are not of type ZeroShotVisualPromptingBatchDataEntity. """ _inference_step_for_zero_shot(model=self, metric=self.metric, inputs=inputs) @@ -584,7 +563,7 @@ def _customize_outputs( self, outputs: Any, # noqa: ANN401 inputs: VisualPromptingBatchDataEntity, # type: ignore[override] - ) -> VisualPromptingBatchPredEntity | OTXBatchLossEntity: + ) -> VisualPromptingBatchPredEntity: """Customize OTX output batch data entity if needed for model.""" masks: list[tv_tensors.Mask] = [] scores: list[torch.Tensor] = [] @@ -620,7 +599,7 @@ def check_if_quantized(model: openvino.Model) -> bool: return any(op.get_type_name() == "FakeQuantize" for op in nodes) def transform_fn( - data_batch: VisualPromptingBatchDataEntity | ZeroShotVisualPromptingBatchDataEntity, + data_batch: VisualPromptingBatchDataEntity, module: Literal["image_encoder", "decoder"], ) -> np.ndarray | dict[str, Any]: images, _, prompts = self._customize_inputs(data_batch) # type: ignore[arg-type] @@ -732,7 +711,12 @@ def _set_label_info(self, _: LabelInfoTypes) -> None: log.warning(msg) -class OVZeroShotVisualPromptingModel(OVVisualPromptingModel): +class OVZeroShotVisualPromptingModel( + OVModel[ + ZeroShotVisualPromptingBatchDataEntity, + ZeroShotVisualPromptingBatchPredEntity, + ], +): """Zero-shot visual prompting model compatible for OpenVINO IR inference. It can only consume OpenVINO IR model path and create the OTX zero-shot visual prompting model compatible @@ -748,10 +732,26 @@ def __init__( use_throughput_mode: bool = True, model_api_configuration: dict[str, Any] | None = None, metric: MetricCallable = VisualPromptingMetricCallable, - root_reference_info: str = "vpm_zsl_reference_infos", + reference_info_dir: Path | str = "reference_infos", + infer_reference_info_root: Path | str = "../.latest/train", save_outputs: bool = True, **kwargs, ) -> None: + if async_inference: + log.warning( + ( + "Async inference is not supported for zero-shot visual prompting models. " + "Setting async_inference to False.", + ), + ) + async_inference = False + + basename: str = Path(model_name).name + model_type_name: str = "_".join(basename.split("_")[:2]) + self.model_names: dict[str, str] = { + module: model_name.replace(basename, f"{model_type_name}_{module}.xml") + for module in ["image_encoder", "decoder"] + } super().__init__( model_name=model_name, model_type=model_type, @@ -761,7 +761,8 @@ def __init__( model_api_configuration=model_api_configuration, metric=metric, ) - self.root_reference_info: Path = Path(root_reference_info) + self.reference_info_dir: Path = Path(reference_info_dir) + self.infer_reference_info_root: Path = Path(infer_reference_info_root) self.save_outputs: bool = save_outputs self.point_labels_box = np.array([[2, 3]], dtype=np.float32) @@ -769,6 +770,29 @@ def __init__( self.initialize_reference_info() + def _create_model(self) -> dict[str, Model]: + """Create a OV model with help of Model API.""" + from openvino.model_api.adapters import OpenvinoAdapter, create_core, get_user_config + from openvino.model_api.models import Model + + ov_models: dict[str, Model] = {} + + plugin_config = get_user_config("AUTO", str(self.num_requests), "AUTO") + if self.use_throughput_mode: + plugin_config["PERFORMANCE_HINT"] = "THROUGHPUT" + + model_parameters = {"decoder": {"input_layouts": "image_embeddings:NCHW"}} + for module in ["image_encoder", "decoder"]: + model_adapter = OpenvinoAdapter( + core=create_core(), + model=self.model_names.get(module), + model_parameters=model_parameters.get(module, {}), + max_num_requests=self.num_requests, + plugin_config=plugin_config, + ) + ov_models[module] = Model.create_model(model_adapter, module, configuration=self.model_api_configuration) + return ov_models + def learn( self, inputs: ZeroShotVisualPromptingBatchDataEntity, @@ -905,7 +929,7 @@ def infer( def forward( # type: ignore[override] self, inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override] - ) -> ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity: + ) -> ZeroShotVisualPromptingBatchPredEntity: """Model forward function.""" kwargs: dict[str, Any] = {} fn = self.learn if self.training else self.infer @@ -920,7 +944,7 @@ def forward( # type: ignore[override] if self.async_inference: log.warning( ( - "Async inference is not supported for visual prompting models yet. " + "Async inference is not supported for zero-shot visual prompting models yet. " "Running synchronous inference instead.", ), ) @@ -976,7 +1000,7 @@ def _customize_outputs( # type: ignore[override] self, outputs: Any, # noqa: ANN401 inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override] - ) -> ZeroShotVisualPromptingBatchPredEntity | OTXBatchLossEntity: + ) -> ZeroShotVisualPromptingBatchPredEntity: """Customize OTX output batch data entity if needed for model.""" if self.training: return outputs @@ -1017,6 +1041,90 @@ def _customize_outputs( # type: ignore[override] labels=labels, ) + def optimize( # type: ignore[override] + self, + output_dir: Path, + data_module: OTXDataModule, + ptq_config: dict[str, Any] | None = None, + ) -> dict[str, Path]: + """Runs NNCF quantization.""" + import nncf + import openvino + + def check_if_quantized(model: openvino.Model) -> bool: + """Checks if OpenVINO model is already quantized.""" + nodes = model.get_ops() + return any(op.get_type_name() == "FakeQuantize" for op in nodes) + + def transform_fn( + data_batch: ZeroShotVisualPromptingBatchDataEntity, + module: Literal["image_encoder", "decoder"], + ) -> np.ndarray | dict[str, Any]: + images, _, prompts = self._customize_inputs(data_batch) # type: ignore[arg-type] + + image = images[0]["images"] # use only the first image + if module == "image_encoder": + # resize + resized_image = self.model["image_encoder"].resize( + image[0], + (self.model["image_encoder"].w, self.model["image_encoder"].h), + ) + + # pad image if necessary because `fit_to_window` resize for python in modelapi doesn't support pad + pad_w = max(0, self.model["image_encoder"].w - resized_image.shape[1]) + pad_h = max(0, self.model["image_encoder"].h - resized_image.shape[0]) + resized_image = np.pad( + resized_image, + ((0, pad_h), (0, pad_w), (0, 0)), + mode="constant", + constant_values=0, + ) + + # normalization + resized_image = self.model["image_encoder"].input_transform(resized_image) + + # change layout from HWC to NCHW + return self.model["image_encoder"]._change_layout(resized_image) # noqa: SLF001 + + # obtain image embeddings from image encoder + image_embeddings = self.model["image_encoder"].infer_sync(image) + # use only the first prompt + prompt_for_optim = next(iter(prompts[0].values()))[0] if isinstance(prompts[0], dict) else prompts[0][0] # type: ignore[attr-defined] + prompt_for_optim.pop("label") + prompt_for_optim.update(**image_embeddings) + return prompt_for_optim + + output_model_paths: dict[str, Path] = {} + for module in ["image_encoder", "decoder"]: + output_model_path = output_dir / (self._OPTIMIZED_MODEL_BASE_NAME + f"_{module}.xml") + + ov_model = openvino.Core().read_model(self.model_names[module]) + if check_if_quantized(ov_model): + msg = "Model is already optimized by PTQ" + raise RuntimeError(msg) + + train_dataset = data_module.train_dataloader() + + ptq_config_from_ir = self._read_ptq_config_from_ir(ov_model) + if ptq_config is not None: + ptq_config_from_ir.update(ptq_config) + ptq_config = ptq_config_from_ir + else: + ptq_config = ptq_config_from_ir + + quantization_dataset = nncf.Dataset(train_dataset, partial(transform_fn, module=module)) # type: ignore[attr-defined] + + compressed_model = nncf.quantize( # type: ignore[attr-defined] + ov_model, + quantization_dataset, + **ptq_config, + ) + + openvino.save_model(compressed_model, output_model_path) + output_model_paths[module] = output_model_path + + return output_model_paths + ###################################### # Preprocess # ###################################### @@ -1137,6 +1245,20 @@ def expand_reference_info(self, new_largest_label: int) -> None: diff = new_largest_label - cur_largest_label self.reference_feats = np.pad(self.reference_feats, ((0, diff), (0, 0), (0, 0)), constant_values=0.0) + def save_reference_info(self, default_root_dir: Path | str) -> None: + """Save reference info.""" + reference_info = { + "reference_feats": self.reference_feats, + "used_indices": self.used_indices, + } + # save reference info + path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt" + path_reference_info.parent.mkdir(parents=True, exist_ok=True) + # TODO (sungchul): ticket no. 139210 + torch.save({k: torch.as_tensor(v) for k, v in reference_info.items()}, path_reference_info) + pickle.dump(reference_info, path_reference_info.with_suffix(".pickle").open("wb")) + log.info(f"Saved reference info at {path_reference_info}.") + def _generate_masked_features( self, feats: np.ndarray, @@ -1189,25 +1311,24 @@ def _pad_to_square(self, x: np.ndarray, image_size: int = 1024) -> np.ndarray: ###################################### # Infer # ###################################### - def _find_latest_reference_info(self, root: Path) -> str | None: - """Find latest reference info to be used.""" - if not Path.is_dir(root): - return None - if len(stamps := sorted(os.listdir(root), reverse=True)) > 0: - return stamps[0] - return None - - def load_latest_reference_info(self, *args, **kwargs) -> bool: + def load_reference_info(self, default_root_dir: Path | str, *args, **kwargs) -> bool: """Load latest reference info to be used.""" - if (latest_stamp := self._find_latest_reference_info(self.root_reference_info)) is not None: - latest_reference_info: Path = self.root_reference_info / latest_stamp / "reference_info.pickle" - reference_info: dict[str, np.ndarray] = pickle.load(Path.open(latest_reference_info, "rb")) # noqa: S301 + _infer_reference_info_root: Path = ( + self.infer_reference_info_root + if self.infer_reference_info_root == self.infer_reference_info_root.absolute() + else Path(default_root_dir) / self.infer_reference_info_root + ) + + if ( + path_reference_info := _infer_reference_info_root / self.reference_info_dir / "reference_info.pickle" + ).is_file(): + reference_info: dict[str, np.ndarray] = pickle.load(path_reference_info.open("rb")) # noqa: S301 self.reference_feats = reference_info.get( "reference_feats", np.zeros((0, 1, self.model["decoder"].embed_dim), dtype=np.float32), ) self.used_indices = reference_info.get("used_indices", np.array([], dtype=np.int64)) - log.info(f"reference info saved at {latest_reference_info} was successfully loaded.") + log.info(f"reference info saved at {path_reference_info} was successfully loaded.") return True return False @@ -1385,6 +1506,65 @@ def _topk_numpy(self, x: np.ndarray, k: int, axis: int = -1, largest: bool = Tru def _reset_prediction_layer(self, num_classes: int) -> None: return + ###################################### + # Lit Module # + ###################################### + def on_train_start(self) -> None: + """Initialize reference infos before learn.""" + self.initialize_reference_info() + + def on_test_start(self) -> None: + """Load previously saved reference info.""" + super().on_test_start() + if not self.load_reference_info(self.trainer.default_root_dir, self.device): + log.warning("No reference info found. `Learn` will be automatically executed first.") + self.trainer.lightning_module.automatic_optimization = False + self.trainer.fit_loop.run() + # to use infer logic + self.training = False + # to set _combined_loader + self.trainer._evaluation_loop.setup_data() # noqa: SLF001 + self.trainer._evaluation_loop.reset() # noqa: SLF001 + self.load_reference_info(self.trainer.default_root_dir, self.device) + + def on_predict_start(self) -> None: + """Load previously saved reference info.""" + if not self.load_reference_info(self.trainer.default_root_dir, self.device): + log.warning("No reference info found. `Learn` will be automatically executed first.") + self.trainer.lightning_module.automatic_optimization = False + self.trainer.fit_loop.run() + # to use infer logic + self.training = False + # to set _combined_loader + self.trainer._evaluation_loop.setup_data() # noqa: SLF001 + self.trainer._evaluation_loop.reset() # noqa: SLF001 + self.load_reference_info(self.trainer.default_root_dir, self.device) + + def on_train_epoch_start(self) -> None: + """Skip on_train_epoch_start unused in zero-shot visual prompting.""" + + def on_train_epoch_end(self) -> None: + """Skip on_train_epoch_end unused in zero-shot visual prompting.""" + if self.save_outputs: + self.save_reference_info(self.trainer.default_root_dir) + + def on_validation_epoch_start(self) -> None: + """Skip on_validation_epoch_start unused in zero-shot visual prompting.""" + + def on_validation_epoch_end(self) -> None: + """Skip on_validation_epoch_end unused in zero-shot visual prompting.""" + + def configure_optimizers(self) -> None: # type: ignore[override] + """Skip configure_optimizers unused in zero-shot visual prompting.""" + + def training_step( + self, + inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override] + batch_idx: int, + ) -> Tensor: + """Skip training_step unused in zero-shot visual prompting.""" + self.forward(inputs) + def validation_step( self, inputs: ZeroShotVisualPromptingBatchDataEntity, @@ -1400,11 +1580,11 @@ def test_step( """Perform a single test step on a batch of data from the test set. Args: - inputs (VisualPromptingBatchDataEntity): The input data for the test step. + inputs (ZeroShotVisualPromptingBatchDataEntity): The input data for the test step. batch_idx (int): The index of the current batch. Raises: - TypeError: If the predictions are not of type VisualPromptingBatchPredEntity. + TypeError: If the predictions are not of type ZeroShotVisualPromptingBatchPredEntity. """ _inference_step_for_zero_shot(model=self, metric=self.metric, inputs=inputs) diff --git a/src/otx/recipe/zero_shot_visual_prompting/openvino_model.yaml b/src/otx/recipe/zero_shot_visual_prompting/openvino_model.yaml index a0f8d47d4d3..d5b5ecf5262 100644 --- a/src/otx/recipe/zero_shot_visual_prompting/openvino_model.yaml +++ b/src/otx/recipe/zero_shot_visual_prompting/openvino_model.yaml @@ -6,7 +6,8 @@ model: model_type: Zero_Shot_Visual_Prompting async_inference: False use_throughput_mode: True - root_reference_info: vpm_zsl_reference_infos + reference_info_dir: reference_infos + infer_reference_info_root: ../.latest/train # set absolute path for using reference_info saved in other location save_outputs: True engine: diff --git a/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml b/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml index fa2ae11e822..7e91958536c 100644 --- a/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml +++ b/src/otx/recipe/zero_shot_visual_prompting/sam_tiny_vit.yaml @@ -9,7 +9,8 @@ model: default_threshold_reference: 0.3 default_threshold_target: 0.65 save_outputs: True - root_reference_info: vpm_zsl_reference_infos + reference_info_dir: reference_infos + infer_reference_info_root: ../.latest/train # set absolute path for using reference_info saved in other location # options use_stability_score: False return_single_mask: False diff --git a/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml b/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml index 2665589eb9b..febb5215d7a 100644 --- a/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml +++ b/src/otx/recipe/zero_shot_visual_prompting/sam_vit_b.yaml @@ -9,7 +9,8 @@ model: default_threshold_reference: 0.3 default_threshold_target: 0.65 save_outputs: True - root_reference_info: vpm_zsl_reference_infos + reference_info_dir: reference_infos + infer_reference_info_root: ../.latest/train # set absolute path for using reference_info saved in other location # options use_stability_score: False return_single_mask: False diff --git a/tests/integration/api/test_auto_configuration.py b/tests/integration/api/test_auto_configuration.py index 40c7a375c83..94d9b7eed44 100644 --- a/tests/integration/api/test_auto_configuration.py +++ b/tests/integration/api/test_auto_configuration.py @@ -43,6 +43,10 @@ def test_auto_configuration( work_dir=tmp_path_train, device=fxt_accelerator, ) + if task.lower() == "zero_shot_visual_prompting": + engine.model.infer_reference_info_root = Path() + # update litmodule.hparams to reflect changed hparams + engine.model.hparams.update({"infer_reference_info_root": str(engine.model.infer_reference_info_root)}) # Check OTXModel & OTXDataModule assert isinstance(engine.model, OTXModel) diff --git a/tests/integration/api/test_engine_api.py b/tests/integration/api/test_engine_api.py index 5362da8132c..945be33cf60 100644 --- a/tests/integration/api/test_engine_api.py +++ b/tests/integration/api/test_engine_api.py @@ -47,6 +47,10 @@ def test_engine_from_config( work_dir=tmp_path_train, device=fxt_accelerator, ) + if task.lower() == "zero_shot_visual_prompting": + engine.model.infer_reference_info_root = Path() + # update litmodule.hparams to reflect changed hparams + engine.model.hparams.update({"infer_reference_info_root": str(engine.model.infer_reference_info_root)}) # Check OTXModel & OTXDataModule assert isinstance(engine.model, OTXModel) @@ -89,6 +93,14 @@ def test_engine_from_config( # Test with IR Model if task in OVMODEL_PER_TASK: if task.lower() in ["visual_prompting", "zero_shot_visual_prompting"]: + if task.lower() == "zero_shot_visual_prompting": + engine.model = engine._auto_configurator.get_ov_model( + model_name=str(exported_model_path["decoder"]), + label_info=engine.datamodule.label_info, + ) + engine.model.infer_reference_info_root = Path() + # update litmodule.hparams to reflect changed hparams + engine.model.hparams.update({"infer_reference_info_root": str(engine.model.infer_reference_info_root)}) test_metric_from_ov_model = engine.test(checkpoint=exported_model_path["decoder"], accelerator="cpu") else: test_metric_from_ov_model = engine.test(checkpoint=exported_model_path, accelerator="cpu") diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index 07acab85b16..eea17a3c2dc 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -112,6 +112,15 @@ def test_otx_e2e( "--checkpoint", str(ckpt_file), ] + # Zero-shot visual prompting needs to specify `infer_reference_info_root` + if task in ["zero_shot_visual_prompting"]: + idx_task = str(ckpt_file).split("/").index(f"otx_train_{model_name}") + command_cfg.extend( + [ + "--model.init_args.infer_reference_info_root", + str(ckpt_file.parents[-idx_task] / f"otx_train_{model_name}/outputs/.latest/train"), + ], + ) run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) @@ -187,7 +196,11 @@ def test_otx_e2e( (p for p in ov_output_dir.iterdir() if p.is_dir() and p.name != ".latest"), key=lambda p: p.stat().st_mtime, ) - exported_model_path = str(ov_latest_dir / "exported_model.xml") + if task in ("visual_prompting", "zero_shot_visual_prompting"): + exported_model_path = str(ov_latest_dir / "exported_model_decoder.xml") + recipe = str(Path(recipe).parents[0] / "openvino_model.yaml") + else: + exported_model_path = str(ov_latest_dir / "exported_model.xml") overrides = fxt_cli_override_command_per_task[task] if "anomaly" in task: @@ -208,6 +221,15 @@ def test_otx_e2e( "--checkpoint", exported_model_path, ] + # Zero-shot visual prompting needs to specify `infer_reference_info_root` + if task in ["zero_shot_visual_prompting"]: + idx_task = str(ckpt_file).split("/").index(f"otx_train_{model_name}") + command_cfg.extend( + [ + "--model.init_args.infer_reference_info_root", + str(ckpt_file.parents[-idx_task] / f"otx_train_{model_name}/outputs/.latest/train"), + ], + ) run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) diff --git a/tests/integration/cli/test_export_inference.py b/tests/integration/cli/test_export_inference.py index 54e6303888e..1a85ad7cd4d 100644 --- a/tests/integration/cli/test_export_inference.py +++ b/tests/integration/cli/test_export_inference.py @@ -142,6 +142,21 @@ def run_cli_test(test_recipe: str, checkpoint_path: str, work_dir: Path, device: "--checkpoint", checkpoint_path, ] + + # Zero-shot visual prompting needs to specify `infer_reference_info_root` + if task in ["zero_shot_visual_prompting"]: + try: + idx_task = checkpoint_path.split("/").index(f"otx_train_{model_name}") + except ValueError: + idx_task = checkpoint_path.split("/").index(f"otx_test_{model_name}") + + command_cfg.extend( + [ + "--model.init_args.infer_reference_info_root", + str(Path(checkpoint_path).parents[-idx_task] / f"otx_train_{model_name}/outputs/.latest/train"), + ], + ) + run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) return tmp_path_test diff --git a/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py b/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py index 79a0a0517f3..b9fe3a5e58e 100644 --- a/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py +++ b/tests/unit/algo/visual_prompting/test_zero_shot_segment_anything.py @@ -3,7 +3,6 @@ from __future__ import annotations -from pathlib import Path from typing import Any, Callable import pytest @@ -695,54 +694,40 @@ def test_initialize_reference_info(self, model) -> None: assert model.reference_feats.shape == (0, 1, 256) assert model.used_indices.shape == (0,) - def test_find_latest_reference_info(self, mocker, model) -> None: - """Test _find_latest_reference_info.""" - mocker.patch( - "otx.algo.visual_prompting.zero_shot_segment_anything.os.path.isdir", - return_value=True, - ) + def test_save_reference_info(self, mocker, tmpdir, model) -> None: + """Test save_reference_info.""" + model.root_reference_info = tmpdir + model.reference_feats = torch.tensor(1) + model.used_indices = torch.tensor(1) + mocker_mkdir = mocker.patch("pathlib.Path.mkdir") + mocker.patch("pathlib.Path.open") + mocker_torch_save = mocker.patch("torch.save") + mocker_pickle_dump = mocker.patch("pickle.dump") - # there are some saved reference info - mocker.patch( - "otx.algo.visual_prompting.zero_shot_segment_anything.os.listdir", - return_value=["1", "2"], - ) - results = model._find_latest_reference_info(Path()) - assert results == "2" + model.save_reference_info(".") - # there are no saved reference info - mocker.patch( - "otx.algo.visual_prompting.zero_shot_segment_anything.os.listdir", - return_value=[], - ) - results = model._find_latest_reference_info(Path()) - assert results is None + mocker_mkdir.assert_called_once() + mocker_torch_save.assert_called_once() + mocker_pickle_dump.assert_called_once() - def test_load_latest_reference_info(self, mocker, model) -> None: - """Test load_latest_reference_info.""" + def test_load_reference_info(self, mocker, model) -> None: + """Test load_reference_info.""" # get previously saved reference info mocker.patch( - "otx.algo.visual_prompting.zero_shot_segment_anything.OTXZeroShotSegmentAnything._find_latest_reference_info", - return_value="1", - ) - mocker.patch( - "otx.algo.visual_prompting.zero_shot_segment_anything.torch.load", + "torch.load", return_value={"reference_feats": torch.zeros((1, 1, 256)), "used_indices": torch.tensor([0.0])}, ) - mocker.patch("builtins.open", return_value="Mocked data") + mocker.patch("pathlib.Path.is_file", return_value=True) - model.load_latest_reference_info() + model.load_reference_info(".") assert model.reference_feats.shape == (1, 1, 256) assert model.used_indices.shape == (1,) # no saved reference info - mocker.patch( - "otx.algo.visual_prompting.zero_shot_segment_anything.OTXZeroShotSegmentAnything._find_latest_reference_info", - return_value=None, - ) + mocker.patch("pathlib.Path.is_file", return_value=False) model.initialize_reference_info() - model.load_latest_reference_info() + model.load_reference_info(".") assert model.reference_feats.shape == (0, 1, 256) assert model.used_indices.shape == (0,) diff --git a/tests/unit/core/model/test_visual_prompting.py b/tests/unit/core/model/test_visual_prompting.py index c7262e9729f..a52d520896b 100644 --- a/tests/unit/core/model/test_visual_prompting.py +++ b/tests/unit/core/model/test_visual_prompting.py @@ -221,7 +221,7 @@ def test_optimization_config(self, otx_zero_shot_visual_prompting_model) -> None def test_on_test_start(self, mocker, otx_zero_shot_visual_prompting_model) -> None: """Test on_test_start.""" - otx_zero_shot_visual_prompting_model.load_latest_reference_info = Mock(return_value=False) + otx_zero_shot_visual_prompting_model.load_reference_info = Mock(return_value=False) otx_zero_shot_visual_prompting_model.trainer = Mock() mocker_run = mocker.patch.object(otx_zero_shot_visual_prompting_model.trainer.fit_loop, "run") mocker_setup_data = mocker.patch.object( @@ -238,7 +238,7 @@ def test_on_test_start(self, mocker, otx_zero_shot_visual_prompting_model) -> No def test_on_predict_start(self, mocker, otx_zero_shot_visual_prompting_model) -> None: """Test on_predict_start.""" - otx_zero_shot_visual_prompting_model.load_latest_reference_info = Mock(return_value=False) + otx_zero_shot_visual_prompting_model.load_reference_info = Mock(return_value=False) otx_zero_shot_visual_prompting_model.trainer = Mock() mocker_run = mocker.patch.object(otx_zero_shot_visual_prompting_model.trainer.fit_loop, "run") mocker_setup_data = mocker.patch.object( @@ -256,20 +256,12 @@ def test_on_predict_start(self, mocker, otx_zero_shot_visual_prompting_model) -> def test_on_train_epoch_end(self, mocker, tmpdir, otx_zero_shot_visual_prompting_model) -> None: """Test on_train_epoch_end.""" otx_zero_shot_visual_prompting_model.save_outputs = True - otx_zero_shot_visual_prompting_model.root_reference_info = tmpdir - otx_zero_shot_visual_prompting_model.reference_feats = torch.tensor(1) - otx_zero_shot_visual_prompting_model.used_indices = torch.tensor(1) - mocker_mkdir = mocker.patch("otx.core.model.visual_prompting.Path.mkdir") - mocker.patch("otx.core.model.visual_prompting.Path.open") - mocker_torch_save = mocker.patch("otx.core.model.visual_prompting.torch.save") - mocker_pickle_dump = mocker.patch("otx.core.model.visual_prompting.pickle.dump") + otx_zero_shot_visual_prompting_model.save_reference_info = Mock() + otx_zero_shot_visual_prompting_model.trainer = Mock() + mocker.patch.object(otx_zero_shot_visual_prompting_model.trainer, "default_root_dir") otx_zero_shot_visual_prompting_model.on_train_epoch_end() - mocker_mkdir.assert_called_once() - mocker_torch_save.assert_called_once() - mocker_pickle_dump.assert_called_once() - class TestOVVisualPromptingModel: @pytest.fixture() @@ -647,51 +639,28 @@ def test_pad_to_square(self, ov_zero_shot_visual_prompting_model) -> None: assert result[8:, :8].sum() == 0 assert result[8:, 8:].sum() == 0 - def test_find_latest_reference_info(self, mocker, ov_zero_shot_visual_prompting_model) -> None: - """Test _find_latest_reference_info.""" - mocker.patch( - "otx.core.model.visual_prompting.os.path.isdir", - return_value=True, - ) - - # there are some saved reference info - mocker.patch( - "otx.core.model.visual_prompting.os.listdir", - return_value=["1", "2"], - ) - results = ov_zero_shot_visual_prompting_model._find_latest_reference_info(Path()) - assert results == "2" - - # there are no saved reference info - mocker.patch( - "otx.core.model.visual_prompting.os.listdir", - return_value=[], - ) - results = ov_zero_shot_visual_prompting_model._find_latest_reference_info(Path()) - assert results is None - - def test_load_latest_reference_info(self, mocker, ov_zero_shot_visual_prompting_model) -> None: + def test_load_reference_info(self, mocker, ov_zero_shot_visual_prompting_model) -> None: """Test load_latest_reference_info.""" ov_zero_shot_visual_prompting_model.model["decoder"].embed_dim = 256 # get previously saved reference info - mocker.patch.object(ov_zero_shot_visual_prompting_model, "_find_latest_reference_info", return_value="1") mocker.patch( - "otx.core.model.visual_prompting.pickle.load", + "pickle.load", return_value={"reference_feats": np.zeros((1, 1, 256)), "used_indices": np.array([0])}, ) - mocker.patch("otx.core.model.visual_prompting.Path.open", return_value="Mocked data") + mocker.patch("pathlib.Path.is_file", return_value=True) + mocker.patch("pathlib.Path.open", return_value="Mocked data") - ov_zero_shot_visual_prompting_model.load_latest_reference_info() + ov_zero_shot_visual_prompting_model.load_reference_info(".") assert ov_zero_shot_visual_prompting_model.reference_feats.shape == (1, 1, 256) assert ov_zero_shot_visual_prompting_model.used_indices.shape == (1,) # no saved reference info - mocker.patch.object(ov_zero_shot_visual_prompting_model, "_find_latest_reference_info", return_value=None) + mocker.patch("pathlib.Path.is_file", return_value=False) ov_zero_shot_visual_prompting_model.reference_feats = np.zeros((0, 1, 256), dtype=np.float32) ov_zero_shot_visual_prompting_model.used_indices = np.array([], dtype=np.int64) - ov_zero_shot_visual_prompting_model.load_latest_reference_info() + ov_zero_shot_visual_prompting_model.load_reference_info(".") assert ov_zero_shot_visual_prompting_model.reference_feats.shape == (0, 1, 256) assert ov_zero_shot_visual_prompting_model.used_indices.shape == (0,)