Skip to content

Commit

Permalink
Merge pull request #1589 from souhhmm/feat/detection-metadata
Browse files Browse the repository at this point in the history
Feat: Added detection metadata
  • Loading branch information
LinasKo authored Nov 4, 2024
2 parents d388857 + 9ee1b5d commit 11355c5
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 3 deletions.
16 changes: 16 additions & 0 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
extract_ultralytics_masks,
get_data_item,
is_data_equal,
is_metadata_equal,
mask_to_xyxy,
merge_data,
merge_metadata,
process_roboflow_result,
xywh_to_xyxy,
)
Expand Down Expand Up @@ -125,6 +127,9 @@ class simplifies data manipulation and filtering, providing a uniform API for
data (Dict[str, Union[np.ndarray, List]]): A dictionary containing additional
data where each key is a string representing the data type, and the value
is either a NumPy array or a list of corresponding data.
metadata (Dict[str, Any]): A dictionary containing collection-level metadata
that applies to the entire set of detections. This may include information such
as the video name, camera parameters, timestamp, or other global metadata.
""" # noqa: E501 // docs

xyxy: np.ndarray
Expand All @@ -133,6 +138,7 @@ class simplifies data manipulation and filtering, providing a uniform API for
class_id: Optional[np.ndarray] = None
tracker_id: Optional[np.ndarray] = None
data: Dict[str, Union[np.ndarray, List]] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
validate_detections_fields(
Expand Down Expand Up @@ -185,6 +191,7 @@ def __eq__(self, other: Detections):
np.array_equal(self.confidence, other.confidence),
np.array_equal(self.tracker_id, other.tracker_id),
is_data_equal(self.data, other.data),
is_metadata_equal(self.metadata, other.metadata),
]
)

Expand Down Expand Up @@ -985,6 +992,7 @@ def is_empty(self) -> bool:
"""
empty_detections = Detections.empty()
empty_detections.data = self.data
empty_detections.metadata = self.metadata
return self == empty_detections

@classmethod
Expand Down Expand Up @@ -1078,13 +1086,17 @@ def stack_or_none(name: str):

data = merge_data([d.data for d in detections_list])

metadata_list = [detections.metadata for detections in detections_list]
metadata = merge_metadata(metadata_list)

return cls(
xyxy=xyxy,
mask=mask,
confidence=confidence,
class_id=class_id,
tracker_id=tracker_id,
data=data,
metadata=metadata,
)

def get_anchors_coordinates(self, anchor: Position) -> np.ndarray:
Expand Down Expand Up @@ -1198,6 +1210,7 @@ def __getitem__(
class_id=self.class_id[index] if self.class_id is not None else None,
tracker_id=self.tracker_id[index] if self.tracker_id is not None else None,
data=get_data_item(self.data, index),
metadata=self.metadata,
)

def __setitem__(self, key: str, value: Union[np.ndarray, List]):
Expand Down Expand Up @@ -1459,13 +1472,16 @@ def merge_inner_detection_object_pair(
else:
winning_detection = detections_2

metadata = merge_metadata([detections_1.metadata, detections_2.metadata])

return Detections(
xyxy=merged_xyxy,
mask=merged_mask,
confidence=merged_confidence,
class_id=winning_detection.class_id,
tracker_id=winning_detection.tracker_id,
data=winning_detection.data,
metadata=metadata,
)


Expand Down
65 changes: 64 additions & 1 deletion supervision/detection/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import cv2
import numpy as np
Expand Down Expand Up @@ -807,12 +807,36 @@ def is_data_equal(data_a: Dict[str, np.ndarray], data_b: Dict[str, np.ndarray])
)


def is_metadata_equal(metadata_a: Dict[str, Any], metadata_b: Dict[str, Any]) -> bool:
"""
Compares the metadata payloads of two Detections instances.
Args:
metadata_a, metadata_b: The metadata payloads of the instances.
Returns:
True if the metadata payloads are equal, False otherwise.
"""
return set(metadata_a.keys()) == set(metadata_b.keys()) and all(
np.array_equal(metadata_a[key], metadata_b[key])
if (
isinstance(metadata_a[key], np.ndarray)
and isinstance(metadata_b[key], np.ndarray)
)
else metadata_a[key] == metadata_b[key]
for key in metadata_a
)


def merge_data(
data_list: List[Dict[str, Union[npt.NDArray[np.generic], List]]],
) -> Dict[str, Union[npt.NDArray[np.generic], List]]:
"""
Merges the data payloads of a list of Detections instances.
Warning: Assumes that empty detections were filtered-out before passing data to
this function.
Args:
data_list: The data payloads of the Detections instances. Each data payload
is a dictionary with the same keys, and the values are either lists or
Expand Down Expand Up @@ -865,6 +889,45 @@ def merge_data(
return merged_data


def merge_metadata(metadata_list: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Merge metadata from a list of metadata dictionaries.
This function combines the metadata dictionaries. If a key appears in more than one
dictionary, the values must be identical for the merge to succeed.
Warning: Assumes that empty detections were filtered-out before passing metadata to
this function.
Args:
metadata_list (List[Dict[str, Any]]): A list of metadata dictionaries to merge.
Returns:
Dict[str, Any]: A single merged metadata dictionary.
Raises:
ValueError: If there are conflicting values for the same key or if
dictionaries have different keys.
"""
if not metadata_list:
return {}

all_keys_sets = [set(metadata.keys()) for metadata in metadata_list]
if not all(keys_set == all_keys_sets[0] for keys_set in all_keys_sets):
raise ValueError("All metadata dictionaries must have the same keys to merge.")

merged_metadata: Dict[str, Any] = {}
for metadata in metadata_list:
for key, value in metadata.items():
if key in merged_metadata:
if merged_metadata[key] != value:
raise ValueError(f"Conflicting metadata for key: '{key}'.")
else:
merged_metadata[key] = value

return merged_metadata


def get_data_item(
data: Dict[str, Union[np.ndarray, List]],
index: Union[int, slice, List[int], np.ndarray],
Expand Down
23 changes: 21 additions & 2 deletions test/utils/test_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,15 @@ def __private_property(self):
(
Detections.empty(),
False,
{"xyxy", "class_id", "confidence", "mask", "tracker_id", "data"},
{
"xyxy",
"class_id",
"confidence",
"mask",
"tracker_id",
"data",
"metadata",
},
DoesNotRaise(),
),
(
Expand All @@ -134,6 +142,7 @@ def __private_property(self):
"mask",
"tracker_id",
"data",
"metadata",
"area",
"box_area",
},
Expand All @@ -149,6 +158,7 @@ def __private_property(self):
"mask",
"tracker_id",
"data",
"metadata",
},
DoesNotRaise(),
),
Expand All @@ -169,13 +179,22 @@ def __private_property(self):
"mask",
"tracker_id",
"data",
"metadata",
},
DoesNotRaise(),
),
(
Detections.empty(),
False,
{"xyxy", "class_id", "confidence", "mask", "tracker_id", "data"},
{
"xyxy",
"class_id",
"confidence",
"mask",
"tracker_id",
"data",
"metadata",
},
DoesNotRaise(),
),
],
Expand Down

0 comments on commit 11355c5

Please sign in to comment.