Skip to content

Commit

Permalink
Update structure of context and manager
Browse files Browse the repository at this point in the history
  • Loading branch information
rculbertson committed Feb 12, 2025
1 parent 7150ba2 commit d21a467
Show file tree
Hide file tree
Showing 7 changed files with 447 additions and 365 deletions.
3 changes: 3 additions & 0 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ async def get(self) -> Union[T, None]:
def empty(self) -> bool:
return self._queue.empty()

def __len__(self):
return self._queue.qsize()


async def queue_batch_iterator(
q: Union[asyncio.Queue, TimestampPriorityQueue], max_batch_size=100, debounce_time=0.015
Expand Down
194 changes: 111 additions & 83 deletions modal/parallel_map.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright Modal Labs 2024
import asyncio
import enum
import time
import typing
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Optional

from grpclib import GRPCError, Status
Expand All @@ -30,6 +30,7 @@
_process_result,
)
from modal._utils.grpc_utils import retry_transient_errors
from modal._utils.jwt_utils import DecodedJwt
from modal.config import logger
from modal.retries import RetryManager
from modal_proto import api_pb2
Expand Down Expand Up @@ -169,7 +170,9 @@ async def retry_inputs():
async for retriable_idxs in queue_batch_iterator(retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
# For each index, use the context in the manager to create a FunctionRetryInputsItem.
# This will also update the context state to RETRYING.
inputs: list[api_pb2.FunctionRetryInputsItem] = map_items_manager.get_items_for_retry(retriable_idxs)
inputs: list[api_pb2.FunctionRetryInputsItem] = await map_items_manager.prepare_items_for_retry(
retriable_idxs
)
request = api_pb2.FunctionRetryInputsRequest(
function_call_jwt=function_call_jwt,
inputs=inputs,
Expand Down Expand Up @@ -209,11 +212,10 @@ async def get_all_outputs():
last_entry_id = "0-0"

while not have_all_inputs or num_outputs < num_inputs:
await asyncio.sleep(3)
logger.debug(f"Requesting outputs. Have {num_outputs} outputs, {num_inputs} inputs.")
# Get input_jwts of all items in the WAITING_FOR_OUTPUT state.
# The server uses these to track for lost inputs.
input_jwts = [await ctx.input_jwt for ctx in map_items_manager.get_items_waiting_for_output()]
input_jwts = [input_jwt for input_jwt in map_items_manager.get_input_jwts_waiting_for_output()]

request = api_pb2.FunctionGetOutputsRequest(
function_call_id=function_call_id,
Expand Down Expand Up @@ -490,7 +492,7 @@ def main():
)


class _MapItemState(Enum):
class _MapItemState(enum.Enum):
# The input is being sent the server with a PutInputs request, but the response has not been received yet.
SENDING = 1
# A call to either PutInputs or FunctionRetry has completed, and we are waiting to receive the output.
Expand All @@ -499,6 +501,8 @@ class _MapItemState(Enum):
WAITING_TO_RETRY = 3
# The input is being sent to the server with a FunctionRetry request, but the response has not been received yet.
RETRYING = 4
# The output has been received and was either successful, or failed with no more retries remaining.
COMPLETE = 5


class _MapItemContext:
Expand All @@ -524,31 +528,81 @@ def __init__(self, input: api_pb2.FunctionInput, retry_manager: RetryManager):
self.input_jwt = self._event_loop.create_future()
self.input_id = self._event_loop.create_future()

def set_state_waiting_for_output(self, input_id: str, input_jwt: str):
assert self.state == _MapItemState.SENDING, self.state
self.input_jwt.set_result(input_jwt)
self.input_id.set_result(input_id)
self.state = _MapItemState.WAITING_FOR_OUTPUT
def handle_put_inputs_response(self, item: api_pb2.FunctionPutInputsResponseItem):
self.input_jwt.set_result(item.input_jwt)
self.input_id.set_result(item.input_id)
# Set state to WAITING_FOR_OUTPUT only if current state is SENDING. If state is
# RETRYING, WAITING_TO_RETRY, or COMPLETE, then we already got the output.
if self.state == _MapItemState.SENDING:
self.state = _MapItemState.WAITING_FOR_OUTPUT

def set_state_waiting_for_output_after_retry(self, input_jwt: str):
assert self.state == _MapItemState.RETRYING, self.state
self.input_jwt.set_result(input_jwt)
self.state = _MapItemState.WAITING_FOR_OUTPUT
async def handle_get_outputs_response(
self,
item: api_pb2.FunctionGetOutputsItem,
now_seconds: int,
function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
retry_queue: TimestampPriorityQueue,
) -> bool:
"""
Processes the output, and determines if it is complete or needs to be retried.
Return True if input state was changed to COMPLETE, otherwise False.
"""
# If the item is already complete, this is a duplicate output and can be ignored.
# If the item's retry count doesn't match our retry count, this is probably an old output.
if self.state == _MapItemState.COMPLETE or item.retry_count != self.retry_manager.retry_count:
# We've already processed this output, so we can skip it.
# This can happen because the worker can sometimes send duplicate outputs.
return False

# retry failed inputs when the function call invocation type is SYNC
if (
item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
or function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
):
self.state = _MapItemState.COMPLETE
return True

# Get the retry delay and increment the retry count.
# TODO(ryan): We must call this for lost inputs - even though we will set the retry delay to 0 later -
# because we must increment the retry count. That's awkward, let's come up with something better.
# TODO(ryan):To maintain paritiy with server-side retries, retrying lost inputs should not count towards
# the retry policy. However we use the retry_count number as a unique identifier on each attempt to:
# 1) ignore duplicate outputs
# 2) ignore late outputs received from previous attempts
# 3) avoid a server race condition between FunctionRetry and GetOutputs that results in deleted input metadata
# For now, lost inputs will count towards the retry policy. But let's address this in another PR, perhaps by
# tracking total attempts and attempts which count towards the retry policy separately.
delay_ms = self.retry_manager.get_delay_ms()

# For system failures on the server, we retry immediately.
# and the failure does not count towards the retry policy.
if item.result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE:
delay_ms = 0

# None means the maximum number of retries has been reached, so output the error
if delay_ms is None:
self.state = _MapItemState.COMPLETE
return True

def set_state_waiting_for_retry(self):
assert self.state == _MapItemState.WAITING_FOR_OUTPUT, self.state
# When we call FunctionRetry, we pass the input_jwt from the previous request,
# which is either the original call to PutInputs, or a previous call to FunctionRetry.
# FunctionRetry then returnd a new input_jwt. Each retry produces a new input_jwt
# because it contains an entry_id which changes on every retry.
self.previous_input_jwt = self.input_jwt.result()
# We reset the input_jwt to a new future so it is ready when get_all_outputs awaits it.
self.input_jwt = self._event_loop.create_future()
self.state = _MapItemState.WAITING_TO_RETRY
await retry_queue.put(now_seconds + (delay_ms / 1000), item.idx)

return False

def set_state_retrying(self):
assert self.state == _MapItemState.WAITING_TO_RETRY, self.state
async def prepare_item_for_retry(self) -> api_pb2.FunctionRetryInputsItem:
self.state = _MapItemState.RETRYING
input_jwt = await self.input_jwt
self.input_jwt = self._event_loop.create_future()
return api_pb2.FunctionRetryInputsItem(
input_jwt=input_jwt,
input=self.input,
retry_count=self.retry_manager.retry_count,
)

def handle_retry_response(self, input_jwt: str):
self.input_jwt.set_result(input_jwt)
self.state = _MapItemState.WAITING_FOR_OUTPUT


class _MapItemsManager:
Expand All @@ -574,84 +628,58 @@ async def add_items(self, items: list[api_pb2.FunctionPutInputsItem]):
input=item.input, retry_manager=RetryManager(self._retry_policy)
)

def get_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]:
items: api_pb2.FunctionRetryInputsItem = []
for retriable_idx in retriable_idxs:
ctx = self._item_context[retriable_idx]
ctx.set_state_retrying()
items.append(
api_pb2.FunctionRetryInputsItem(
input_jwt=ctx.previous_input_jwt,
input=ctx.input,
retry_count=ctx.retry_manager.attempt_count,
)
)
return items
async def prepare_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]:
return [await self._item_context[idx].prepare_item_for_retry() for idx in retriable_idxs]

