From d21a467b3adc89bd052fb8728d6be45192c80744 Mon Sep 17 00:00:00 2001 From: Ryan Culbertson Date: Wed, 12 Feb 2025 22:21:09 +0000 Subject: [PATCH] Update structure of context and manager --- modal/_utils/async_utils.py | 3 + modal/parallel_map.py | 194 ++++++++++-------- modal/retries.py | 8 +- modal_proto/api.proto | 10 +- test/conftest.py | 12 +- test/map_item_context_test.py | 228 +++++++++------------ test/map_item_mananger_test.py | 357 +++++++++++++++++++++------------ 7 files changed, 447 insertions(+), 365 deletions(-) diff --git a/modal/_utils/async_utils.py b/modal/_utils/async_utils.py index 0b2422496d..9f840ab28c 100644 --- a/modal/_utils/async_utils.py +++ b/modal/_utils/async_utils.py @@ -314,6 +314,9 @@ async def get(self) -> Union[T, None]: def empty(self) -> bool: return self._queue.empty() + def __len__(self): + return self._queue.qsize() + async def queue_batch_iterator( q: Union[asyncio.Queue, TimestampPriorityQueue], max_batch_size=100, debounce_time=0.015 diff --git a/modal/parallel_map.py b/modal/parallel_map.py index 6098a7f3f3..6cfb9d07f2 100644 --- a/modal/parallel_map.py +++ b/modal/parallel_map.py @@ -1,9 +1,9 @@ # Copyright Modal Labs 2024 import asyncio +import enum import time import typing from dataclasses import dataclass -from enum import Enum from typing import Any, Callable, Optional from grpclib import GRPCError, Status @@ -30,6 +30,7 @@ _process_result, ) from modal._utils.grpc_utils import retry_transient_errors +from modal._utils.jwt_utils import DecodedJwt from modal.config import logger from modal.retries import RetryManager from modal_proto import api_pb2 @@ -169,7 +170,9 @@ async def retry_inputs(): async for retriable_idxs in queue_batch_iterator(retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE): # For each index, use the context in the manager to create a FunctionRetryInputsItem. # This will also update the context state to RETRYING. - inputs: list[api_pb2.FunctionRetryInputsItem] = map_items_manager.get_items_for_retry(retriable_idxs) + inputs: list[api_pb2.FunctionRetryInputsItem] = await map_items_manager.prepare_items_for_retry( + retriable_idxs + ) request = api_pb2.FunctionRetryInputsRequest( function_call_jwt=function_call_jwt, inputs=inputs, @@ -209,11 +212,10 @@ async def get_all_outputs(): last_entry_id = "0-0" while not have_all_inputs or num_outputs < num_inputs: - await asyncio.sleep(3) logger.debug(f"Requesting outputs. Have {num_outputs} outputs, {num_inputs} inputs.") # Get input_jwts of all items in the WAITING_FOR_OUTPUT state. # The server uses these to track for lost inputs. - input_jwts = [await ctx.input_jwt for ctx in map_items_manager.get_items_waiting_for_output()] + input_jwts = [input_jwt for input_jwt in map_items_manager.get_input_jwts_waiting_for_output()] request = api_pb2.FunctionGetOutputsRequest( function_call_id=function_call_id, @@ -490,7 +492,7 @@ def main(): ) -class _MapItemState(Enum): +class _MapItemState(enum.Enum): # The input is being sent the server with a PutInputs request, but the response has not been received yet. SENDING = 1 # A call to either PutInputs or FunctionRetry has completed, and we are waiting to receive the output. @@ -499,6 +501,8 @@ class _MapItemState(Enum): WAITING_TO_RETRY = 3 # The input is being sent to the server with a FunctionRetry request, but the response has not been received yet. RETRYING = 4 + # The output has been received and was either successful, or failed with no more retries remaining. + COMPLETE = 5 class _MapItemContext: @@ -524,31 +528,81 @@ def __init__(self, input: api_pb2.FunctionInput, retry_manager: RetryManager): self.input_jwt = self._event_loop.create_future() self.input_id = self._event_loop.create_future() - def set_state_waiting_for_output(self, input_id: str, input_jwt: str): - assert self.state == _MapItemState.SENDING, self.state - self.input_jwt.set_result(input_jwt) - self.input_id.set_result(input_id) - self.state = _MapItemState.WAITING_FOR_OUTPUT + def handle_put_inputs_response(self, item: api_pb2.FunctionPutInputsResponseItem): + self.input_jwt.set_result(item.input_jwt) + self.input_id.set_result(item.input_id) + # Set state to WAITING_FOR_OUTPUT only if current state is SENDING. If state is + # RETRYING, WAITING_TO_RETRY, or COMPLETE, then we already got the output. + if self.state == _MapItemState.SENDING: + self.state = _MapItemState.WAITING_FOR_OUTPUT - def set_state_waiting_for_output_after_retry(self, input_jwt: str): - assert self.state == _MapItemState.RETRYING, self.state - self.input_jwt.set_result(input_jwt) - self.state = _MapItemState.WAITING_FOR_OUTPUT + async def handle_get_outputs_response( + self, + item: api_pb2.FunctionGetOutputsItem, + now_seconds: int, + function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType", + retry_queue: TimestampPriorityQueue, + ) -> bool: + """ + Processes the output, and determines if it is complete or needs to be retried. + + Return True if input state was changed to COMPLETE, otherwise False. + """ + # If the item is already complete, this is a duplicate output and can be ignored. + # If the item's retry count doesn't match our retry count, this is probably an old output. + if self.state == _MapItemState.COMPLETE or item.retry_count != self.retry_manager.retry_count: + # We've already processed this output, so we can skip it. + # This can happen because the worker can sometimes send duplicate outputs. + return False + + # retry failed inputs when the function call invocation type is SYNC + if ( + item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS + or function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC + ): + self.state = _MapItemState.COMPLETE + return True + + # Get the retry delay and increment the retry count. + # TODO(ryan): We must call this for lost inputs - even though we will set the retry delay to 0 later - + # because we must increment the retry count. That's awkward, let's come up with something better. + # TODO(ryan):To maintain paritiy with server-side retries, retrying lost inputs should not count towards + # the retry policy. However we use the retry_count number as a unique identifier on each attempt to: + # 1) ignore duplicate outputs + # 2) ignore late outputs received from previous attempts + # 3) avoid a server race condition between FunctionRetry and GetOutputs that results in deleted input metadata + # For now, lost inputs will count towards the retry policy. But let's address this in another PR, perhaps by + # tracking total attempts and attempts which count towards the retry policy separately. + delay_ms = self.retry_manager.get_delay_ms() + + # For system failures on the server, we retry immediately. + # and the failure does not count towards the retry policy. + if item.result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE: + delay_ms = 0 + + # None means the maximum number of retries has been reached, so output the error + if delay_ms is None: + self.state = _MapItemState.COMPLETE + return True - def set_state_waiting_for_retry(self): - assert self.state == _MapItemState.WAITING_FOR_OUTPUT, self.state - # When we call FunctionRetry, we pass the input_jwt from the previous request, - # which is either the original call to PutInputs, or a previous call to FunctionRetry. - # FunctionRetry then returnd a new input_jwt. Each retry produces a new input_jwt - # because it contains an entry_id which changes on every retry. - self.previous_input_jwt = self.input_jwt.result() - # We reset the input_jwt to a new future so it is ready when get_all_outputs awaits it. - self.input_jwt = self._event_loop.create_future() self.state = _MapItemState.WAITING_TO_RETRY + await retry_queue.put(now_seconds + (delay_ms / 1000), item.idx) + + return False - def set_state_retrying(self): - assert self.state == _MapItemState.WAITING_TO_RETRY, self.state + async def prepare_item_for_retry(self) -> api_pb2.FunctionRetryInputsItem: self.state = _MapItemState.RETRYING + input_jwt = await self.input_jwt + self.input_jwt = self._event_loop.create_future() + return api_pb2.FunctionRetryInputsItem( + input_jwt=input_jwt, + input=self.input, + retry_count=self.retry_manager.retry_count, + ) + + def handle_retry_response(self, input_jwt: str): + self.input_jwt.set_result(input_jwt) + self.state = _MapItemState.WAITING_FOR_OUTPUT class _MapItemsManager: @@ -574,84 +628,58 @@ async def add_items(self, items: list[api_pb2.FunctionPutInputsItem]): input=item.input, retry_manager=RetryManager(self._retry_policy) ) - def get_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]: - items: api_pb2.FunctionRetryInputsItem = [] - for retriable_idx in retriable_idxs: - ctx = self._item_context[retriable_idx] - ctx.set_state_retrying() - items.append( - api_pb2.FunctionRetryInputsItem( - input_jwt=ctx.previous_input_jwt, - input=ctx.input, - retry_count=ctx.retry_manager.attempt_count, - ) - ) - return items + async def prepare_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]: + return [await self._item_context[idx].prepare_item_for_retry() for idx in retriable_idxs] - def get_items_waiting_for_output(self) -> list[_MapItemContext]: - return [ctx for ctx in self._item_context.values() if ctx.state == _MapItemState.WAITING_FOR_OUTPUT] + def get_input_jwts_waiting_for_output(self) -> list[str]: + """ + Returns a list of input_jwts for inputs that are waiting for output. + """ + # If input_jwt is not done, the call to PutInputs has not completed, so omit it from results. + return [ + ctx.input_jwt.result() + for ctx in self._item_context.values() + if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done() + ] def _remove_item(self, item_idx: int): del self._item_context[item_idx] self._inputs_outstanding.release() def get_item_context(self, item_idx: int) -> _MapItemContext: - return self._item_context[item_idx] + return self._item_context.get(item_idx) def handle_put_inputs_response(self, items: list[api_pb2.FunctionPutInputsResponseItem]): for item in items: ctx = self._item_context.get(item.idx, None) - # If the context is None, then get_all_outputs() has already - # received a successful output, and deleted the context. + # If the context is None, then get_all_outputs() has already received a successful + # output, and deleted the context. This happens if FunctionGetOutputs completes + # before FunctionPutInputsResponse is received. if ctx is not None: - ctx.set_state_waiting_for_output(item.input_id, item.input_jwt) - - def handle_retry_response(self, items: list[api_pb2.FunctionRetryInputsResponseItem]): - for item in items: - ctx = self._item_context[item.idx] - ctx.set_state_waiting_for_output_after_retry(input_jwt=item.input_jwt) + ctx.handle_put_inputs_response(item) + + def handle_retry_response(self, input_jwts: list[str]): + for input_jwt in input_jwts: + decoded_jwt = DecodedJwt.decode_without_verification(input_jwt) + ctx = self._item_context.get(decoded_jwt.payload["idx"], None) + # If the context is None, then get_all_outputs() has already received a successful + # output, and deleted the context. This happens if FunctionGetOutputs completes + # before FunctionRetryInputsResponse is received. + if ctx is not None: + ctx.handle_retry_response(input_jwt) async def handle_get_outputs_response(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> bool: - output_is_complete = await self._handle_output(item, now_seconds) - if output_is_complete: - self._remove_item(item.idx) - return output_is_complete - - async def _handle_output(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> bool: - """ - Determines if an output is complete or needs to be retried. - - If complete, we remove the input from the manager, and return True. - Otherwise we place it on the retry queue, and return False. - """ ctx = self._item_context.get(item.idx, None) if ctx is None: # We've already processed this output, so we can skip it. # This can happen because the worker can sometimes send duplicate outputs. return False - - # retry failed inputs when the function call invocation type is SYNC - if ( - item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS - or self.function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC - ): - return True - - # For system failures on the server, we retry immediately, - # and the failure does not count towards the retry policy. - delay_ms = ( - 0 - if item.result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE - else ctx.retry_manager.get_delay_ms() + output_is_complete = await ctx.handle_get_outputs_response( + item, now_seconds, self.function_call_invocation_type, self._retry_queue ) - - # None means the maximum number of retries has been reached, so output the error - if delay_ms is None: - return True - - ctx.set_state_waiting_for_retry() - await self._retry_queue.put(now_seconds + (delay_ms / 1000), item.idx) - return False + if output_is_complete: + self._remove_item(item.idx) + return output_is_complete def __len__(self): return len(self._item_context) diff --git a/modal/retries.py b/modal/retries.py index 60a14606e2..f16210c1a8 100644 --- a/modal/retries.py +++ b/modal/retries.py @@ -116,19 +116,19 @@ class RetryManager: def __init__(self, retry_policy: api_pb2.FunctionRetryPolicy): self.retry_policy = retry_policy - self.attempt_count = 0 + self.retry_count = 0 def get_delay_ms(self) -> Union[float, None]: """ Returns the delay in milliseconds before the next retry, or None if the maximum number of retries has been reached. """ - self.attempt_count += 1 + self.retry_count += 1 - if self.attempt_count > self.retry_policy.retries: + if self.retry_count > self.retry_policy.retries: return None - return self._retry_delay_ms(self.attempt_count, self.retry_policy) + return self._retry_delay_ms(self.retry_count, self.retry_policy) @staticmethod def _retry_delay_ms(attempt_count: int, retry_policy: api_pb2.FunctionRetryPolicy) -> float: diff --git a/modal_proto/api.proto b/modal_proto/api.proto index bf8068a360..9046af87bc 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -1674,16 +1674,8 @@ message FunctionRetryInputsRequest { repeated FunctionRetryInputsItem inputs = 2; } -// TODO(ryan): We shouldn't need to pass back the idx value, since the client can -// get it from the input_jwt. But I wasn't able to get the client to see authlib, -// so couldn't decode it. -message FunctionRetryInputsResponseItem { - uint32 idx = 1; - string input_jwt = 2; -} - message FunctionRetryInputsResponse { - repeated FunctionRetryInputsResponseItem items = 1; + repeated string input_jwts = 1; } message FunctionRetryPolicy { diff --git a/test/conftest.py b/test/conftest.py index 976ea4bd46..c9ffc45c0f 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1072,7 +1072,7 @@ async def FunctionMap(self, stream): api_pb2.FunctionPutInputsResponseItem( idx=self.fcidx, input_id=input_id, - input_jwt=encode_input_jwt(self.fcidx, input_id, function_call_id), + input_jwt=encode_input_jwt(self.fcidx, input_id, function_call_id, self.next_entry_id()), ) ) @@ -1089,7 +1089,7 @@ async def FunctionRetryInputs(self, stream): request: api_pb2.FunctionRetryInputsRequest = await stream.recv_message() function_id, function_call_id = decode_function_call_jwt(request.function_call_jwt) function_call_inputs = self.client_calls.setdefault(function_call_id, []) - response_items = [] + input_jwts = [] for item in request.inputs: if item.input.WhichOneof("args_oneof") == "args": args, kwargs = deserialize(item.input.args, None) @@ -1098,13 +1098,9 @@ async def FunctionRetryInputs(self, stream): self.n_inputs += 1 idx, input_id, function_call_id, entry_id = decode_input_jwt(item.input_jwt) entry_id = self.next_entry_id() - response_items.append( - api_pb2.FunctionRetryInputsResponseItem( - idx=idx, input_jwt=encode_input_jwt(idx, input_id, function_call_id, entry_id) - ) - ) + input_jwts.append(encode_input_jwt(idx, input_id, function_call_id, entry_id)) function_call_inputs.append(((idx, input_id), (args, kwargs))) - await stream.send_message(api_pb2.FunctionRetryInputsResponse(items=response_items)) + await stream.send_message(api_pb2.FunctionRetryInputsResponse(input_jwts=input_jwts)) async def FunctionPutInputs(self, stream): request: api_pb2.FunctionPutInputsRequest = await stream.recv_message() diff --git a/test/map_item_context_test.py b/test/map_item_context_test.py index 7929ded686..3ec7add644 100644 --- a/test/map_item_context_test.py +++ b/test/map_item_context_test.py @@ -4,6 +4,14 @@ from modal.parallel_map import _MapItemContext, _MapItemState from modal.retries import RetryManager from modal_proto import api_pb2 +from test.supports.map_item_test_utils import ( + InputJwtData, + assert_context_is, + assert_retry_item_is, + result_failure, + result_internal_failure, + result_success, +) retry_policy = api_pb2.FunctionRetryPolicy( backoff_coefficient=1.0, @@ -12,9 +20,6 @@ retries=2, ) -result_success = api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS) -result_failure = api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE) -result_internal_failure = api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE) now_seconds = 1738439812 input_data = api_pb2.FunctionInput(args=b"test") @@ -24,28 +29,22 @@ def retry_queue(): return TimestampPriorityQueue() -def test_map_item_context_initial_state(): - map_item_context = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) - assert map_item_context.state == _MapItemState.SENDING - assert map_item_context.input == input_data - assert map_item_context.retry_manager.retry_count == 0 - assert map_item_context.input_id.done() == False - assert map_item_context.input_jwt.done() == False +def test_ctx_initial_state(): + ctx = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + assert_context_is(ctx, _MapItemState.SENDING, 0, None, None, input_data.args) @pytest.mark.asyncio async def test_successful_output(retry_queue): - map_item_context = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + ctx = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + input_jwt_data = InputJwtData.of(0, 0) # Put inputs - response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="input-0", input_jwt="jwt-0") - map_item_context.handle_put_inputs_response(response_item) - assert map_item_context.state == _MapItemState.WAITING_FOR_OUTPUT - # We call result here rather than await because we want to test that the result has been set already. - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="in-0", input_jwt=input_jwt_data.to_jwt()) + ctx.handle_put_inputs_response(response_item) + assert_context_is(ctx, _MapItemState.WAITING_FOR_OUTPUT, 0, "in-0", input_jwt_data, input_data.args) # Get outputs - changed_to_complete = await map_item_context.handle_get_outputs_response( + changed_to_complete = await ctx.handle_get_outputs_response( api_pb2.FunctionGetOutputsItem(idx=0, result=result_success), now_seconds, api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, @@ -53,25 +52,21 @@ async def test_successful_output(retry_queue): ) assert changed_to_complete == True - assert map_item_context.state == _MapItemState.COMPLETE - assert map_item_context.retry_manager.retry_count == 0 - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + assert_context_is(ctx, _MapItemState.COMPLETE, 0, "in-0", input_jwt_data, input_data.args) @pytest.mark.asyncio async def test_failed_output_no_retries(retry_queue): retry_policy = api_pb2.FunctionRetryPolicy(retries=0) - map_item_context = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + ctx = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + input_jwt_data = InputJwtData.of(0, 0) # Put inputs - response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="input-0", input_jwt="jwt-0") - map_item_context.handle_put_inputs_response(response_item) - assert map_item_context.state == _MapItemState.WAITING_FOR_OUTPUT - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="in-0", input_jwt=input_jwt_data.to_jwt()) + ctx.handle_put_inputs_response(response_item) + assert_context_is(ctx, _MapItemState.WAITING_FOR_OUTPUT, 0, "in-0", input_jwt_data, input_data.args) # Get outputs - changed_to_complete = await map_item_context.handle_get_outputs_response( + changed_to_complete = await ctx.handle_get_outputs_response( api_pb2.FunctionGetOutputsItem(idx=0, result=result_failure), now_seconds, api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, @@ -79,26 +74,22 @@ async def test_failed_output_no_retries(retry_queue): ) assert changed_to_complete == True - assert map_item_context.state == _MapItemState.COMPLETE - assert map_item_context.retry_manager.retry_count == 1 - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + assert_context_is(ctx, _MapItemState.COMPLETE, 1, "in-0", input_jwt_data, input_data.args) assert retry_queue.empty() @pytest.mark.asyncio async def test_failed_output_retries_then_succeeds(retry_queue): retry_policy = api_pb2.FunctionRetryPolicy(retries=1) - map_item_context = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + ctx = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + input_jwt_data_0 = InputJwtData.of(0, 0) # Put inputs - response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="input-0", input_jwt="jwt-0") - map_item_context.handle_put_inputs_response(response_item) - assert map_item_context.state == _MapItemState.WAITING_FOR_OUTPUT - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="in-0", input_jwt=input_jwt_data_0.to_jwt()) + ctx.handle_put_inputs_response(response_item) + assert_context_is(ctx, _MapItemState.WAITING_FOR_OUTPUT, 0, "in-0", input_jwt_data_0, input_data.args) # Get outputs - changed_to_complete = await map_item_context.handle_get_outputs_response( + changed_to_complete = await ctx.handle_get_outputs_response( api_pb2.FunctionGetOutputsItem(idx=0, result=result_failure), now_seconds, api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, @@ -106,52 +97,41 @@ async def test_failed_output_retries_then_succeeds(retry_queue): ) assert changed_to_complete == False - assert map_item_context.state == _MapItemState.WAITING_TO_RETRY - assert map_item_context.retry_manager.retry_count == 1 - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + assert_context_is(ctx, _MapItemState.WAITING_TO_RETRY, 1, "in-0", input_jwt_data_0, input_data.args) assert len(retry_queue) == 1 # Retry input - retry_item = await map_item_context.prepare_item_for_retry() - assert retry_item.input_jwt == "jwt-0" - assert retry_item.input == input_data - assert retry_item.retry_count == 1 - assert map_item_context.state == _MapItemState.RETRYING + retry_item = await ctx.prepare_item_for_retry() + assert_retry_item_is(retry_item, input_jwt_data_0, 1, input_data.args) + assert_context_is(ctx, _MapItemState.RETRYING, 1, "in-0", None, input_data.args) - map_item_context.handle_retry_response(api_pb2.FunctionRetryInputsResponseItem(idx=0, input_jwt="jwt-1")) - assert map_item_context.state == _MapItemState.WAITING_FOR_OUTPUT - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-1" + input_jwt_data_1 = InputJwtData.of(0, 1) + ctx.handle_retry_response(input_jwt_data_1.to_jwt()) + assert_context_is(ctx, _MapItemState.WAITING_FOR_OUTPUT, 1, "in-0", input_jwt_data_1, input_data.args) # Get outputs - changed_to_complete = await map_item_context.handle_get_outputs_response( + changed_to_complete = await ctx.handle_get_outputs_response( api_pb2.FunctionGetOutputsItem(idx=0, result=result_success, retry_count=1), now_seconds, api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, retry_queue, ) assert changed_to_complete == True - assert map_item_context.state == _MapItemState.COMPLETE - # retry count is incremented only on failures - assert map_item_context.retry_manager.retry_count == 1 - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-1" + assert_context_is(ctx, _MapItemState.COMPLETE, 1, "in-0", input_jwt_data_1, input_data.args) @pytest.mark.asyncio async def test_lost_input_retries_then_succeeds(retry_queue): retry_policy = api_pb2.FunctionRetryPolicy(retries=1) - map_item_context = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + ctx = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + input_jwt_data_0 = InputJwtData.of(0, 0) # Put inputs - response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="input-0", input_jwt="jwt-0") - map_item_context.handle_put_inputs_response(response_item) - assert map_item_context.state == _MapItemState.WAITING_FOR_OUTPUT - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="in-0", input_jwt=input_jwt_data_0.to_jwt()) + ctx.handle_put_inputs_response(response_item) + assert_context_is(ctx, _MapItemState.WAITING_FOR_OUTPUT, 0, "in-0", input_jwt_data_0, input_data.args) # Get outputs - changed_to_complete = await map_item_context.handle_get_outputs_response( + changed_to_complete = await ctx.handle_get_outputs_response( api_pb2.FunctionGetOutputsItem(idx=0, result=result_internal_failure), now_seconds, api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, @@ -159,52 +139,41 @@ async def test_lost_input_retries_then_succeeds(retry_queue): ) assert changed_to_complete == False - assert map_item_context.state == _MapItemState.WAITING_TO_RETRY - assert map_item_context.retry_manager.retry_count == 1 - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + assert_context_is(ctx, _MapItemState.WAITING_TO_RETRY, 1, "in-0", input_jwt_data_0, input_data.args) assert len(retry_queue) == 1 # Retry input - retry_item = await map_item_context.prepare_item_for_retry() - assert retry_item.input_jwt == "jwt-0" - assert retry_item.input == input_data - assert retry_item.retry_count == 1 - assert map_item_context.state == _MapItemState.RETRYING + retry_item = await ctx.prepare_item_for_retry() + assert_retry_item_is(retry_item, input_jwt_data_0, 1, input_data.args) + assert_context_is(ctx, _MapItemState.RETRYING, 1, "in-0", None, input_data.args) - map_item_context.handle_retry_response(api_pb2.FunctionRetryInputsResponseItem(idx=0, input_jwt="jwt-1")) - assert map_item_context.state == _MapItemState.WAITING_FOR_OUTPUT - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-1" + input_jwt_data_1 = InputJwtData.of(0, 1) + ctx.handle_retry_response(input_jwt_data_1.to_jwt()) + assert_context_is(ctx, _MapItemState.WAITING_FOR_OUTPUT, 1, "in-0", input_jwt_data_1, input_data.args) # Get outputs - changed_to_complete = await map_item_context.handle_get_outputs_response( + changed_to_complete = await ctx.handle_get_outputs_response( api_pb2.FunctionGetOutputsItem(idx=0, result=result_success, retry_count=1), now_seconds, api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, retry_queue, ) assert changed_to_complete == True - assert map_item_context.state == _MapItemState.COMPLETE - # retry count is incremented only on failures - assert map_item_context.retry_manager.retry_count == 1 - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-1" + assert_context_is(ctx, _MapItemState.COMPLETE, 1, "in-0", input_jwt_data_1, input_data.args) @pytest.mark.asyncio async def test_failed_output_exhausts_retries(retry_queue): retry_policy = api_pb2.FunctionRetryPolicy(retries=1) - map_item_context = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + ctx = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + input_jwt_data_0 = InputJwtData.of(0, 0) # Put inputs - response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="input-0", input_jwt="jwt-0") - map_item_context.handle_put_inputs_response(response_item) - assert map_item_context.state == _MapItemState.WAITING_FOR_OUTPUT - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="in-0", input_jwt=input_jwt_data_0.to_jwt()) + ctx.handle_put_inputs_response(response_item) + assert_context_is(ctx, _MapItemState.WAITING_FOR_OUTPUT, 0, "in-0", input_jwt_data_0, input_data.args) # Get outputs - changed_to_complete = await map_item_context.handle_get_outputs_response( + changed_to_complete = await ctx.handle_get_outputs_response( api_pb2.FunctionGetOutputsItem(idx=0, result=result_failure), now_seconds, api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, @@ -212,45 +181,36 @@ async def test_failed_output_exhausts_retries(retry_queue): ) assert changed_to_complete == False - assert map_item_context.state == _MapItemState.WAITING_TO_RETRY - assert map_item_context.retry_manager.retry_count == 1 - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + assert_context_is(ctx, _MapItemState.WAITING_TO_RETRY, 1, "in-0", input_jwt_data_0, input_data.args) assert len(retry_queue) == 1 # Retry input - retry_item = await map_item_context.prepare_item_for_retry() - assert retry_item.input_jwt == "jwt-0" - assert retry_item.input == input_data - assert retry_item.retry_count == 1 - assert map_item_context.state == _MapItemState.RETRYING + retry_item = await ctx.prepare_item_for_retry() + assert_retry_item_is(retry_item, input_jwt_data_0, 1, input_data.args) + assert_context_is(ctx, _MapItemState.RETRYING, 1, "in-0", None, input_data.args) - map_item_context.handle_retry_response(api_pb2.FunctionRetryInputsResponseItem(idx=0, input_jwt="jwt-1")) - assert map_item_context.state == _MapItemState.WAITING_FOR_OUTPUT - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-1" + input_jwt_data_1 = InputJwtData.of(0, 1) + ctx.handle_retry_response(input_jwt_data_1.to_jwt()) + assert_context_is(ctx, _MapItemState.WAITING_FOR_OUTPUT, 1, "in-0", input_jwt_data_1, input_data.args) # Get outputs - changed_to_complete = await map_item_context.handle_get_outputs_response( + changed_to_complete = await ctx.handle_get_outputs_response( api_pb2.FunctionGetOutputsItem(idx=0, result=result_failure, retry_count=1), now_seconds, api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, retry_queue, ) assert changed_to_complete == True - assert map_item_context.state == _MapItemState.COMPLETE - # retry count is incremented only on failures - assert map_item_context.retry_manager.retry_count == 2 - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-1" + assert_context_is(ctx, _MapItemState.COMPLETE, 2, "in-0", input_jwt_data_1, input_data.args) @pytest.mark.asyncio async def test_get_successful_output_before_put_inputs_completes(retry_queue): - map_item_context = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + ctx = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + input_jwt_data = InputJwtData.of(0, 0) # Get outputs - changed_to_complete = await map_item_context.handle_get_outputs_response( + changed_to_complete = await ctx.handle_get_outputs_response( api_pb2.FunctionGetOutputsItem(idx=0, result=result_success), now_seconds, api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, @@ -258,46 +218,40 @@ async def test_get_successful_output_before_put_inputs_completes(retry_queue): ) assert changed_to_complete == True - assert map_item_context.state == _MapItemState.COMPLETE - assert map_item_context.retry_manager.retry_count == 0 - assert map_item_context.input_id.done() == False - assert map_item_context.input_jwt.done() == False + assert_context_is(ctx, _MapItemState.COMPLETE, 0, None, None, input_data.args) assert retry_queue.empty() # Put inputs - response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="input-0", input_jwt="jwt-0") - map_item_context.handle_put_inputs_response(response_item) - assert map_item_context.state == _MapItemState.COMPLETE - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="in-0", input_jwt=input_jwt_data.to_jwt()) + ctx.handle_put_inputs_response(response_item) + assert_context_is(ctx, _MapItemState.COMPLETE, 0, "in-0", input_jwt_data, input_data.args) @pytest.mark.asyncio async def test_get_failed_output_before_put_inputs_completes(retry_queue): - map_item_context = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + ctx = _MapItemContext(input=input_data, retry_manager=RetryManager(retry_policy)) + input_jwt_data = InputJwtData.of(0, 0) # Get outputs - changed_to_complete = await map_item_context.handle_get_outputs_response( - api_pb2.FunctionGetOutputsItem(idx=0, result=result_success), + changed_to_complete = await ctx.handle_get_outputs_response( + api_pb2.FunctionGetOutputsItem(idx=0, result=result_failure), now_seconds, api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, retry_queue, ) assert changed_to_complete == False - assert map_item_context.state == _MapItemState.WAITING_TO_RETRY - assert map_item_context.retry_manager.retry_count == 1 - assert map_item_context.input_id.done() == False - assert map_item_context.input_jwt.done() == False - assert retry_queue.empty() + assert_context_is(ctx, _MapItemState.WAITING_TO_RETRY, 1, None, None, input_data.args) + assert len(retry_queue) == 1 # Put inputs - response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="input-0", input_jwt="jwt-0") - map_item_context.handle_put_inputs_response(response_item) - assert map_item_context.state == _MapItemState.WAITING_TO_RETRY - assert map_item_context.input_id.result() == "input-0" - assert map_item_context.input_jwt.result() == "jwt-0" + response_item = api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="in-0", input_jwt=input_jwt_data.to_jwt()) + ctx.handle_put_inputs_response(response_item) + assert_context_is(ctx, _MapItemState.WAITING_TO_RETRY, 1, "in-0", input_jwt_data, input_data.args) + assert len(retry_queue) == 1 -# TODO(ryan): Add test for retrying before put inputs completes. Need to check that we await -# for put inputs to return before retrying. +# TODO(ryan): Add test for: +# - retrying before put inputs completes. Need to check that we await for put inputs to return before retrying. +# - receiving an old output after a newer output. We should ignore the old output. +# - retrying multiple times and verifying that the retry count is incremented. diff --git a/test/map_item_mananger_test.py b/test/map_item_mananger_test.py index 51ad72bb9b..8611127ca7 100644 --- a/test/map_item_mananger_test.py +++ b/test/map_item_mananger_test.py @@ -4,6 +4,14 @@ from modal._utils.async_utils import TimestampPriorityQueue from modal.parallel_map import _MapItemsManager, _MapItemState from modal_proto import api_pb2 +from test.supports.map_item_test_utils import ( + InputJwtData, + assert_context_is, + assert_retry_item_is, + result_failure, + result_internal_failure, + result_success, +) retry_policy = api_pb2.FunctionRetryPolicy( backoff_coefficient=1.0, @@ -11,170 +19,271 @@ max_delay_ms=500, retries=2, ) - retry_queue: TimestampPriorityQueue - -result_success = api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS) -result_failure = api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_FAILURE) -result_internal_failure = api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE) +manager: _MapItemsManager now_seconds = 1738439812 +count = 10 @pytest.fixture(autouse=True) -def reset_retry_queue(): - global retry_queue +def reset_state(): + global retry_queue, manager retry_queue = TimestampPriorityQueue() - - -@pytest.mark.asyncio -async def test_happy_path(): - # Test putting inputs, and getting sucessful outputs. Verify context has proper values throughout. - count = 10 manager = _MapItemsManager( retry_policy=retry_policy, function_call_invocation_type=api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, retry_queue=retry_queue, ) - # Put inputs - put_items = [api_pb2.FunctionPutInputsItem(idx=i, input=api_pb2.FunctionInput(args=b"{i}")) for i in range(count)] + + +async def add_items(): + put_items = [ + api_pb2.FunctionPutInputsItem(idx=i, input=api_pb2.FunctionInput(args=f"{i}".encode())) for i in range(count) + ] await manager.add_items(put_items) assert len(manager) == count for i in range(count): ctx = manager.get_item_context(i) - assert ctx.state == _MapItemState.SENDING - assert ctx.input == put_items[i].input + assert_context_is(ctx, _MapItemState.SENDING, 0, None, None, f"{i}".encode()) + + +async def handle_put_inputs_response(state: _MapItemState): response_items = [ - api_pb2.FunctionPutInputsResponseItem(idx=i, input_id=f"in-{i}", input_jwt=f"jwt-{i}") for i in range(count) + api_pb2.FunctionPutInputsResponseItem(idx=i, input_id=f"in-{i}", input_jwt=InputJwtData.of(i, 0).to_jwt()) + for i in range(count) ] manager.handle_put_inputs_response(response_items) - for i in range(count): - ctx = manager.get_item_context(i) - assert ctx.state == _MapItemState.WAITING_FOR_OUTPUT - assert await ctx.input_id == response_items[i].input_id - assert await ctx.input_jwt == response_items[i].input_jwt - - # Get outputs - assert [await i.input_jwt for i in manager.get_items_waiting_for_output()] == [f"jwt-{i}" for i in range(count)] - for i in range(count): - output_is_complete = await manager.handle_get_outputs_response( - api_pb2.FunctionGetOutputsItem(idx=i, result=result_success), now_seconds - ) - assert output_is_complete == True - # Verify handle_get_outputs_response removed the item from the manager - assert len(manager) == 0 + if state == _MapItemState.COMPLETE: + assert len(manager) == 0 + else: + for i in range(count): + ctx = manager.get_item_context(i) + assert_context_is(ctx, state, 0, f"in-{i}", InputJwtData.of(i, 0), f"{i}".encode()) -@pytest.mark.asyncio -async def test_retry(): - count = 10 - manager = _MapItemsManager( - retry_policy=retry_policy, - function_call_invocation_type=api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, - retry_queue=retry_queue, - ) - # Put inputs - put_items = [api_pb2.FunctionPutInputsItem(idx=i, input=api_pb2.FunctionInput(args=b"{i}")) for i in range(count)] - await manager.add_items(put_items) - assert len(manager) == count - for i in range(count): - ctx = manager.get_item_context(i) - assert ctx.state == _MapItemState.SENDING - assert ctx.input == put_items[i].input - response_items = [ - api_pb2.FunctionPutInputsResponseItem(idx=i, input_id=f"in-{i}", input_jwt=f"jwt-{i}-0") for i in range(count) +def get_input_jwts_waiting_for_output(retry_count: int): + assert [InputJwtData.from_jwt(input_jwt) for input_jwt in manager.get_input_jwts_waiting_for_output()] == [ + InputJwtData.of(i, retry_count) for i in range(count) ] - manager.handle_put_inputs_response(response_items) - for i in range(count): - ctx = manager.get_item_context(i) - assert ctx.state == _MapItemState.WAITING_FOR_OUTPUT - assert await ctx.input_id == response_items[i].input_id - assert await ctx.input_jwt == response_items[i].input_jwt - # Get outputs - assert [await i.input_jwt for i in manager.get_items_waiting_for_output()] == [f"jwt-{i}-0" for i in range(count)] + +async def handle_get_outputs_response( + result: api_pb2.GenericResult, + state: _MapItemState, + retry_count: int, + output_is_complete: bool, + include_input_jwt: bool = True, +): for i in range(count): - output_is_complete = await manager.handle_get_outputs_response( - api_pb2.FunctionGetOutputsItem(idx=i, result=result_failure), now_seconds + _output_is_complete = await manager.handle_get_outputs_response( + api_pb2.FunctionGetOutputsItem(idx=i, result=result, retry_count=retry_count), now_seconds ) - assert output_is_complete == False - # all inputs should still be in the manager, and waiting for retry - assert len(manager) == count - for i in range(count): - assert manager.get_item_context(i).state == _MapItemState.WAITING_TO_RETRY + assert _output_is_complete == output_is_complete + ctx = manager.get_item_context(i) + if state == _MapItemState.COMPLETE: + assert ctx is None + else: + input_jwt = InputJwtData.of(i, retry_count) if include_input_jwt else None + # we add 1 to the retry count because it gets incremented during handling of the response + assert_context_is(ctx, state, retry_count + 1, f"in-{i}", input_jwt, f"{i}".encode()) + if state == _MapItemState.COMPLETE: + assert len(manager) == 0 + else: + assert len(manager) == count + if state == _MapItemState.WAITING_TO_RETRY: + assert len(retry_queue) == count + else: + assert len(retry_queue) == 0 + - # Retry lost input - retry_items: list[api_pb2.FunctionRetryInputsItem] = manager.get_items_for_retry([i for i in range(count)]) - assert len(retry_items) == count +async def prepare_items_for_retry(retry_count: int): + retry_items: list[api_pb2.FunctionRetryInputsItem] = await manager.prepare_items_for_retry( + [i for i in range(count)] + ) for i in range(count): - assert retry_items[i].input_jwt == f"jwt-{i}-0" - assert retry_items[i].input == put_items[i].input - assert retry_items[i].retry_count == 1 + assert_retry_item_is(retry_items[i], InputJwtData.of(i, retry_count - 1), retry_count, f"{i}".encode()) - # Update the jwt to something new. It will be different because the redis entry id will have changed. - response_items = [api_pb2.FunctionRetryInputsResponseItem(idx=i, input_jwt=f"jwt-{i}-1") for i in range(count)] + +def handle_retry_response(retry_count: int): + response_items = [InputJwtData.of(i, retry_count).to_jwt() for i in range(count)] manager.handle_retry_response(response_items) for i in range(count): ctx = manager.get_item_context(i) - assert ctx.state == _MapItemState.WAITING_FOR_OUTPUT - assert await ctx.input_id == f"in-{i}" - # Make sure we have the updated jwt and not the old one - assert await ctx.input_jwt == f"jwt-{i}-1" - - # Get outputs - assert [await i.input_jwt for i in manager.get_items_waiting_for_output()] == [f"jwt-{i}-1" for i in range(count)] - for i in range(count): - output_is_complete = await manager.handle_get_outputs_response( - api_pb2.FunctionGetOutputsItem(idx=i, result=result_success), now_seconds + assert_context_is( + ctx, + _MapItemState.WAITING_FOR_OUTPUT, + retry_count, + f"in-{i}", + InputJwtData.of(i, retry_count), + f"{i}".encode(), ) - assert output_is_complete == True - # handle_output should have removed the item from the manager - assert len(manager) == 0 + +async def clear_retry_queue(): + """ + Clear the retry queue. Simulates reading all elements from queue using queue_batch_iterator. + """ + while not retry_queue.empty(): + await retry_queue.get() + + +@pytest.mark.asyncio +async def test_happy_path(): + # pump_inputs - retry count 0 + await add_items() + await handle_put_inputs_response(_MapItemState.WAITING_FOR_OUTPUT) + # get_all_outputs + get_input_jwts_waiting_for_output(0) + await handle_get_outputs_response(result_success, _MapItemState.COMPLETE, 0, True) + + +@pytest.mark.asyncio +async def test_retry(): + # pump_inputs - retry count 0 + await add_items() + await handle_put_inputs_response(_MapItemState.WAITING_FOR_OUTPUT) + + # get_all_outputs - retry count 0 + get_input_jwts_waiting_for_output(0) + await handle_get_outputs_response(result_failure, _MapItemState.WAITING_TO_RETRY, 0, False) + + # retry_inputs - retry count 1 + await prepare_items_for_retry(1) + await clear_retry_queue() + handle_retry_response(1) + + # get_all_outputs - retry count 1 + get_input_jwts_waiting_for_output(1) + await handle_get_outputs_response(result_success, _MapItemState.COMPLETE, 1, True) @pytest.mark.asyncio async def test_retry_lost_input(): + # pump_inputs - retry count 0 + await add_items() + await handle_put_inputs_response(_MapItemState.WAITING_FOR_OUTPUT) + + # get_all_outputs - retry count 0 + get_input_jwts_waiting_for_output(0) + await handle_get_outputs_response(result_internal_failure, _MapItemState.WAITING_TO_RETRY, 0, False) + + # retry_inputs - retry count 1 + await prepare_items_for_retry(1) + await clear_retry_queue() + handle_retry_response(1) + + # get_all_outputs - retry count 1 + get_input_jwts_waiting_for_output(1) + await handle_get_outputs_response(result_success, _MapItemState.COMPLETE, 1, True) + + +@pytest.mark.asyncio +async def test_duplicate_succcesful_outputs(): + # pump_inputs - retry count 0 + await add_items() + await handle_put_inputs_response(_MapItemState.WAITING_FOR_OUTPUT) + + # get_all_outputs - retry count 0 + get_input_jwts_waiting_for_output(0) + await handle_get_outputs_response(result_success, _MapItemState.COMPLETE, 0, True) + + # get_all_outputs - retry count 0 (duplicate) + # No items should be waiting for output since we already processed all the outputs + assert manager.get_input_jwts_waiting_for_output() == [] + await handle_get_outputs_response(result_success, _MapItemState.COMPLETE, 0, False) + + +@pytest.mark.asyncio +async def test_duplicate_failed_outputs(): + # pump_inputs - retry count 0 + await add_items() + await handle_put_inputs_response(_MapItemState.WAITING_FOR_OUTPUT) + + # get_all_outputs - retry_count 0 + get_input_jwts_waiting_for_output(0) + await handle_get_outputs_response(result_failure, _MapItemState.WAITING_TO_RETRY, 0, False) + + # get_all_outputs - retry_count 0 (duplicate) + # No items should be waiting for output since we already processed all the outputs + assert manager.get_input_jwts_waiting_for_output() == [] + await handle_get_outputs_response(result_failure, _MapItemState.WAITING_TO_RETRY, 0, False) + + +@pytest.mark.asyncio +async def test_get_outputs_completes_before_put_inputs(): + # There is a race condition where we can send inputs to the server with PutInputs, but before it returns, + # a call to GetOutputs executing in a coroutine fetches the output and completes. Ensure we handle this + # properly. manager = _MapItemsManager( retry_policy=retry_policy, function_call_invocation_type=api_pb2.FunctionCallInvocationType.FUNCTION_CALL_INVOCATION_TYPE_SYNC, retry_queue=retry_queue, ) - # Put inputs - put_item = api_pb2.FunctionPutInputsItem(idx=0, input=api_pb2.FunctionInput(args=b"0")) - await manager.add_items([put_item]) - manager.handle_put_inputs_response( - [api_pb2.FunctionPutInputsResponseItem(idx=0, input_id="in-0", input_jwt="jwt-0")] - ) - # Get output that reports a lost input - output_is_complete = await manager.handle_get_outputs_response( - api_pb2.FunctionGetOutputsItem(idx=0, result=result_internal_failure), now_seconds - ) - assert output_is_complete == False - # Assert item is waiting to be retried. Because it is lost, it will be retried immediately. - idx = await retry_queue.get() - assert idx == 0 - # Retry lost input - retry_items: list[api_pb2.FunctionRetryInputsItem] = manager.get_items_for_retry([idx]) - assert len(retry_items) == 1 - retry_item = retry_items[0] - assert retry_item.input_jwt == "jwt-0" - assert retry_item.input == put_item.input - assert retry_item.retry_count == 0 - # The response will have a different input_jwt because the redis entry id will have changed - response_item = api_pb2.FunctionRetryInputsResponseItem(idx=0, input_jwt="jwt-1") - manager.handle_retry_response([response_item]) - ctx = manager.get_item_context(0) - assert ctx.state == _MapItemState.WAITING_FOR_OUTPUT - assert await ctx.input_id == "in-0" - assert await ctx.input_jwt == "jwt-1" - # Get succcessful output - output_is_complete = await manager.handle_get_outputs_response( - api_pb2.FunctionGetOutputsItem(idx=0, result=result_success), now_seconds - ) - assert output_is_complete == True + # pump_inputs - retry_count 0 - send request + await add_items() + + # get_all_outputs - retry_count 0 + # Verify there are no input_jwts waiting for output yet. The input_jwt is returned in the PutInputsResponse, + # which we have not received yet. + assert manager.get_input_jwts_waiting_for_output() == [] + await handle_get_outputs_response(result_success, _MapItemState.COMPLETE, 0, True) + + # pump_inputs - retry_count 0 - receive response + await handle_put_inputs_response(_MapItemState.COMPLETE) + + +@pytest.mark.asyncio +async def test_get_outputs_completes_before_function_retry(): + # pump_inputs - retry_count 0 + await add_items() + await handle_put_inputs_response(_MapItemState.WAITING_FOR_OUTPUT) + + # get_all_outputs - retry_count 0 + get_input_jwts_waiting_for_output(0) + await handle_get_outputs_response(result_failure, _MapItemState.WAITING_TO_RETRY, 0, False) + + # First retry fails + + # retry_inputs - retry_count 1 + await prepare_items_for_retry(1) + await clear_retry_queue() + + # get_all_outputs - retry_count 1 + # The retry call has not returned yet, so there are not input_jwts waiting for output. + assert manager.get_input_jwts_waiting_for_output() == [] + await handle_get_outputs_response(result_failure, _MapItemState.WAITING_TO_RETRY, 1, False, False) + + # retry_inputs - retry_count 1 - handle response + response_items = [InputJwtData.of(i, 1).to_jwt() for i in range(count)] + manager.handle_retry_response(response_items) + for i in range(count): + # Even though this the response for retry attempt 1, the retry count will be 2 because the above call to + # handle_get_outputs_response would have bumped the count. The jwt will still be for retry attempt 1. + assert_context_is( + manager.get_item_context(i), + _MapItemState.WAITING_FOR_OUTPUT, + 2, + f"in-{i}", + InputJwtData.of(i, 1), + f"{i}".encode(), + ) + + # Second retry succeeds + + # retry_inputs - retry_count 2 + await prepare_items_for_retry(2) + await clear_retry_queue() + + # get_all_outputs - retry_count 2 + # The retry call has not returned yet, so there are not input_jwts waiting for output. + assert manager.get_input_jwts_waiting_for_output() == [] + await handle_get_outputs_response(result_success, _MapItemState.COMPLETE, 2, True) + + # retry_inputs - retry_count 2 - handle response + response_items = [InputJwtData.of(i, 2).to_jwt() for i in range(count)] + manager.handle_retry_response(response_items) + assert len(manager) == 0 -# TODO(ryan): Add tests for: -# - Ensure duplicate outputs are ignored -# - If before a call to PutInputs returns, we completely process the output, -# make sure we don't put context for it in the manager. +# TODO: Add test where we try to retry an item before PutInputs has returned. This will test +# that we await the input_jwt correctly before retrying.