diff --git a/projects/fal/src/fal/api.py b/projects/fal/src/fal/api.py index 607506dd..3b96fcca 100644 --- a/projects/fal/src/fal/api.py +++ b/projects/fal/src/fal/api.py @@ -76,9 +76,7 @@ class Host(Generic[ArgsT, ReturnT]): is executed.""" _SUPPORTED_KEYS: ClassVar[frozenset[str]] = frozenset() - _GATEWAY_KEYS: ClassVar[frozenset[str]] = frozenset( - {"serve", "exposed_port", "max_concurrency"} - ) + _GATEWAY_KEYS: ClassVar[frozenset[str]] = frozenset({"serve", "exposed_port"}) def __post_init__(self): assert not self._SUPPORTED_KEYS.intersection( @@ -118,7 +116,6 @@ def register( self, func: Callable[ArgsT, ReturnT], options: Options, - max_concurrency: int | None = None, application_name: str | None = None, application_auth_mode: Literal["public", "shared", "private"] | None = None, metadata: dict[str, Any] | None = None, @@ -311,6 +308,7 @@ class FalServerlessHost(Host): { "machine_type", "keep_alive", + "max_concurrency", "max_multiplexing", "setup_function", "metadata", @@ -341,7 +339,6 @@ def register( self, func: Callable[ArgsT, ReturnT], options: Options, - max_concurrency: int | None = None, application_name: str | None = None, application_auth_mode: Literal["public", "shared", "private"] | None = None, metadata: dict[str, Any] | None = None, @@ -354,9 +351,8 @@ def register( "machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE ) keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE) - max_multiplexing = options.host.get( - "max_multiplexing", FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING - ) + max_concurrency = options.host.get("max_concurrency") + max_multiplexing = options.host.get("max_multiplexing") base_image = options.host.get("_base_image", None) scheduler = options.host.get("_scheduler", None) scheduler_options = options.host.get("_scheduler_options", None) @@ -370,6 +366,7 @@ def register( scheduler=scheduler, scheduler_options=scheduler_options, max_multiplexing=max_multiplexing, + max_concurrency=max_concurrency, ) partial_func = _prepare_partial_func(func) @@ -394,7 +391,6 @@ def register( application_name=application_name, application_auth_mode=application_auth_mode, machine_requirements=machine_requirements, - max_concurrency=max_concurrency, metadata=metadata, ): for log in partial_result.logs: @@ -419,9 +415,8 @@ def run( "machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE ) keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE) - max_multiplexing = options.host.get( - "max_multiplexing", FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING - ) + max_concurrency = options.host.get("max_concurrency") + max_multiplexing = options.host.get("max_multiplexing") base_image = options.host.get("_base_image", None) scheduler = options.host.get("_scheduler", None) scheduler_options = options.host.get("_scheduler_options", None) @@ -436,6 +431,7 @@ def run( scheduler=scheduler, scheduler_options=scheduler_options, max_multiplexing=max_multiplexing, + max_concurrency=max_concurrency, ) return_value = _UNSET diff --git a/projects/fal/src/fal/cli.py b/projects/fal/src/fal/cli.py index dd3c2ee8..38709dd9 100644 --- a/projects/fal/src/fal/cli.py +++ b/projects/fal/src/fal/cli.py @@ -17,7 +17,7 @@ from fal.logging.isolate import IsolateLogPrinter from fal.logging.trace import get_tracer from fal.rest_client import REST_CLIENT -from fal.sdk import KeyScope +from fal.sdk import AliasInfo, KeyScope from isolate.logs import Log, LogLevel, LogSource from rich.table import Table @@ -280,13 +280,11 @@ def register_application( "Must expose port 8080 for now. This will be configurable in the future." ) - max_concurrency = gateway_options.get("max_concurrency") id = host.register( func=isolated_function.func, options=isolated_function.options, application_name=alias, application_auth_mode=auth_mode, - max_concurrency=max_concurrency, metadata={}, ) @@ -342,57 +340,77 @@ def alias_cli(ctx, host: str, port: str): ctx.obj = api.FalServerlessClient(f"{host}:{port}") +def _alias_table(aliases: list[AliasInfo]): + table = Table(title="Function Aliases") + table.add_column("Alias") + table.add_column("Revision") + table.add_column("Auth") + table.add_column("Max Concurrency") + table.add_column("Max Multiplexing") + table.add_column("Keep Alive") + + for app_alias in aliases: + table.add_row( + app_alias.alias, + app_alias.revision, + app_alias.auth_mode, + str(app_alias.max_concurrency), + str(app_alias.max_multiplexing), + str(app_alias.keep_alive), + ) + + return table + + @alias_cli.command("list") @click.pass_obj def alias_list(client: api.FalServerlessClient): with client.connect() as connection: - table = Table(title="Function Aliases") - table.add_column("Alias") - table.add_column("Revision") - table.add_column("Auth") - table.add_column("Max Concurrency") - - for app_alias in connection.list_aliases(): - table.add_row( - app_alias.alias, - app_alias.revision, - app_alias.auth_mode, - str(app_alias.max_concurrency), - ) + aliases = connection.list_aliases() + table = _alias_table(aliases) console.print(table) -@alias_cli.command("scale") -@click.argument("alias", required=True) -@click.argument("max_concurrency", required=True, type=int) -@click.pass_obj -def alias_scale(client: api.FalServerlessClient, alias: str, max_concurrency: int): - with client.connect() as connection: - connection.scale(application_name=alias, max_concurrency=max_concurrency) - - @alias_cli.command("update") @click.argument("alias", required=True) -@click.option("--keep-alive", type=int) -@click.option("--max-multiplexing", type=int) +@click.option("--keep-alive", "-k", type=int) +@click.option("--max-multiplexing", "-m", type=int) +@click.option("--max-concurrency", "-c", type=int) @click.pass_obj def alias_update( client: api.FalServerlessClient, alias: str, keep_alive: int | None, max_multiplexing: int | None, + max_concurrency: int | None, ): with client.connect() as connection: - if not (keep_alive or max_multiplexing): + if keep_alive is None and max_multiplexing is None and max_concurrency is None: console.log("No parameters for update were provided, ignoring.") return - connection.update_application( + alias_info = connection.update_application( application_name=alias, keep_alive=keep_alive, max_multiplexing=max_multiplexing, + max_concurrency=max_concurrency, ) + table = _alias_table([alias_info]) + + console.print(table) + + +@alias_cli.command("scale") +@click.argument("alias", required=True) +@click.argument("max_concurrency", required=True, type=int) +def alias_scale(alias: str, max_concurrency: int): + alias_update.callback( + alias=alias, + keep_alive=None, + max_multiplexing=None, + max_concurrency=max_concurrency, + ) # type: ignore ##### Secrets group ##### diff --git a/projects/fal/src/fal/sdk.py b/projects/fal/src/fal/sdk.py index 9a89f268..7382e7e8 100644 --- a/projects/fal/src/fal/sdk.py +++ b/projects/fal/src/fal/sdk.py @@ -184,7 +184,9 @@ class AliasInfo: alias: str revision: str auth_mode: str + keep_alive: int max_concurrency: int + max_multiplexing: int @dataclass @@ -258,7 +260,9 @@ def _from_grpc_alias_info(message: isolate_proto.AliasInfo) -> AliasInfo: alias=message.alias, revision=message.revision, auth_mode=auth_mode, + keep_alive=message.keep_alive, max_concurrency=message.max_concurrency, + max_multiplexing=message.max_multiplexing, ) @@ -306,7 +310,8 @@ class MachineRequirements: exposed_port: int | None = None scheduler: str | None = None scheduler_options: dict[str, Any] | None = None - max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING + max_concurrency: int | None = None + max_multiplexing: int | None = None @dataclass @@ -386,7 +391,6 @@ def register( application_name: str | None = None, application_auth_mode: Literal["public", "private", "shared"] | None = None, *, - max_concurrency: int | None = None, serialization_method: str = _DEFAULT_SERIALIZATION_METHOD, machine_requirements: MachineRequirements | None = None, metadata: dict[str, Any] | None = None, @@ -402,6 +406,7 @@ def register( scheduler_options=to_struct( machine_requirements.scheduler_options or {} ), + max_concurrency=machine_requirements.max_concurrency, max_multiplexing=machine_requirements.max_multiplexing, ) else: @@ -423,7 +428,6 @@ def register( function=wrapped_function, environments=environments, machine_requirements=wrapped_requirements, - max_concurrency=max_concurrency, application_name=application_name, auth_mode=auth_mode, metadata=struct_metadata, @@ -432,24 +436,25 @@ def register( yield from_grpc(partial_result) def scale(self, application_name: str, max_concurrency: int | None = None) -> None: - request = isolate_proto.ScaleApplicationRequest( - application_name=application_name, - max_concurrency=max_concurrency, - ) - self.stub.ScaleApplication(request) + raise NotImplementedError def update_application( self, application_name: str, keep_alive: int | None = None, max_multiplexing: int | None = None, - ) -> None: + max_concurrency: int | None = None, + ) -> AliasInfo: request = isolate_proto.UpdateApplicationRequest( application_name=application_name, keep_alive=keep_alive, max_multiplexing=max_multiplexing, + max_concurrency=max_concurrency, + ) + res: isolate_proto.UpdateApplicationResult = self.stub.UpdateApplication( + request ) - self.stub.UpdateApplication(request) + return from_grpc(res.alias_info) def run( self, @@ -471,6 +476,7 @@ def run( scheduler_options=to_struct( machine_requirements.scheduler_options or {} ), + max_concurrency=machine_requirements.max_concurrency, max_multiplexing=machine_requirements.max_multiplexing, ) else: diff --git a/projects/fal/tests/test_apps.py b/projects/fal/tests/test_apps.py index df9e057c..dc5783f7 100644 --- a/projects/fal/tests/test_apps.py +++ b/projects/fal/tests/test_apps.py @@ -29,6 +29,7 @@ class Output(BaseModel): keep_alive=60, machine_type="S", serve=True, + max_concurrency=1, ) def addition_app(input: Input) -> Output: print("starting...") @@ -41,7 +42,6 @@ def addition_app(input: Input) -> Output: app_alias = addition_app.host.register( func=addition_app.func, options=addition_app.options, - max_concurrency=1, ) user_id = _get_user_id() yield f"{user_id}-{app_alias}"