Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dill for serialization #121

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,20 +203,24 @@ async def transform(msg):

### Serialization

Dispatch uses the [pickle] library to serialize coroutines.
Serialization of coroutines is enabled by a CPython extension that
exposes internal details about stack frames.

[pickle]: https://docs.python.org/3/library/pickle.html
Dispatch then uses the [dill] library to serialize these stack frames.
Note that `dill` is an extension of [pickle] from the standard library.

Serialization of coroutines is enabled by a CPython extension.
[dill]: https://dill.readthedocs.io/en/latest/
[pickle]: https://docs.python.org/3/library/pickle.html

The user must ensure that the contents of their stack frames are
serializable. That is, users should avoid using variables inside
coroutines that cannot be pickled.

If a pickle error is encountered, serialization tracing can be enabled
with the `DISPATCH_TRACE=1` environment variable to debug the issue. The
stacks of coroutines and generators will be printed to stdout before
the pickle library attempts serialization.
serializable. That is, users should either avoid using variables inside
coroutines that cannot be pickled, or should wrap them in a container
that is serializable.

If a serialization error is encountered, tracing can be enabled with the
`DISPATCH_TRACE=1` environment variable. The object graph, and the stacks
of coroutines and generators, will be printed to stdout before serialization
is attempted. This allows users to pinpoint where serialization issues occur.

For help with a serialization issues, please submit a [GitHub issue][issues].

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"grpc-stubs >= 1.53.0.5",
"http-message-signatures >= 0.4.4",
"tblib >= 3.0.0",
"dill >= 0.3.8"
]

[project.optional-dependencies]
Expand Down
65 changes: 50 additions & 15 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
import os
import pickle
import sys
from dataclasses import dataclass
from typing import Any, Callable, Protocol, TypeAlias

import dill # type: ignore
import dill.detect # type: ignore

from dispatch.coroutine import Gather
from dispatch.error import IncompatibleStateError
from dispatch.experimental.durable.function import DurableCoroutine, DurableGenerator
Expand Down Expand Up @@ -144,9 +148,12 @@ def __repr__(self):
class State:
"""State of the scheduler and the coroutines it's managing."""

version: str
interpreter_version: str
scheduler_version: int

suspended: dict[CoroutineID, Coroutine]
ready: list[Coroutine]

next_coroutine_id: int
next_call_id: int

Expand All @@ -155,6 +162,11 @@ class State:
outstanding_calls: int


# Version of the scheduler and its state. Increment this when a breaking
# change is introduced.
SCHEDULER_VERSION = 1


class OneShotScheduler:
"""Scheduler for local coroutines.

Expand All @@ -165,7 +177,7 @@ class OneShotScheduler:

__slots__ = (
"entry_point",
"version",
"interpreter_version",
"poll_min_results",
"poll_max_results",
"poll_max_wait_seconds",
Expand All @@ -174,7 +186,7 @@ class OneShotScheduler:
def __init__(
self,
entry_point: Callable,
version: str = sys.version,
interpreter_version: str = sys.version,
poll_min_results: int = 1,
poll_max_results: int = 10,
poll_max_wait_seconds: int | None = None,
Expand All @@ -184,9 +196,9 @@ def __init__(
Args:
entry_point: Entry point for the main coroutine.

version: Version string to attach to scheduler/coroutine state.
If the scheduler sees a version mismatch, it will respond to
Dispatch with an INCOMPATIBLE_STATE status code.
interpreter_version: Version string to attach to scheduler /
coroutine state. If the scheduler sees a version mismatch it will
respond to Dispatch with an INCOMPATIBLE_STATE status code.

poll_min_results: Minimum number of call results to wait for before
coroutine execution should continue. Dispatch waits until this
Expand All @@ -200,14 +212,15 @@ def __init__(
while waiting for call results. Optional.
"""
self.entry_point = entry_point
self.version = version
self.interpreter_version = interpreter_version
self.poll_min_results = poll_min_results
self.poll_max_results = poll_max_results
self.poll_max_wait_seconds = poll_max_wait_seconds
logger.debug(
"booting coroutine scheduler with entry point '%s' version '%s'",
"booting coroutine scheduler with entry point '%s', interpreter version '%s', scheduler version %d",
entry_point.__qualname__,
version,
self.interpreter_version,
SCHEDULER_VERSION,
)

def run(self, input: Input) -> Output:
Expand All @@ -231,7 +244,8 @@ def _init_state(self, input: Input) -> State:
raise ValueError("entry point is not a @dispatch.function")

return State(
version=sys.version,
interpreter_version=sys.version,
scheduler_version=SCHEDULER_VERSION,
suspended={},
ready=[Coroutine(id=0, parent_id=None, coroutine=main)],
next_coroutine_id=1,
Expand All @@ -245,15 +259,22 @@ def _rebuild_state(self, input: Input):
"resuming scheduler with %d bytes of state", len(input.coroutine_state)
)
try:
state = pickle.loads(input.coroutine_state)
state = deserialize(input.coroutine_state)
if not isinstance(state, State):
raise ValueError("invalid state")
if state.version != self.version:

if state.interpreter_version != self.interpreter_version:
raise ValueError(
f"interpreter version mismatch: '{state.interpreter_version}' vs. current '{self.interpreter_version}'"
)
if state.scheduler_version != SCHEDULER_VERSION:
raise ValueError(
f"version mismatch: '{state.version}' vs. current '{self.version}'"
f"scheduler version mismatch: {state.scheduler_version} vs. current {SCHEDULER_VERSION}"
)

return state
except (pickle.PickleError, ValueError) as e:

except (pickle.PickleError, AttributeError, ValueError) as e:
logger.warning("state is incompatible", exc_info=True)
raise IncompatibleStateError from e

Expand Down Expand Up @@ -421,7 +442,7 @@ def _run(self, input: Input) -> Output:
# Serialize coroutines and scheduler state.
logger.debug("serializing state")
try:
serialized_state = pickle.dumps(state)
serialized_state = serialize(state)
except pickle.PickleError as e:
logger.exception("state could not be serialized")
return Output.error(Error.from_exception(e, status=Status.PERMANENT_ERROR))
Expand All @@ -446,6 +467,20 @@ def _run(self, input: Input) -> Output:
)


TRACE = os.getenv("DISPATCH_TRACE")


def serialize(obj: Any) -> bytes:
if TRACE:
with dill.detect.trace():
return dill.dumps(obj, byref=True)
return dill.dumps(obj, byref=True)


def deserialize(state: bytes) -> Any:
return dill.loads(state)


def correlation_id(coroutine_id: CoroutineID, call_id: CallID) -> CorrelationID:
return coroutine_id << 32 | call_id

Expand Down