Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Simplify endpoint logic in sky.serve.controller and improve input validation #4043

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions sky/serve/autoscalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sky import sky_logging
from sky.serve import constants
from sky.serve import schemas
from sky.serve import serve_state
from sky.serve import serve_utils

Expand Down Expand Up @@ -97,7 +98,7 @@ def update_version(self, version: int, spec: 'service_spec.SkyServiceSpec',
self.update_mode = update_mode

def collect_request_information(
self, request_aggregator_info: Dict[str, Any]) -> None:
self, request_aggregator_info: schemas.RequestAggregator) -> None:
"""Collect request information from aggregator for autoscaling."""
raise NotImplementedError

Expand Down Expand Up @@ -222,17 +223,9 @@ def update_version(self, version: int, spec: 'service_spec.SkyServiceSpec',
self.downscale_counter = 0

def collect_request_information(
self, request_aggregator_info: Dict[str, Any]) -> None:
"""Collect request information from aggregator for autoscaling.

request_aggregator_info should be a dict with the following format:

{
'timestamps': [timestamp1 (float), timestamp2 (float), ...]
}
"""
self.request_timestamps.extend(
request_aggregator_info.get('timestamps', []))
self, request_aggregator_info: schemas.RequestAggregator) -> None:
"""Collect request information from aggregator for autoscaling."""
self.request_timestamps.extend(request_aggregator_info.timestamps)
current_time = time.time()
index = bisect.bisect_left(self.request_timestamps,
current_time - self.qps_window_size)
Expand Down
89 changes: 39 additions & 50 deletions sky/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import threading
import time
import traceback
from typing import Any, Dict, List

import fastapi
from fastapi import responses
Expand All @@ -17,6 +16,7 @@
from sky import sky_logging
from sky.serve import autoscalers
from sky.serve import replica_managers
from sky.serve import schemas
from sky.serve import serve_state
from sky.serve import serve_utils
from sky.utils import common_utils
Expand Down Expand Up @@ -98,12 +98,10 @@ def run(self) -> None:

@self._app.post('/controller/load_balancer_sync')
async def load_balancer_sync(
request: fastapi.Request) -> fastapi.Response:
request_data = await request.json()
# TODO(MaoZiming): Check aggregator type.
request_aggregator: Dict[str, Any] = request_data.get(
'request_aggregator', {})
timestamps: List[int] = request_aggregator.get('timestamps', [])
request: schemas.LoadBalancerRequest) -> fastapi.Response:
request_aggregator = request.request_aggregator
andylizf marked this conversation as resolved.
Show resolved Hide resolved
timestamps = request_aggregator.timestamps

