From 8598b05f862c04278e9e684ceece72178da4f1f2 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Tue, 31 Oct 2023 12:33:03 +0100 Subject: [PATCH] Add label ids to anomaly OpenVINO model xml (#2590) * Add label ids to model xml --------- --- src/otx/algorithms/anomaly/tasks/inference.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/otx/algorithms/anomaly/tasks/inference.py b/src/otx/algorithms/anomaly/tasks/inference.py index 50c5a4b81f7..aac3d525b5e 100644 --- a/src/otx/algorithms/anomaly/tasks/inference.py +++ b/src/otx/algorithms/anomaly/tasks/inference.py @@ -359,7 +359,20 @@ def _add_metadata_to_ir(self, model_file: str, export_type: ExportType) -> None: extra_model_data[("model_info", "reverse_input_channels")] = False extra_model_data[("model_info", "model_type")] = "AnomalyDetection" - extra_model_data[("model_info", "labels")] = "Normal Anomaly" + + labels = [] + label_ids = [] + for label_entity in self.task_environment.label_schema.get_labels(include_empty=False): + label_name = label_entity.name.replace(" ", "_") + # There is a mismatch between labels in OTX and modelAPI + if label_name == "Anomalous": + label_name = "Anomaly" + labels.append(label_name) + label_ids.append(str(label_entity.id_)) + + extra_model_data[("model_info", "labels")] = " ".join(labels) + extra_model_data[("model_info", "label_ids")] = " ".join(label_ids) + if export_type == ExportType.OPENVINO: embed_ir_model_data(model_file, extra_model_data) elif export_type == ExportType.ONNX: