diff --git a/integration/test_flow_integration.py b/integration/test_flow_integration.py index 9beec3bc..61c21dec 100644 --- a/integration/test_flow_integration.py +++ b/integration/test_flow_integration.py @@ -410,6 +410,7 @@ def test_writing_int_key(setup_teardown_test): ]).run() controller.await_termination() + def test_writing_timedelta_key(setup_teardown_test): table = Table(setup_teardown_test, V3ioDriver()) @@ -456,15 +457,42 @@ def test_write_multiple_keys_to_v3io_from_csv(setup_teardown_test): controller.await_termination() response = asyncio.run(get_kv_item(setup_teardown_test, '1.2')) + expected = {'n1': 1, 'n2': 2, 'n3': 3} + assert response.status_code == 200 + assert expected == response.output.item + + response = asyncio.run(get_kv_item(setup_teardown_test, '4.5')) + expected = {'n1': 4, 'n2': 5, 'n3': 6} + assert response.status_code == 200 + assert expected == response.output.item + + +def test_write_multiple_keys_to_v3io(setup_teardown_test): + table = Table(setup_teardown_test, V3ioDriver()) + + controller = build_flow([ + Source(key_field=['n1', 'n2']), + WriteToTable(table), + ]).run() + controller.emit({'n1': 1, 'n2': 2, 'n3': 3}) + controller.emit({'n1': 4, 'n2': 5, 'n3': 6}) + + controller.terminate() + controller.await_termination() + + response = asyncio.run(get_kv_item(setup_teardown_test, '1.2')) expected = {'n1': 1, 'n2': 2, 'n3': 3} + assert response.status_code == 200 + assert expected == response.output.item + response = asyncio.run(get_kv_item(setup_teardown_test, '4.5')) + expected = {'n1': 4, 'n2': 5, 'n3': 6} assert response.status_code == 200 assert expected == response.output.item def test_write_none_time(setup_teardown_test): - table = Table(setup_teardown_test, V3ioDriver()) data = pd.DataFrame( { diff --git a/storey/sources.py b/storey/sources.py index d021d2c6..8eb0ba73 100644 --- a/storey/sources.py +++ b/storey/sources.py @@ -5,7 +5,7 @@ import threading import uuid import warnings -from datetime import datetime, timezone +from datetime import datetime from typing import List, Optional, Union, Callable, Coroutine, Iterable import pandas @@ -57,25 +57,29 @@ def _get_uuid(self): return result def _build_event(self, element, key, event_time): - if event_time is None and self._time_field is None: - event_time = datetime.now(timezone.utc) - if hasattr(element, 'id'): - event = element + body = element + element_is_event = hasattr(element, 'id') + if element_is_event: + body = element.body + + if not key and self._key_field: + if isinstance(self._key_field, str): + key = body[self._key_field] + else: + key = [] + for field in self._key_field: + key.append(body[field]) + if not event_time and self._time_field: + event_time = body[self._time_field] + + if element_is_event: if key: - event.key = key - elif self._key_field: - event.key = event.body[self._key_field] + element.key = key if event_time: - event.time = event_time - elif self._time_field: - event.time = event.body[self._time_field] + element.time = event_time + return element else: - if not key and self._key_field: - key = element[self._key_field] - if not event_time and self._time_field: - event_time = element[self._time_field] - event = Event(element, id=self._get_uuid(), key=key, time=event_time) - return event + return Event(body, id=self._get_uuid(), key=key, time=event_time) class FlowController(FlowControllerBase): @@ -150,7 +154,7 @@ class Source(Flow): """ _legal_first_step = True - def __init__(self, buffer_size: Optional[int] = None, key_field: Optional[str] = None, time_field: Optional[str] = None, + def __init__(self, buffer_size: Optional[int] = None, key_field: Union[list, str, None] = None, time_field: Optional[str] = None, **kwargs): if buffer_size is None: buffer_size = 1024 @@ -317,8 +321,7 @@ class AsyncSource(Flow): """ _legal_first_step = True - def __init__(self, buffer_size: int = 1024, key_field: Optional[str] = None, time_field: Optional[str] = None, - **kwargs): + def __init__(self, buffer_size: int = 1024, key_field: Union[list, str, None] = None, time_field: Optional[str] = None, **kwargs): super().__init__(**kwargs) if buffer_size <= 0: raise ValueError('Buffer size must be positive')