Skip to content

Commit

Permalink
Cache (#184)
Browse files Browse the repository at this point in the history
* refactor

* working integ

* lint

* fix typo

* test fixes

* lint

* temp

* integ

* parametrize tests

* rename

* handle exceptions inside flush worker

* mid

* fix reuse

* flushing

* use changed items list instead of time lookup

* check if running loop exists

* add test

* fix test

* add flush interval enum

* update test and doc

* code review

* make flush interval an optional[int] and init_flush_task only from async code

* update doc

Co-authored-by: Dina Nimrodi <[email protected]>
  • Loading branch information
dinal and Dina Nimrodi authored Apr 12, 2021
1 parent 122fd04 commit 6aefc07
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 179 deletions.
29 changes: 17 additions & 12 deletions integration/test_aggregation_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
from storey.dtypes import SlidingWindows, FixedWindows
from storey.utils import _split_path

from .integration_test_utils import setup_teardown_test, append_return, test_base_time
from .integration_test_utils import setup_teardown_test, append_return, test_base_time, V3ioHeaders


@pytest.mark.parametrize('partitioned_by_key', [True, False])
def test_aggregate_and_query_with_different_windows(setup_teardown_test, partitioned_by_key):
table = Table(setup_teardown_test, V3ioDriver(), partitioned_by_key=partitioned_by_key)
@pytest.mark.parametrize('flush_interval', [None, 1])
def test_aggregate_and_query_with_different_windows(setup_teardown_test, partitioned_by_key, flush_interval):
table = Table(setup_teardown_test, V3ioDriver(), partitioned_by_key=partitioned_by_key, flush_interval_secs=flush_interval)

controller = build_flow([
Source(),
Expand Down Expand Up @@ -167,8 +168,9 @@ def test_query_virtual_aggregations_flow(setup_teardown_test):


@pytest.mark.parametrize('partitioned_by_key', [True, False])
def test_query_aggregate_by_key(setup_teardown_test, partitioned_by_key):
table = Table(setup_teardown_test, V3ioDriver(), partitioned_by_key=partitioned_by_key)
@pytest.mark.parametrize('flush_interval', [None, 1])
def test_query_aggregate_by_key(setup_teardown_test, partitioned_by_key, flush_interval):
table = Table(setup_teardown_test, V3ioDriver(), partitioned_by_key=partitioned_by_key, flush_interval_secs=flush_interval)

controller = build_flow([
Source(),
Expand Down Expand Up @@ -330,7 +332,8 @@ def test_aggregate_and_query_with_dependent_aggrs_different_windows(setup_teardo


@pytest.mark.parametrize('partitioned_by_key', [True, False])
def test_aggregate_by_key_one_underlying_window(setup_teardown_test, partitioned_by_key):
@pytest.mark.parametrize('flush_interval', [None, 1])
def test_aggregate_by_key_one_underlying_window(setup_teardown_test, partitioned_by_key, flush_interval):
expected = {1: [{'number_of_stuff_count_1h': 1, 'other_stuff_sum_1h': 0.0, 'col1': 0},
{'number_of_stuff_count_1h': 2, 'other_stuff_sum_1h': 1.0, 'col1': 1},
{'number_of_stuff_count_1h': 3, 'other_stuff_sum_1h': 3.0, 'col1': 2}],
Expand All @@ -346,7 +349,7 @@ def test_aggregate_by_key_one_underlying_window(setup_teardown_test, partitioned

for current_expected in expected.values():

table = Table(setup_teardown_test, V3ioDriver(), partitioned_by_key=partitioned_by_key)
table = Table(setup_teardown_test, V3ioDriver(), partitioned_by_key=partitioned_by_key, flush_interval_secs=flush_interval)
controller = build_flow([
Source(),
AggregateByKey([FieldAggregator("number_of_stuff", "col1", ["count"],
Expand Down Expand Up @@ -490,8 +493,9 @@ def enrich(event, state):
f'actual did not match expected. \n actual: {actual} \n expected: {expected_results}'


def test_write_cache_with_aggregations(setup_teardown_test):
table = Table(setup_teardown_test, V3ioDriver())
@pytest.mark.parametrize('flush_interval', [None, 1])
def test_write_cache_with_aggregations(setup_teardown_test, flush_interval):
table = Table(setup_teardown_test, V3ioDriver(), flush_interval_secs=flush_interval)

table['tal'] = {'color': 'blue', 'age': 41, 'iss': True, 'sometime': test_base_time}

Expand Down Expand Up @@ -547,7 +551,7 @@ def enrich(event, state):
assert actual == expected_results, \
f'actual did not match expected. \n actual: {actual} \n expected: {expected_results}'

other_table = Table(setup_teardown_test, V3ioDriver())
other_table = Table(setup_teardown_test, V3ioDriver(), flush_interval_secs=flush_interval)

controller = build_flow([
Source(),
Expand All @@ -570,8 +574,9 @@ def enrich(event, state):
f'actual did not match expected. \n actual: {actual} \n expected: {expected_results}'


def test_write_cache(setup_teardown_test):
table = Table(setup_teardown_test, V3ioDriver())
@pytest.mark.parametrize('flush_interval', [None, 1])
def test_write_cache(setup_teardown_test, flush_interval):
table = Table(setup_teardown_test, V3ioDriver(), flush_interval_secs=flush_interval)

table['tal'] = {'color': 'blue', 'age': 41, 'iss': True, 'sometime': datetime.now()}

Expand Down
24 changes: 22 additions & 2 deletions integration/test_flow_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import v3io_frames as frames

from storey import Filter, JoinWithV3IOTable, SendToHttp, Map, Reduce, Source, HttpRequest, build_flow, \
WriteToV3IOStream, V3ioDriver, WriteToTSDB, Table, JoinWithTable, MapWithState, WriteToTable, DataframeSource, ReduceToDataFrame, \
QueryByKey, AggregateByKey, ReadCSV
WriteToV3IOStream, V3ioDriver, WriteToTSDB, Table, JoinWithTable, MapWithState, WriteToTable, DataframeSource, \
ReadCSV
from storey.utils import hash_list
from .integration_test_utils import V3ioHeaders, append_return, test_base_time, setup_kv_teardown_test, setup_teardown_test, \
setup_stream_teardown_test
Expand Down Expand Up @@ -524,3 +524,23 @@ def set_moshe_time_to_none(data):
expected = {'first_name': 'moshe', 'color': 'blue'}
assert response.status_code == 200
assert expected == response.output.item


def test_cache_flushing(setup_teardown_test):
table = Table(setup_teardown_test, V3ioDriver(), flush_interval_secs=3)
controller = build_flow([
Source(),
WriteToTable(table),
]).run()

controller.emit({'col1': 0}, 'dina', test_base_time + timedelta(minutes=25))
response = asyncio.run(get_kv_item(setup_teardown_test, 'dina')).output.item
assert response == {}
time.sleep(4)

response = asyncio.run(get_kv_item(setup_teardown_test, 'dina')).output.item
assert response == {'col1': 0}

controller.terminate()
controller.await_termination()

136 changes: 1 addition & 135 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ async def _call(self, event):
key_data = self._state[event.key]
res, new_state = self._fn(element, key_data)
self._state._set_static_attrs(event.key, new_state)
self._state._init_flush_task()
else:
res, self._state = self._fn(element, self._state)
if self._is_async:
Expand Down Expand Up @@ -640,141 +641,6 @@ async def _do(self, event):
await self._worker_awaitable


class _PendingEvent:
def __init__(self):
self.in_flight = []
self.pending = []


class _ConcurrentByKeyJobExecution(Flow):
def __init__(self, max_in_flight=8, **kwargs):
kwargs['max_in_flight'] = max_in_flight
Flow.__init__(self, **kwargs)
self._max_in_flight = max_in_flight

def _init(self):
super()._init()
self._q = None
self._pending_by_key = {}

async def _worker(self):
event = None
received_job_count = 0
self_sent_jobs = {}
try:
while True:
jobs = self_sent_jobs.pop(received_job_count, None)
if jobs:
job = jobs[0]
if len(jobs) > 1:
self_sent_jobs[received_job_count] = jobs[1:]
else:
job = await self._q.get()
received_job_count += 1
if job is _termination_obj:
if received_job_count in self_sent_jobs:
await self._q.put(_termination_obj)
continue
for pending_event in self._pending_by_key.values():
if pending_event.pending and not pending_event.in_flight:
resp = await self._safe_process_events(pending_event.pending)
for event in pending_event.pending:
await self._handle_completed(event, resp)
break

event = job[0]
completed = await job[1]

if isinstance(event.key, list):
event_key = str(event.key)
else:
event_key = event.key

for event in self._pending_by_key[event_key].in_flight:
await self._handle_completed(event, completed)
self._pending_by_key[event_key].in_flight = []

# If we got more pending events for the same key process them
if self._pending_by_key[event_key].pending:
self._pending_by_key[event_key].in_flight = self._pending_by_key[event_key].pending
self._pending_by_key[event_key].pending = []

task = self._safe_process_events(self._pending_by_key[event_key].in_flight)
tail_position = received_job_count + self._q.qsize()
jobs_at_tail = self_sent_jobs.get(tail_position, [])
jobs_at_tail.append((event, asyncio.get_running_loop().create_task(task)))
self_sent_jobs[tail_position] = jobs_at_tail
else:
del self._pending_by_key[event_key]
except BaseException as ex:
if event and event is not _termination_obj and event._awaitable_result:
event._awaitable_result._set_error(ex)
if not self._q.empty():
await self._q.get()
raise ex
finally:
await self._cleanup()

async def _do(self, event):
if not self._q:
await self._lazy_init()
self._q = asyncio.queues.Queue(self._max_in_flight)
self._worker_awaitable = asyncio.get_running_loop().create_task(self._worker())

if self._worker_awaitable.done():
await self._worker_awaitable
raise FlowError("ConcurrentByKeyJobExecution worker has already terminated")

if event is _termination_obj:
await self._q.put(_termination_obj)
await self._worker_awaitable
return await self._do_downstream(_termination_obj)
else:
# Initializing the key with 2 lists. One for pending requests and one for requests that an update request has been issued for.
if isinstance(event.key, list):
# list can't be key in a dictionary
event_key = str(event.key)
else:
event_key = event.key

if event_key not in self._pending_by_key:
self._pending_by_key[event_key] = _PendingEvent()

# If there is a current update in flight for the key, add the event to the pending list. Otherwise update the key.
self._pending_by_key[event_key].pending.append(event)
if len(self._pending_by_key[event_key].in_flight) == 0:
self._pending_by_key[event_key].in_flight = self._pending_by_key[event_key].pending
self._pending_by_key[event_key].pending = []

task = self._safe_process_events(self._pending_by_key[event_key].in_flight)
await self._q.put((event, asyncio.get_running_loop().create_task(task)))
if self._worker_awaitable.done():
await self._worker_awaitable

async def _safe_process_events(self, events):
try:
return await self._process_events(events)
except BaseException as ex:
for event in events:
if event._awaitable_result:
none_or_coroutine = event._awaitable_result._set_error(ex)
if none_or_coroutine:
await none_or_coroutine
raise ex

async def _process_events(self, events):
raise NotImplementedError()

async def _handle_completed(self, event, response):
raise NotImplementedError()

async def _cleanup(self):
pass

async def _lazy_init(self):
pass


class SendToHttp(_ConcurrentJobExecution):
"""Joins each event with data from any HTTP source. Used for event augmentation.
Expand Down
Loading

0 comments on commit 6aefc07

Please sign in to comment.