From e4b6596326fb01fd9d9b665962ccb5c83dc557ad Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Tue, 18 Feb 2025 14:39:17 -0700 Subject: [PATCH] Add IngestControlMessage class to nv_ingest_api library (#455) --- api/__init__.py | 0 api/src/nv_ingest_api/__init__.py | 0 api/src/nv_ingest_api/primitives/__init__.py | 0 .../primitives/control_message_task.py | 10 + .../primitives/ingest_control_message.py | 216 ++++++++++++ tests/nv_ingest_api/__init__.py | 0 tests/nv_ingest_api/primitives/__init__.py | 0 .../primitives/test_ingest_control_message.py | 330 ++++++++++++++++++ .../test_ingest_control_message_task.py | 64 ++++ 9 files changed, 620 insertions(+) create mode 100644 api/__init__.py create mode 100644 api/src/nv_ingest_api/__init__.py create mode 100644 api/src/nv_ingest_api/primitives/__init__.py create mode 100644 api/src/nv_ingest_api/primitives/control_message_task.py create mode 100644 api/src/nv_ingest_api/primitives/ingest_control_message.py create mode 100644 tests/nv_ingest_api/__init__.py create mode 100644 tests/nv_ingest_api/primitives/__init__.py create mode 100644 tests/nv_ingest_api/primitives/test_ingest_control_message.py create mode 100644 tests/nv_ingest_api/primitives/test_ingest_control_message_task.py diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/src/nv_ingest_api/__init__.py b/api/src/nv_ingest_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/src/nv_ingest_api/primitives/__init__.py b/api/src/nv_ingest_api/primitives/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/src/nv_ingest_api/primitives/control_message_task.py b/api/src/nv_ingest_api/primitives/control_message_task.py new file mode 100644 index 00000000..c0d249ff --- /dev/null +++ b/api/src/nv_ingest_api/primitives/control_message_task.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field, ConfigDict +from typing import Any, Dict + + +class ControlMessageTask(BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str + id: str + properties: Dict[str, Any] = Field(default_factory=dict) diff --git a/api/src/nv_ingest_api/primitives/ingest_control_message.py b/api/src/nv_ingest_api/primitives/ingest_control_message.py new file mode 100644 index 00000000..308c0940 --- /dev/null +++ b/api/src/nv_ingest_api/primitives/ingest_control_message.py @@ -0,0 +1,216 @@ +import copy +import re +from datetime import datetime + +import logging +import pandas as pd +from typing import Any, Dict, Generator, Union + +from nv_ingest_api.primitives.control_message_task import ControlMessageTask + + +logger = logging.getLogger(__name__) + + +class IngestControlMessage: + """ + A control message class for ingesting tasks and managing associated metadata, + timestamps, configuration, and payload. + """ + + def __init__(self): + """ + Initialize a new IngestControlMessage instance. + """ + self._tasks: Dict[str, ControlMessageTask] = {} + self._metadata: Dict[str, Any] = {} + self._timestamps: Dict[str, datetime] = {} + self._payload: pd.DataFrame = pd.DataFrame() + self._config: Dict[str, Any] = {} + + def add_task(self, task: ControlMessageTask): + """ + Add a task to the control message, keyed by the task's unique 'id'. + + Raises + ------ + ValueError + If a task with the same 'id' already exists. + """ + if task.id in self._tasks: + raise ValueError(f"Task with id '{task.id}' already exists. Tasks must be unique.") + self._tasks[task.id] = task + + def get_tasks(self) -> Generator[ControlMessageTask, None, None]: + """ + Return all tasks as a generator. + """ + yield from self._tasks.values() + + def has_task(self, task_id: str) -> bool: + """ + Check if a task with the given ID exists. + """ + return task_id in self._tasks + + def remove_task(self, task_id: str) -> None: + """ + Remove a task from the control message. Logs a warning if the task does not exist. + """ + if task_id in self._tasks: + del self._tasks[task_id] + else: + logger.warning(f"Attempted to remove non-existent task with id: {task_id}") + + def config(self, config: Dict[str, Any] = None) -> Dict[str, Any]: + """ + Get or update the control message configuration. + + If 'config' is provided, it must be a dictionary. The configuration is updated with the + provided values. If no argument is provided, returns a copy of the current configuration. + + Raises + ------ + ValueError + If the provided configuration is not a dictionary. + """ + if config is None: + return self._config.copy() + + if not isinstance(config, dict): + raise ValueError("Configuration must be provided as a dictionary.") + + self._config.update(config) + return self._config.copy() + + def copy(self) -> "IngestControlMessage": + """ + Create a deep copy of this control message. + """ + return copy.deepcopy(self) + + def get_metadata(self, key: Union[str, re.Pattern] = None, default_value: Any = None) -> Any: + """ + Retrieve metadata. If 'key' is None, returns a copy of all metadata. + + Parameters + ---------- + key : str or re.Pattern, optional + If a string is provided, returns the value for that exact key. + If a regex pattern is provided, returns a dictionary of all metadata key-value pairs + where the key matches the regex. If no matches are found, returns default_value. + default_value : Any, optional + The value to return if the key is not found or no regex matches. + + Returns + ------- + Any + The metadata value for an exact string key, or a dict of matching metadata if a regex is provided. + """ + if key is None: + return self._metadata.copy() + + # If key is a regex pattern (i.e. has a search method), perform pattern matching. + if hasattr(key, "search"): + matches = {k: v for k, v in self._metadata.items() if key.search(k)} + return matches if matches else default_value + + # Otherwise, perform an exact lookup. + return self._metadata.get(key, default_value) + + def has_metadata(self, key: Union[str, re.Pattern]) -> bool: + """ + Check if a metadata key exists. + + Parameters + ---------- + key : str or re.Pattern + If a string is provided, checks for the exact key. + If a regex pattern is provided, returns True if any metadata key matches the regex. + + Returns + ------- + bool + True if the key (or any matching key, in case of a regex) exists, False otherwise. + """ + if hasattr(key, "search"): + return any(key.search(k) for k in self._metadata) + return key in self._metadata + + def list_metadata(self) -> list: + """ + List all metadata keys. + """ + return list(self._metadata.keys()) + + def set_metadata(self, key: str, value: Any) -> None: + """ + Set a metadata key-value pair. + """ + self._metadata[key] = value + + def filter_timestamp(self, regex_filter: str) -> Dict[str, datetime]: + """ + Retrieve timestamps whose keys match the regex filter. + """ + pattern = re.compile(regex_filter) + return {key: ts for key, ts in self._timestamps.items() if pattern.search(key)} + + def get_timestamp(self, key: str, fail_if_nonexist: bool = False) -> datetime: + """ + Retrieve a timestamp for a given key. + + Raises + ------ + KeyError + If the key is not found and 'fail_if_nonexist' is True. + """ + if key in self._timestamps: + return self._timestamps[key] + if fail_if_nonexist: + raise KeyError(f"Timestamp for key '{key}' does not exist.") + return None + + def get_timestamps(self) -> Dict[str, datetime]: + """ + Retrieve all timestamps. + """ + return self._timestamps.copy() + + def set_timestamp(self, key: str, timestamp: Any) -> None: + """ + Set a timestamp for a given key. Accepts either a datetime object or an ISO format string. + + Raises + ------ + ValueError + If the provided timestamp is neither a datetime object nor a valid ISO format string. + """ + if isinstance(timestamp, datetime): + self._timestamps[key] = timestamp + elif isinstance(timestamp, str): + try: + dt = datetime.fromisoformat(timestamp) + self._timestamps[key] = dt + except ValueError as e: + raise ValueError(f"Invalid timestamp format: {timestamp}") from e + else: + raise ValueError("timestamp must be a datetime object or ISO format string") + + def payload(self, payload: pd.DataFrame = None) -> pd.DataFrame: + """ + Get or set the payload DataFrame. + + Raises + ------ + ValueError + If the provided payload is not a pandas DataFrame. + """ + if payload is None: + return self._payload + + if not isinstance(payload, pd.DataFrame): + raise ValueError("Payload must be a pandas DataFrame") + + self._payload = payload + return self._payload diff --git a/tests/nv_ingest_api/__init__.py b/tests/nv_ingest_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nv_ingest_api/primitives/__init__.py b/tests/nv_ingest_api/primitives/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nv_ingest_api/primitives/test_ingest_control_message.py b/tests/nv_ingest_api/primitives/test_ingest_control_message.py new file mode 100644 index 00000000..95c36a7a --- /dev/null +++ b/tests/nv_ingest_api/primitives/test_ingest_control_message.py @@ -0,0 +1,330 @@ +import re + +from nv_ingest_api.primitives.control_message_task import ControlMessageTask +from nv_ingest_api.primitives.ingest_control_message import IngestControlMessage + +import pytest +import pandas as pd +from datetime import datetime +from pydantic import ValidationError + + +def test_valid_task(): + data = { + "name": "Example Task", + "id": "task-123", + "properties": {"param1": "value1", "param2": 42}, + } + task = ControlMessageTask(**data) + assert task.name == "Example Task" + assert task.id == "task-123" + assert task.properties == {"param1": "value1", "param2": 42} + + +def test_valid_task_without_properties(): + data = {"name": "Minimal Task", "id": "task-456"} + task = ControlMessageTask(**data) + assert task.name == "Minimal Task" + assert task.id == "task-456" + assert task.properties == {} + + +def test_missing_required_field_name(): + data = {"id": "task-no-name", "properties": {"some_property": "some_value"}} + with pytest.raises(ValidationError) as exc_info: + ControlMessageTask(**data) + errors = exc_info.value.errors() + assert len(errors) == 1 + assert errors[0]["loc"] == ("name",) + assert errors[0]["type"] == "missing" + + +def test_missing_required_field_id(): + data = {"name": "Task With No ID", "properties": {"some_property": "some_value"}} + with pytest.raises(ValidationError) as exc_info: + ControlMessageTask(**data) + errors = exc_info.value.errors() + assert len(errors) == 1 + assert errors[0]["loc"] == ("id",) + assert errors[0]["type"] == "missing" + + +def test_extra_fields_forbidden(): + data = {"name": "Task With Extras", "id": "task-extra", "properties": {}, "unexpected_field": "foo"} + with pytest.raises(ValidationError) as exc_info: + ControlMessageTask(**data) + errors = exc_info.value.errors() + assert len(errors) == 1 + assert errors[0]["type"] == "extra_forbidden" + assert errors[0]["loc"] == ("unexpected_field",) + + +def test_properties_accepts_various_types(): + data = { + "name": "Complex Properties Task", + "id": "task-complex", + "properties": { + "string_prop": "string value", + "int_prop": 123, + "list_prop": [1, 2, 3], + "dict_prop": {"nested": True}, + }, + } + task = ControlMessageTask(**data) + assert task.properties["string_prop"] == "string value" + assert task.properties["int_prop"] == 123 + assert task.properties["list_prop"] == [1, 2, 3] + assert task.properties["dict_prop"] == {"nested": True} + + +def test_properties_with_invalid_type(): + data = {"name": "Invalid Properties Task", "id": "task-invalid-props", "properties": ["this", "should", "fail"]} + with pytest.raises(ValidationError) as exc_info: + ControlMessageTask(**data) + errors = exc_info.value.errors() + assert len(errors) == 1 + assert errors[0]["loc"] == ("properties",) + + +def test_set_and_get_metadata(): + cm = IngestControlMessage() + cm.set_metadata("key1", "value1") + # Test string lookup remains unchanged. + assert cm.get_metadata("key1") == "value1" + + +def test_get_all_metadata(): + cm = IngestControlMessage() + cm.set_metadata("key1", "value1") + cm.set_metadata("key2", "value2") + all_metadata = cm.get_metadata() + assert isinstance(all_metadata, dict) + assert all_metadata == {"key1": "value1", "key2": "value2"} + # Ensure a copy is returned. + all_metadata["key1"] = "modified" + assert cm.get_metadata("key1") == "value1" + + +def test_has_metadata(): + cm = IngestControlMessage() + cm.set_metadata("present", 123) + # Test string lookup remains unchanged. + assert cm.has_metadata("present") + assert not cm.has_metadata("absent") + + +def test_list_metadata(): + cm = IngestControlMessage() + keys = ["alpha", "beta", "gamma"] + for key in keys: + cm.set_metadata(key, key.upper()) + metadata_keys = cm.list_metadata() + assert sorted(metadata_keys) == sorted(keys) + + +def test_get_metadata_regex_match(): + """ + Validate that get_metadata returns a dict of all matching metadata entries when a regex is provided. + """ + cm = IngestControlMessage() + cm.set_metadata("alpha", 1) + cm.set_metadata("beta", 2) + cm.set_metadata("gamma", 3) + # Use a regex to match keys that start with "a" or "g". + pattern = re.compile("^(a|g)") + result = cm.get_metadata(pattern) + expected = {"alpha": 1, "gamma": 3} + assert result == expected + + +def test_get_metadata_regex_no_match(): + """ + Validate that get_metadata returns the default value when a regex is provided but no keys match. + """ + cm = IngestControlMessage() + cm.set_metadata("alpha", 1) + cm.set_metadata("beta", 2) + pattern = re.compile("z") + # Return default as an empty dict when no match is found. + result = cm.get_metadata(pattern, default_value={}) + assert result == {} + + +def test_has_metadata_regex_match(): + """ + Validate that has_metadata returns True if any metadata key matches the regex. + """ + cm = IngestControlMessage() + cm.set_metadata("key1", "value1") + cm.set_metadata("other", "value2") + assert cm.has_metadata(re.compile("^key")) + assert not cm.has_metadata(re.compile("nonexistent")) + + +def test_set_timestamp_with_datetime(): + cm = IngestControlMessage() + dt = datetime(2025, 1, 1, 12, 0, 0) + cm.set_timestamp("start", dt) + retrieved = cm.get_timestamp("start") + assert retrieved == dt + + +def test_set_timestamp_with_string(): + cm = IngestControlMessage() + iso_str = "2025-01-01T12:00:00" + dt = datetime.fromisoformat(iso_str) + cm.set_timestamp("start", iso_str) + retrieved = cm.get_timestamp("start") + assert retrieved == dt + + +def test_set_timestamp_invalid_input(): + cm = IngestControlMessage() + with pytest.raises(ValueError): + cm.set_timestamp("bad", 123) + with pytest.raises(ValueError): + cm.set_timestamp("bad", "not-a-timestamp") + + +def test_get_timestamp_nonexistent(): + cm = IngestControlMessage() + assert cm.get_timestamp("missing") is None + + +def test_get_timestamp_nonexistent_fail(): + cm = IngestControlMessage() + with pytest.raises(KeyError): + cm.get_timestamp("missing", fail_if_nonexist=True) + + +def test_get_timestamps(): + cm = IngestControlMessage() + dt1 = datetime(2025, 1, 1, 12, 0, 0) + dt2 = datetime(2025, 1, 2, 12, 0, 0) + cm.set_timestamp("start", dt1) + cm.set_timestamp("end", dt2) + timestamps = cm.get_timestamps() + assert timestamps == {"start": dt1, "end": dt2} + timestamps["start"] = datetime(2025, 1, 1, 0, 0, 0) + assert cm.get_timestamp("start") == dt1 + + +def test_filter_timestamp(): + cm = IngestControlMessage() + dt1 = datetime(2025, 1, 1, 12, 0, 0) + dt2 = datetime(2025, 1, 2, 12, 0, 0) + dt3 = datetime(2025, 1, 3, 12, 0, 0) + cm.set_timestamp("start", dt1) + cm.set_timestamp("end", dt2) + cm.set_timestamp("middle", dt3) + filtered = cm.filter_timestamp("nothing") + assert set(filtered.keys()) == set() + filtered = cm.filter_timestamp("^(s|m)") + expected_keys = {"start", "middle"} + assert set(filtered.keys()) == expected_keys + filtered_e = cm.filter_timestamp("^e") + assert set(filtered_e.keys()) == {"end"} + + +def test_remove_existing_task(): + cm = IngestControlMessage() + task = ControlMessageTask(name="Test Task", id="task1", properties={"param": "value"}) + cm.add_task(task) + assert cm.has_task("task1") + cm.remove_task("task1") + assert not cm.has_task("task1") + tasks = list(cm.get_tasks()) + assert all(t.id != "task1" for t in tasks) + + +def test_remove_nonexistent_task(): + cm = IngestControlMessage() + task = ControlMessageTask(name="Test Task", id="task1", properties={"param": "value"}) + cm.add_task(task) + cm.remove_task("nonexistent") + assert cm.has_task("task1") + tasks = list(cm.get_tasks()) + assert any(t.id == "task1" for t in tasks) + + +def test_payload_get_default(): + cm = IngestControlMessage() + payload = cm.payload() + assert isinstance(payload, pd.DataFrame) + assert payload.empty + + +def test_payload_set_valid(): + cm = IngestControlMessage() + df = pd.DataFrame({"col1": [1, 2], "col2": ["a", "b"]}) + returned_payload = cm.payload(df) + pd.testing.assert_frame_equal(returned_payload, df) + pd.testing.assert_frame_equal(cm.payload(), df) + + +def test_payload_set_invalid(): + cm = IngestControlMessage() + with pytest.raises(ValueError): + cm.payload("not a dataframe") + + +def test_config_get_default(): + cm = IngestControlMessage() + default_config = cm.config() + assert isinstance(default_config, dict) + assert default_config == {} + + +def test_config_update_valid(): + cm = IngestControlMessage() + new_config = {"setting": True, "threshold": 10} + updated_config = cm.config(new_config) + assert updated_config == new_config + additional_config = {"another_setting": "value"} + updated_config = cm.config(additional_config) + assert updated_config == {"setting": True, "threshold": 10, "another_setting": "value"} + + +def test_config_update_invalid(): + cm = IngestControlMessage() + with pytest.raises(ValueError): + cm.config("not a dict") + + +def test_copy_creates_deep_copy(): + cm = IngestControlMessage() + task = ControlMessageTask(name="Test Task", id="task1", properties={"param": "value"}) + cm.add_task(task) + cm.set_metadata("meta", "data") + dt = datetime(2025, 1, 1, 12, 0, 0) + cm.set_timestamp("start", dt) + df = pd.DataFrame({"col": [1, 2]}) + cm.payload(df) + cm.config({"config_key": "config_value"}) + + copy_cm = cm.copy() + assert copy_cm is not cm + assert list(copy_cm.get_tasks()) == list(cm.get_tasks()) + assert copy_cm.get_metadata() == cm.get_metadata() + assert copy_cm.get_timestamps() == cm.get_timestamps() + pd.testing.assert_frame_equal(copy_cm.payload(), cm.payload()) + assert copy_cm.config() == cm.config() + + copy_cm.remove_task("task1") + copy_cm.set_metadata("meta", "new_data") + copy_cm.set_timestamp("start", "2025-01-02T12:00:00") + copy_cm.payload(pd.DataFrame({"col": [3, 4]})) + copy_cm.config({"config_key": "new_config"}) + + assert cm.has_task("task1") + assert cm.get_metadata("meta") == "data" + assert cm.get_timestamp("start") == dt + pd.testing.assert_frame_equal(cm.payload(), df) + assert cm.config()["config_key"] == "config_value" + + +def test_remove_nonexistent_task_logs_warning(caplog): + cm = IngestControlMessage() + with caplog.at_level("WARNING"): + cm.remove_task("nonexistent") + assert "Attempted to remove non-existent task" in caplog.text diff --git a/tests/nv_ingest_api/primitives/test_ingest_control_message_task.py b/tests/nv_ingest_api/primitives/test_ingest_control_message_task.py new file mode 100644 index 00000000..e99fd294 --- /dev/null +++ b/tests/nv_ingest_api/primitives/test_ingest_control_message_task.py @@ -0,0 +1,64 @@ +import pytest + +from nv_ingest_api.primitives.ingest_control_message import IngestControlMessage +from nv_ingest_api.primitives.control_message_task import ControlMessageTask + + +def test_empty_control_message(): + """ + Validate that an IngestControlMessage with no tasks returns an empty list from get_tasks() + and that has_task returns False for any task id. + """ + cm = IngestControlMessage() + assert list(cm.get_tasks()) == [] + assert not cm.has_task("nonexistent") + + +def test_add_single_task(): + """ + Validate that adding a single ControlMessageTask stores the task correctly, making it retrievable + via has_task and get_tasks. + """ + cm = IngestControlMessage() + task = ControlMessageTask(name="Test Task", id="task1", properties={"key": "value"}) + cm.add_task(task) + assert cm.has_task("task1") + tasks = list(cm.get_tasks()) + assert len(tasks) == 1 + assert tasks[0] == task + + +def test_add_duplicate_task(): + """ + Validate that adding a duplicate task (same id) raises a ValueError indicating that tasks must be unique. + """ + cm = IngestControlMessage() + task = ControlMessageTask(name="Test Task", id="task1", properties={"key": "value"}) + cm.add_task(task) + duplicate_task = ControlMessageTask(name="Another Task", id="task1", properties={"key": "other"}) + with pytest.raises(ValueError) as exc_info: + cm.add_task(duplicate_task) + assert "already exists" in str(exc_info.value) + + +def test_multiple_tasks(): + """ + Validate that multiple tasks added to IngestControlMessage are stored and retrievable. + Ensures that has_task returns True for all added tasks and that get_tasks returns the correct set of tasks. + """ + cm = IngestControlMessage() + task_data = [ + {"name": "Task A", "id": "a", "properties": {}}, + {"name": "Task B", "id": "b", "properties": {"x": 10}}, + {"name": "Task C", "id": "c", "properties": {"y": 20}}, + ] + tasks = [ControlMessageTask(**data) for data in task_data] + for task in tasks: + cm.add_task(task) + for data in task_data: + assert cm.has_task(data["id"]) + retrieved_tasks = list(cm.get_tasks()) + assert len(retrieved_tasks) == len(task_data) + retrieved_ids = {t.id for t in retrieved_tasks} + expected_ids = {data["id"] for data in task_data} + assert retrieved_ids == expected_ids