def get_items_waiting_for_output(self) -> list[_MapItemContext]:
return [ctx for ctx in self._item_context.values() if ctx.state == _MapItemState.WAITING_FOR_OUTPUT]
def get_input_jwts_waiting_for_output(self) -> list[str]:
"""
Returns a list of input_jwts for inputs that are waiting for output.
"""
# If input_jwt is not done, the call to PutInputs has not completed, so omit it from results.
return [
ctx.input_jwt.result()
for ctx in self._item_context.values()
if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
]

def _remove_item(self, item_idx: int):
del self._item_context[item_idx]
self._inputs_outstanding.release()

def get_item_context(self, item_idx: int) -> _MapItemContext:
return self._item_context[item_idx]
return self._item_context.get(item_idx)

def handle_put_inputs_response(self, items: list[api_pb2.FunctionPutInputsResponseItem]):
for item in items:
ctx = self._item_context.get(item.idx, None)
# If the context is None, then get_all_outputs() has already
# received a successful output, and deleted the context.
# If the context is None, then get_all_outputs() has already received a successful
# output, and deleted the context. This happens if FunctionGetOutputs completes
# before FunctionPutInputsResponse is received.
if ctx is not None:
ctx.set_state_waiting_for_output(item.input_id, item.input_jwt)

def handle_retry_response(self, items: list[api_pb2.FunctionRetryInputsResponseItem]):
for item in items:
ctx = self._item_context[item.idx]
ctx.set_state_waiting_for_output_after_retry(input_jwt=item.input_jwt)
ctx.handle_put_inputs_response(item)