logger.info(f'Received {len(timestamps)} inflight requests.')
self._autoscaler.collect_request_information(request_aggregator)
return responses.JSONResponse(content={
Expand All @@ -113,49 +111,40 @@ async def load_balancer_sync(
status_code=200)

@self._app.post('/controller/update_service')
async def update_service(request: fastapi.Request) -> fastapi.Response:
request_data = await request.json()
try:
version = request_data.get('version', None)
if version is None:
return responses.JSONResponse(
content={'message': 'Error: version is not specified.'},
status_code=400)
update_mode_str = request_data.get(
'mode', serve_utils.DEFAULT_UPDATE_MODE.value)
update_mode = serve_utils.UpdateMode(update_mode_str)
logger.info(f'Update to new version {version} with '
f'update_mode {update_mode}.')
# The yaml with the name latest_task_yaml will be synced
# See sky/serve/core.py::update
latest_task_yaml = serve_utils.generate_task_yaml_file_name(
self._service_name, version)
service = serve.SkyServiceSpec.from_yaml(latest_task_yaml)
logger.info(
f'Update to new version version {version}: {service}')

self._replica_manager.update_version(version,
service,
update_mode=update_mode)
new_autoscaler = autoscalers.Autoscaler.from_spec(
self._service_name, service)
if not isinstance(self._autoscaler, type(new_autoscaler)):
logger.info('Autoscaler type changed to '
f'{type(new_autoscaler)}, updating autoscaler.')
old_autoscaler = self._autoscaler
self._autoscaler = new_autoscaler
self._autoscaler.load_dynamic_states(
old_autoscaler.dump_dynamic_states())
self._autoscaler.update_version(version,
service,
update_mode=update_mode)
return responses.JSONResponse(content={'message': 'Success'},
status_code=200)
except Exception as e: # pylint: disable=broad-except
logger.error(f'Error in update_service: '
f'{common_utils.format_exception(e)}')
return responses.JSONResponse(content={'message': 'Error'},
status_code=500)
async def update_service(
request: schemas.UpdateServiceRequest) -> fastapi.Response:
andylizf marked this conversation as resolved.
Show resolved Hide resolved
version = request.version
update_mode = request.mode
logger.info(f'Update to new version {version} '
f'with update_mode {update_mode}.')

# The yaml with the name latest_task_yaml will be synced
# See sky/serve/core.py::update
latest_task_yaml = serve_utils.generate_task_yaml_file_name(
self._service_name, version)
service = serve.SkyServiceSpec.from_yaml(latest_task_yaml)
logger.info(f'Update to new version {version}: {service}')

self._replica_manager.update_version(version,
service,
update_mode=update_mode)

new_autoscaler = autoscalers.Autoscaler.from_spec(
self._service_name, service)
if not isinstance(self._autoscaler, type(new_autoscaler)):
logger.info(f'Autoscaler type changed to '
f'{type(new_autoscaler)}, '
f'updating autoscaler.')
andylizf marked this conversation as resolved.
Show resolved Hide resolved
old_autoscaler = self._autoscaler
self._autoscaler = new_autoscaler
self._autoscaler.load_dynamic_states(
old_autoscaler.dump_dynamic_states())
self._autoscaler.update_version(version,
service,
update_mode=update_mode)

return responses.JSONResponse(content={'message': 'Success'},
status_code=200)

threading.Thread(target=self._run_autoscaler).start()

Expand Down
26 changes: 26 additions & 0 deletions sky/serve/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""This file defines the schemas for the requests to the controller.
"""

from typing import List

import pydantic

from sky.serve import serve_utils


class RequestAggregator(pydantic.BaseModel):
timestamps: List[float]
cblmemo marked this conversation as resolved.
Show resolved Hide resolved


class LoadBalancerRequest(pydantic.BaseModel):
request_aggregator: RequestAggregator


class UpdateServiceRequest(pydantic.BaseModel):
version: int
mode: serve_utils.UpdateMode = serve_utils.DEFAULT_UPDATE_MODE


class TerminateReplicaRequest(pydantic.BaseModel):
replica_id: int
purge: bool
13 changes: 11 additions & 2 deletions sky/serve/serve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def error_type(self) -> Type[Exception]:
return _SIGNAL_TO_ERROR[self]


class UpdateMode(enum.Enum):
class UpdateMode(str, enum.Enum):
"""Update mode for updating a service."""
ROLLING = 'rolling'
BLUE_GREEN = 'blue_green'
Expand Down Expand Up @@ -291,12 +291,21 @@ def update_service_encoded(service_name: str, version: int, mode: str) -> str:
if service_status is None:
raise ValueError(f'Service {service_name!r} does not exist.')
controller_port = service_status['controller_port']

try:
update_mode = UpdateMode(mode)
except ValueError:
with ux_utils.print_exception_no_traceback():
# pylint: disable=raise-missing-from
andylizf marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f'Invalid update mode: {mode}. '
f'Please specify one of {UpdateMode}.')

resp = requests.post(
_CONTROLLER_URL.format(CONTROLLER_PORT=controller_port) +
'/controller/update_service',
json={
'version': version,
'mode': mode,
'mode': update_mode.value,
})
if resp.status_code == 404:
raise ValueError('The service is up-ed in an old version and does not '
Expand Down
Loading