diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index bba1fed06..39b2e26a0 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -2,6 +2,7 @@ # ruff: noqa: E402 import os +from modal._runtime import execution_context from modal._runtime.user_code_imports import Service, import_class_service, import_single_function_service telemetry_socket = os.environ.get("MODAL_TELEMETRY_SOCKET") @@ -428,21 +429,22 @@ def main(container_args: api_pb2.ContainerArguments, client: Client): param_args = () param_kwargs = {} - if function_def.is_class: - service = import_class_service( - function_def, - ser_cls, - param_args, - param_kwargs, - ) - else: - service = import_single_function_service( - function_def, - ser_cls, - ser_fun, - param_args, - param_kwargs, - ) + with execution_context._import_context(): + if function_def.is_class: + service = import_class_service( + function_def, + ser_cls, + param_args, + param_kwargs, + ) + else: + service = import_single_function_service( + function_def, + ser_cls, + ser_fun, + param_args, + param_kwargs, + ) # If the cls/function decorator was applied in local scope, but the app is global, we can look it up if service.app is not None: diff --git a/modal/_runtime/execution_context.py b/modal/_runtime/execution_context.py index 95f480699..d3592ac41 100644 --- a/modal/_runtime/execution_context.py +++ b/modal/_runtime/execution_context.py @@ -1,4 +1,5 @@ # Copyright Modal Labs 2024 +from contextlib import contextmanager from contextvars import ContextVar from typing import Callable, Optional @@ -87,3 +88,15 @@ def _reset_current_context_ids(): _current_input_id: ContextVar = ContextVar("_current_input_id") _current_function_call_id: ContextVar = ContextVar("_current_function_call_id") + +_is_currently_importing = False # we set this to True while a container is importing user code + + +@contextmanager +def _import_context(): + global _is_currently_importing + _is_currently_importing = True + try: + yield + finally: + _is_currently_importing = False diff --git a/modal/runner.py b/modal/runner.py index 7c773d892..3b1b92ae4 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -12,6 +12,7 @@ from grpclib import GRPCError, Status from synchronicity.async_wrap import asynccontextmanager +import modal._runtime.execution_context import modal_proto.api_pb2 from modal_proto import api_pb2 @@ -19,7 +20,6 @@ from ._object import _get_environment_name, _Object from ._pty import get_pty_info from ._resolver import Resolver -from ._runtime.execution_context import is_local from ._traceback import print_server_warnings, traceback_contains_remote_call from ._utils.async_utils import TaskContext, gather_cancel_on_exc, synchronize_api from ._utils.deprecation import deprecation_error @@ -262,12 +262,9 @@ async def _run_app( if environment_name is None: environment_name = typing.cast(str, config.get("environment")) - if not is_local(): - raise InvalidError( - "Can not run an app from within a container." - " Are you calling app.run() directly?" - " Consider using the `modal run` shell command." - ) + if modal._runtime.execution_context._is_currently_importing: + raise InvalidError("Can not run an app in global scope within a container") + if app._running_app: raise InvalidError( "App is already running and can't be started again.\n" diff --git a/test/container_test.py b/test/container_test.py index c153b83a4..6beb20617 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -536,11 +536,8 @@ def test_grpc_failure(servicer, event_loop): def test_missing_main_conditional(servicer, capsys): _run_container(servicer, "test.supports.missing_main_conditional", "square") output = capsys.readouterr() - assert "Can not run an app from within a container" in output.err - + assert "Can not run an app in global scope within a container" in output.err assert servicer.task_result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE - assert "modal run" in servicer.task_result.traceback - exc = deserialize(servicer.task_result.data, None) assert isinstance(exc, InvalidError)