def handle_retry_response(self, input_jwts: list[str]):
for input_jwt in input_jwts:
decoded_jwt = DecodedJwt.decode_without_verification(input_jwt)
ctx = self._item_context.get(decoded_jwt.payload["idx"], None)
# If the context is None, then get_all_outputs() has already received a successful
# output, and deleted the context. This happens if FunctionGetOutputs completes
# before FunctionRetryInputsResponse is received.
if ctx is not None:
ctx.handle_retry_response(input_jwt)

async def handle_get_outputs_response(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> bool:
output_is_complete = await self._handle_output(item, now_seconds)
if output_is_complete:
self._remove_item(item.idx)
return output_is_complete

async def _handle_output(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> bool:
"""
Determines if an output is complete or needs to be retried.
If complete, we remove the input from the manager, and return True.
Otherwise we place it on the retry queue, and return False.
"""
ctx = self._item_context.get(item.idx, None)
if ctx is None:
# We've already processed this output, so we can skip it.
# This can happen because the worker can sometimes send duplicate outputs.
return False

# retry failed inputs when the function call invocation type is SYNC
if (
item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
or self.function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
):
return True

# For system failures on the server, we retry immediately,
# and the failure does not count towards the retry policy.
delay_ms = (
0
if item.result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE
else ctx.retry_manager.get_delay_ms()
output_is_complete = await ctx.handle_get_outputs_response(
item, now_seconds, self.function_call_invocation_type, self._retry_queue
)

# None means the maximum number of retries has been reached, so output the error
if delay_ms is None:
return True

ctx.set_state_waiting_for_retry()
await self._retry_queue.put(now_seconds + (delay_ms / 1000), item.idx)
return False
if output_is_complete:
self._remove_item(item.idx)
return output_is_complete

def __len__(self):
return len(self._item_context)
8 changes: 4 additions & 4 deletions modal/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,19 @@ class RetryManager:

def __init__(self, retry_policy: api_pb2.FunctionRetryPolicy):
self.retry_policy = retry_policy
self.attempt_count = 0
self.retry_count = 0

def get_delay_ms(self) -> Union[float, None]:
"""
Returns the delay in milliseconds before the next retry, or None
if the maximum number of retries has been reached.
"""
self.attempt_count += 1
self.retry_count += 1

if self.attempt_count > self.retry_policy.retries:
if self.retry_count > self.retry_policy.retries:
return None

return self._retry_delay_ms(self.attempt_count, self.retry_policy)
return self._retry_delay_ms(self.retry_count, self.retry_policy)

@staticmethod
def _retry_delay_ms(attempt_count: int, retry_policy: api_pb2.FunctionRetryPolicy) -> float:
Expand Down
10 changes: 1 addition & 9 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1674,16 +1674,8 @@ message FunctionRetryInputsRequest {
repeated FunctionRetryInputsItem inputs = 2;
}

// TODO(ryan): We shouldn't need to pass back the idx value, since the client can
// get it from the input_jwt. But I wasn't able to get the client to see authlib,
// so couldn't decode it.
message FunctionRetryInputsResponseItem {
uint32 idx = 1;
string input_jwt = 2;
}

message FunctionRetryInputsResponse {
repeated FunctionRetryInputsResponseItem items = 1;
repeated string input_jwts = 1;
}

message FunctionRetryPolicy {
Expand Down
12 changes: 4 additions & 8 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ async def FunctionMap(self, stream):
api_pb2.FunctionPutInputsResponseItem(
idx=self.fcidx,
input_id=input_id,
input_jwt=encode_input_jwt(self.fcidx, input_id, function_call_id),
input_jwt=encode_input_jwt(self.fcidx, input_id, function_call_id, self.next_entry_id()),
)
)

Expand All @@ -1089,7 +1089,7 @@ async def FunctionRetryInputs(self, stream):
request: api_pb2.FunctionRetryInputsRequest = await stream.recv_message()
function_id, function_call_id = decode_function_call_jwt(request.function_call_jwt)
function_call_inputs = self.client_calls.setdefault(function_call_id, [])
response_items = []
input_jwts = []
for item in request.inputs:
if item.input.WhichOneof("args_oneof") == "args":
args, kwargs = deserialize(item.input.args, None)
Expand All @@ -1098,13 +1098,9 @@ async def FunctionRetryInputs(self, stream):
self.n_inputs += 1
idx, input_id, function_call_id, entry_id = decode_input_jwt(item.input_jwt)
entry_id = self.next_entry_id()
response_items.append(
api_pb2.FunctionRetryInputsResponseItem(
idx=idx, input_jwt=encode_input_jwt(idx, input_id, function_call_id, entry_id)
)
)
input_jwts.append(encode_input_jwt(idx, input_id, function_call_id, entry_id))
function_call_inputs.append(((idx, input_id), (args, kwargs)))
await stream.send_message(api_pb2.FunctionRetryInputsResponse(items=response_items))
await stream.send_message(api_pb2.FunctionRetryInputsResponse(input_jwts=input_jwts))

async def FunctionPutInputs(self, stream):
request: api_pb2.FunctionPutInputsRequest = await stream.recv_message()
Expand Down
Loading

0 comments on commit d21a467

Please sign in to comment.