diff --git a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py index 72b5814361..971dbf4df3 100644 --- a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py +++ b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py @@ -1191,6 +1191,7 @@ def compute_object_detection_metrics( continue metric_OD = MeanAveragePrecision( + average=aggregate_method.lower(), class_metrics=True, iou_thresholds=normalized_iou_threshold).to(device) true_y_cohort = [true_y[cohort_index] for cohort_index diff --git a/responsibleai_vision/tests/test_rai_vision_insights.py b/responsibleai_vision/tests/test_rai_vision_insights.py index a8da6e5b94..434aee6c2b 100644 --- a/responsibleai_vision/tests/test_rai_vision_insights.py +++ b/responsibleai_vision/tests/test_rai_vision_insights.py @@ -332,11 +332,12 @@ def run_rai_insights(model, test_data, target_column, ignore_index) if task_type == ModelTask.OBJECT_DETECTION: selection_indexes = [[0]] - aggregate_method = 'Macro' class_name = classes[0] iou_threshold = 70 object_detection_cache = {} - metrics = rai_insights.compute_object_detection_metrics( - selection_indexes, aggregate_method, class_name, iou_threshold, - object_detection_cache) - assert len(metrics) == 2 + aggregate_methods = ['macro', 'micro'] + for aggregate_method in aggregate_methods: + metrics = rai_insights.compute_object_detection_metrics( + selection_indexes, aggregate_method, class_name, iou_threshold, + object_detection_cache) + assert len(metrics) == 2