From 6a5d308567c9ed3f12ba2c27780625c995b55420 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 14 Nov 2023 14:32:57 -0500 Subject: [PATCH] add average parameter to MeanAveragePrecision to specify micro or macro calculation (#2412) --- .../rai_vision_insights/rai_vision_insights.py | 1 + .../tests/test_rai_vision_insights.py | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py index 72b5814361..971dbf4df3 100644 --- a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py +++ b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py @@ -1191,6 +1191,7 @@ def compute_object_detection_metrics( continue metric_OD = MeanAveragePrecision( + average=aggregate_method.lower(), class_metrics=True, iou_thresholds=normalized_iou_threshold).to(device) true_y_cohort = [true_y[cohort_index] for cohort_index diff --git a/responsibleai_vision/tests/test_rai_vision_insights.py b/responsibleai_vision/tests/test_rai_vision_insights.py index a8da6e5b94..434aee6c2b 100644 --- a/responsibleai_vision/tests/test_rai_vision_insights.py +++ b/responsibleai_vision/tests/test_rai_vision_insights.py @@ -332,11 +332,12 @@ def run_rai_insights(model, test_data, target_column, ignore_index) if task_type == ModelTask.OBJECT_DETECTION: selection_indexes = [[0]] - aggregate_method = 'Macro' class_name = classes[0] iou_threshold = 70 object_detection_cache = {} - metrics = rai_insights.compute_object_detection_metrics( - selection_indexes, aggregate_method, class_name, iou_threshold, - object_detection_cache) - assert len(metrics) == 2 + aggregate_methods = ['macro', 'micro'] + for aggregate_method in aggregate_methods: + metrics = rai_insights.compute_object_detection_metrics( + selection_indexes, aggregate_method, class_name, iou_threshold, + object_detection_cache) + assert len(metrics) == 2