Skip to content

Commit

Permalink
ML-887: Always write outstanding streaming records when worker is idl…
Browse files Browse the repository at this point in the history
…e. (#291)

* ML-887: Always write outstanding streaming records when worker is idle.

* batch_size=8
  • Loading branch information
Gal Topper authored Sep 9, 2021
1 parent 308e6d9 commit 59d8380
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
30 changes: 19 additions & 11 deletions integration/test_flow_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from storey import Filter, JoinWithV3IOTable, SendToHttp, Map, Reduce, SyncEmitSource, HttpRequest, build_flow, \
StreamTarget, V3ioDriver, TSDBTarget, Table, JoinWithTable, MapWithState, NoSqlTarget, DataframeSource, \
CSVSource
CSVSource, AsyncEmitSource
from .integration_test_utils import V3ioHeaders, append_return, test_base_time, setup_kv_teardown_test, setup_teardown_test, \
setup_stream_teardown_test

Expand Down Expand Up @@ -84,22 +84,30 @@ def test_join_with_http():
assert termination_result == 200 * 7


def test_write_to_v3io_stream(setup_stream_teardown_test):
async def async_test_write_to_v3io_stream(setup_stream_teardown_test):
stream_path = setup_stream_teardown_test
controller = build_flow([
SyncEmitSource(),
AsyncEmitSource(),
Map(lambda x: str(x)),
StreamTarget(V3ioDriver(), stream_path, sharding_func=lambda event: int(event.body))
StreamTarget(V3ioDriver(), stream_path, sharding_func=lambda event: int(event.body), batch_size=8)
]).run()
for i in range(10):
controller.emit(i)
await controller.emit(i)

controller.terminate()
controller.await_termination()
shard0_data = asyncio.run(GetShardData().get_shard_data(f'{stream_path}/0'))
assert shard0_data == [b'0', b'2', b'4', b'6', b'8']
shard1_data = asyncio.run(GetShardData().get_shard_data(f'{stream_path}/1'))
assert shard1_data == [b'1', b'3', b'5', b'7', b'9']
await asyncio.sleep(5)

try:
shard0_data = await GetShardData().get_shard_data(f'{stream_path}/0')
assert shard0_data == [b'0', b'2', b'4', b'6', b'8']
shard1_data = await GetShardData().get_shard_data(f'{stream_path}/1')
assert shard1_data == [b'1', b'3', b'5', b'7', b'9']
finally:
await controller.terminate()
await controller.await_termination()


def test_write_to_v3io_stream(setup_stream_teardown_test):
asyncio.run(async_test_write_to_v3io_stream(setup_stream_teardown_test))


def test_write_to_v3io_stream_with_column_inference(setup_stream_teardown_test):
Expand Down
5 changes: 2 additions & 3 deletions storey/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ class StreamTarget(Flow, _Writer):
:type storage_options: dict
"""

def __init__(self, storage: Driver, stream_path: str, sharding_func: Optional[Callable[[Event], int]] = None, batch_size: int = 1,
def __init__(self, storage: Driver, stream_path: str, sharding_func: Optional[Callable[[Event], int]] = None, batch_size: int = 8,
columns: Optional[List[str]] = None, infer_columns_from_data: Optional[bool] = None, **kwargs):
kwargs['stream_path'] = stream_path
kwargs['batch_size'] = batch_size
Expand Down Expand Up @@ -658,8 +658,7 @@ async def _worker(self):
req = in_flight_reqs[shard_id]
in_flight_reqs[shard_id] = None
await self._handle_response(req)
if len(buffers[shard_id]) >= self._batch_size:
self._send_batch(buffers, in_flight_reqs, shard_id)
self._send_batch(buffers, in_flight_reqs, shard_id)
event = await self._q.get()
if event is _termination_obj: # handle outstanding batches and in flight requests on termination
for req in in_flight_reqs:
Expand Down

0 comments on commit 59d8380

Please sign in to comment.