diff --git a/modal/sandbox.py b/modal/sandbox.py index ea460a4e7..e5f575673 100644 --- a/modal/sandbox.py +++ b/modal/sandbox.py @@ -23,6 +23,7 @@ from ._utils.deprecation import deprecation_error from ._utils.grpc_utils import retry_transient_errors from ._utils.mount_utils import validate_network_file_systems, validate_volumes +from .app import _App from .client import _Client from .config import config from .container_process import _ContainerProcess @@ -58,6 +59,7 @@ class _Sandbox(_Object, type_prefix="sb"): _stdin: _StreamWriter _task_id: Optional[str] = None _tunnels: Optional[dict[int, Tunnel]] = None + _enable_snapshot: bool = False @staticmethod def _new( @@ -81,6 +83,7 @@ def _new( unencrypted_ports: Sequence[int] = [], proxy: Optional[_Proxy] = None, _experimental_scheduler_placement: Optional[SchedulerPlacement] = None, + enable_snapshot: bool = False, ) -> "_Sandbox": """mdmd:hidden""" @@ -177,6 +180,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona open_ports=api_pb2.PortSpecs(ports=open_ports), network_access=network_access, proxy_id=(proxy.object_id if proxy else None), + enable_snapshot=enable_snapshot, ) # Note - `resolver.app_id` will be `None` for app-less sandboxes @@ -224,13 +228,13 @@ async def create( unencrypted_ports: Sequence[int] = [], # Reference to a Modal Proxy to use in front of this Sandbox. proxy: Optional[_Proxy] = None, + # Enable memory snapshots. + enable_snapshot: bool = False, _experimental_scheduler_placement: Optional[ SchedulerPlacement ] = None, # Experimental controls over fine-grained scheduling (alpha). client: Optional[_Client] = None, ) -> "_Sandbox": - from .app import _App - environment_name = _get_environment_name(environment_name) # If there are no entrypoint args, we'll sleep forever so that the sandbox will stay @@ -261,7 +265,9 @@ async def create( unencrypted_ports=unencrypted_ports, proxy=proxy, _experimental_scheduler_placement=_experimental_scheduler_placement, + enable_snapshot=enable_snapshot, ) + obj._enable_snapshot = enable_snapshot app_id: Optional[str] = None app_client: Optional[_Client] = None @@ -534,6 +540,33 @@ async def exec( by_line = bufsize == 1 return _ContainerProcess(resp.exec_id, self._client, stdout=stdout, stderr=stderr, text=text, by_line=by_line) + async def snapshot(self) -> str: + if not self._enable_snapshot: + raise ValueError( + "Memory snapshots are not supported for this sandbox. To enable memory snapshots, " + "set `enable_snapshot=True` when creating the sandbox." + ) + req = api_pb2.SandboxSnapshotRequest(sandbox_id=self.object_id) + resp = await retry_transient_errors(self._client.stub.SandboxSnapshot, req) + snapshot_id = resp.snapshot_id + wait_req = api_pb2.SandboxSnapshotWaitRequest(snapshot_id=resp.snapshot_id, timeout=55.0) + resp = await retry_transient_errors(self._client.stub.SandboxSnapshotWait, wait_req) + if resp.result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS: + raise ExecutionError(resp.result.exception) + return snapshot_id + + @staticmethod + async def from_snapshot(snapshot_id: str, client: Optional[_Client] = None): + client = client or await _Client.from_env() + + req = api_pb2.SandboxRestoreRequest(snapshot_id=snapshot_id) + resp: api_pb2.SandboxRestoreResponse = await retry_transient_errors(client.stub.SandboxRestore, req) + sandbox = await _Sandbox.from_id(resp.sandbox_id, client) + wait_req = api_pb2.SandboxWaitRequest(sandbox_id=resp.sandbox_id, timeout=0) + resp = await retry_transient_errors(client.stub.SandboxWait, wait_req) + print("from_snapshot resp", resp) + return sandbox + @overload async def open( self, diff --git a/modal_proto/api.proto b/modal_proto/api.proto index ec04754a3..38686a281 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -2318,9 +2318,7 @@ message SandboxListResponse { } message SandboxRestoreRequest { - string app_id = 1 [ (modal.options.audit_target_attr) = true ]; - string snapshot_id = 2; - string environment_name = 3; + string snapshot_id = 1; } message SandboxRestoreResponse {