Skip to content

Commit

Permalink
Support list key_field in Source. (#197)
Browse files Browse the repository at this point in the history
Co-authored-by: Gal Topper <[email protected]>
  • Loading branch information
Gal Topper and Gal Topper authored Apr 8, 2021
1 parent 9a7e12b commit 122fd04
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 21 deletions.
30 changes: 29 additions & 1 deletion integration/test_flow_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def test_writing_int_key(setup_teardown_test):
]).run()
controller.await_termination()


def test_writing_timedelta_key(setup_teardown_test):
table = Table(setup_teardown_test, V3ioDriver())

Expand Down Expand Up @@ -456,15 +457,42 @@ def test_write_multiple_keys_to_v3io_from_csv(setup_teardown_test):
controller.await_termination()

response = asyncio.run(get_kv_item(setup_teardown_test, '1.2'))
expected = {'n1': 1, 'n2': 2, 'n3': 3}
assert response.status_code == 200
assert expected == response.output.item

response = asyncio.run(get_kv_item(setup_teardown_test, '4.5'))
expected = {'n1': 4, 'n2': 5, 'n3': 6}
assert response.status_code == 200
assert expected == response.output.item


def test_write_multiple_keys_to_v3io(setup_teardown_test):
table = Table(setup_teardown_test, V3ioDriver())

controller = build_flow([
Source(key_field=['n1', 'n2']),
WriteToTable(table),
]).run()

controller.emit({'n1': 1, 'n2': 2, 'n3': 3})
controller.emit({'n1': 4, 'n2': 5, 'n3': 6})

controller.terminate()
controller.await_termination()

response = asyncio.run(get_kv_item(setup_teardown_test, '1.2'))
expected = {'n1': 1, 'n2': 2, 'n3': 3}
assert response.status_code == 200
assert expected == response.output.item

response = asyncio.run(get_kv_item(setup_teardown_test, '4.5'))
expected = {'n1': 4, 'n2': 5, 'n3': 6}
assert response.status_code == 200
assert expected == response.output.item


def test_write_none_time(setup_teardown_test):

table = Table(setup_teardown_test, V3ioDriver())
data = pd.DataFrame(
{
Expand Down
43 changes: 23 additions & 20 deletions storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import uuid
import warnings
from datetime import datetime, timezone
from datetime import datetime
from typing import List, Optional, Union, Callable, Coroutine, Iterable

import pandas
Expand Down Expand Up @@ -57,25 +57,29 @@ def _get_uuid(self):
return result

def _build_event(self, element, key, event_time):
if event_time is None and self._time_field is None:
event_time = datetime.now(timezone.utc)
if hasattr(element, 'id'):
event = element
body = element
element_is_event = hasattr(element, 'id')
if element_is_event:
body = element.body

if not key and self._key_field:
if isinstance(self._key_field, str):
key = body[self._key_field]
else:
key = []
for field in self._key_field:
key.append(body[field])
if not event_time and self._time_field:
event_time = body[self._time_field]

if element_is_event:
if key:
event.key = key
elif self._key_field:
event.key = event.body[self._key_field]
element.key = key
if event_time:
event.time = event_time
elif self._time_field:
event.time = event.body[self._time_field]
element.time = event_time
return element
else:
if not key and self._key_field:
key = element[self._key_field]
if not event_time and self._time_field:
event_time = element[self._time_field]
event = Event(element, id=self._get_uuid(), key=key, time=event_time)
return event
return Event(body, id=self._get_uuid(), key=key, time=event_time)


class FlowController(FlowControllerBase):
Expand Down Expand Up @@ -150,7 +154,7 @@ class Source(Flow):
"""
_legal_first_step = True

def __init__(self, buffer_size: Optional[int] = None, key_field: Optional[str] = None, time_field: Optional[str] = None,
def __init__(self, buffer_size: Optional[int] = None, key_field: Union[list, str, None] = None, time_field: Optional[str] = None,
**kwargs):
if buffer_size is None:
buffer_size = 1024
Expand Down Expand Up @@ -317,8 +321,7 @@ class AsyncSource(Flow):
"""
_legal_first_step = True

def __init__(self, buffer_size: int = 1024, key_field: Optional[str] = None, time_field: Optional[str] = None,
**kwargs):
def __init__(self, buffer_size: int = 1024, key_field: Union[list, str, None] = None, time_field: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if buffer_size <= 0:
raise ValueError('Buffer size must be positive')
Expand Down

0 comments on commit 122fd04

Please sign in to comment.