Skip to content

Commit

Permalink
Allow last event to be garbage collected in ConcurrentExecution (#550)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtopper authored Dec 19, 2024
1 parent dc7f8c1 commit 9b970f7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
5 changes: 4 additions & 1 deletion storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,9 +840,12 @@ def _init(self):
self._lazy_init_complete = False

async def _worker(self):
event = None
try:
while True:
# Allow event to be garbage collected
job = None # noqa
event = None
completed = None # noqa
try:
# If we don't handle the event before we remove it from the queue, the effective max_in_flight will
# be 1 higher than requested. Hence, we peek.
Expand Down
45 changes: 45 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
build_flow,
)
from storey.flow import (
ConcurrentExecution,
Context,
ParallelExecution,
ParallelExecutionRunnable,
Expand Down Expand Up @@ -360,6 +361,50 @@ def test_async_offset_commit_before_termination_with_nosqltarget():
asyncio.run(async_offset_commit_before_termination_with_nosqltarget())


async def async_offset_commit_before_termination_with_concurrent_execution():
platform = Committer()
context = CommitterContext(platform)

max_wait_before_commit = 1

controller = build_flow(
[
AsyncEmitSource(context=context, explicit_ack=True, max_wait_before_commit=max_wait_before_commit),
ConcurrentExecution(event_processor=lambda x: x + 1),
Filter(lambda x: x < 3),
FlatMap(lambda x: [x, x * 10]),
Reduce(0, lambda acc, x: acc + x),
]
).run()

num_shards = 10
num_records_per_shard = 10

for offset in range(1, num_records_per_shard + 1):
for shard in range(num_shards):
event = Event(shard)
event.shard_id = shard
event.offset = offset
await controller.emit(event)

del event

await asyncio.sleep(max_wait_before_commit + 1)

try:
offsets = copy.copy(platform.offsets)
assert offsets == {("/", i): num_records_per_shard for i in range(num_shards)}
finally:
await controller.terminate()
termination_result = await controller.await_termination()
assert termination_result == 330


# ML-8799
def test_async_offset_commit_before_termination_with_concurrent_execution():
asyncio.run(async_offset_commit_before_termination_with_concurrent_execution())


def test_offset_not_committed_prematurely():
platform = Committer()
context = CommitterContext(platform)
Expand Down

0 comments on commit 9b970f7

Please sign in to comment.