Skip to content

Commit

Permalink
incorporate overprovision, store location in db
Browse files Browse the repository at this point in the history
  • Loading branch information
cblmemo committed Jan 31, 2025
1 parent cfbe767 commit 32c55be
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 51 deletions.
28 changes: 19 additions & 9 deletions sky/serve/autoscalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self, service_name: str,
self.min_replicas: int = spec.min_replicas
self.max_replicas: int = (spec.max_replicas if spec.max_replicas
is not None else spec.min_replicas)
self.num_overprovision: Optional[int] = spec.num_overprovision
# Target number of replicas is initialized to min replicas
self.target_num_replicas: int = spec.min_replicas
self.latest_version: int = constants.INITIAL_VERSION
Expand All @@ -143,6 +144,12 @@ def __init__(self, service_name: str,
self.latest_version_ever_ready: int = self.latest_version - 1
self.update_mode = serve_utils.DEFAULT_UPDATE_MODE

def get_final_target_num_replicas(self) -> int:
"""Get the final target number of replicas."""
if self.num_overprovision is None:
return self.target_num_replicas
return self.target_num_replicas + self.num_overprovision

def _calculate_target_num_replicas(self) -> int:
"""Calculate target number of replicas."""
raise NotImplementedError
Expand Down Expand Up @@ -207,7 +214,7 @@ def get_decision_interval(self) -> int:
0, to make the service scale faster when the service is not running.
This will happen when min_replicas = 0 and no traffic.
"""
if self.target_num_replicas == 0:
if self.get_final_target_num_replicas() == 0:
return constants.AUTOSCALER_NO_REPLICA_DECISION_INTERVAL_SECONDS
else:
return constants.AUTOSCALER_DEFAULT_DECISION_INTERVAL_SECONDS
Expand Down Expand Up @@ -236,13 +243,14 @@ def _select_outdated_replicas_to_scale_down(
# old and latest versions are allowed in rolling update, this will
# not affect the time it takes for the service to updated to the
# latest version.
if num_latest_ready_replicas >= self.target_num_replicas:
if (num_latest_ready_replicas >=
self.get_final_target_num_replicas()):
# Once the number of ready new replicas is greater than or equal
# to the target, we can scale down all old replicas.
return [info.replica_id for info in old_nonterminal_replicas]
# If rolling update is in progress, we scale down old replicas
# based on the number of ready new replicas.
num_old_replicas_to_keep = (self.target_num_replicas -
num_old_replicas_to_keep = (self.get_final_target_num_replicas() -
num_latest_ready_replicas)
# Remove old replicas (especially old launching replicas) and only
# keep the required number of replicas, as we want to let the new
Expand Down Expand Up @@ -422,6 +430,7 @@ def _set_target_num_replicas_with_hysteresis(self) -> None:
f'Old target number of replicas: {old_target_num_replicas}. '
f'Current target number of replicas: {target_num_replicas}. '
f'Final target number of replicas: {self.target_num_replicas}. '
f'Num overprovision: {self.num_overprovision}. '
f'Upscale counter: {self.upscale_counter}/'
f'{self.scale_up_threshold}. '
f'Downscale counter: {self.downscale_counter}/'
Expand Down Expand Up @@ -505,20 +514,21 @@ def _generate_scaling_decisions(

# Case 1. when latest_nonterminal_replicas is less
# than num_to_provision, we always scale up new replicas.
if len(latest_nonterminal_replicas) < self.target_num_replicas:
num_replicas_to_scale_up = (self.target_num_replicas -
target_num_replicas = self.get_final_target_num_replicas()
if len(latest_nonterminal_replicas) < target_num_replicas:
num_replicas_to_scale_up = (target_num_replicas -
len(latest_nonterminal_replicas))
logger.info('Number of replicas to scale up: '
f'{num_replicas_to_scale_up}')
scaling_decisions.extend(
_generate_scale_up_decisions(num_replicas_to_scale_up, None))

# Case 2: when latest_nonterminal_replicas is more
# than self.target_num_replicas, we scale down new replicas.
# than target_num_replicas, we scale down new replicas.
replicas_to_scale_down = []
if len(latest_nonterminal_replicas) > self.target_num_replicas:
if len(latest_nonterminal_replicas) > target_num_replicas:
num_replicas_to_scale_down = (len(latest_nonterminal_replicas) -
self.target_num_replicas)
target_num_replicas)
replicas_to_scale_down = (
_select_nonterminal_replicas_to_scale_down(
num_replicas_to_scale_down, latest_nonterminal_replicas))
Expand Down Expand Up @@ -633,7 +643,7 @@ def _generate_scaling_decisions(
all_replica_ids_to_scale_down: List[int] = []

# Decide how many spot instances to launch.
num_spot_to_provision = (self.target_num_replicas -
num_spot_to_provision = (self.get_final_target_num_replicas() -
self.base_ondemand_fallback_replicas)
if num_nonterminal_spot < num_spot_to_provision:
# Not enough spot instances, scale up.
Expand Down
67 changes: 44 additions & 23 deletions sky/serve/replica_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def launch_cluster(replica_id: int,
task_yaml_path: str,
cluster_name: str,
resources_override: Optional[Dict[str, Any]] = None,
retry_until_up: bool = True,
max_retry: int = 3) -> None:
"""Launch a sky serve replica cluster.
Expand Down Expand Up @@ -100,7 +101,7 @@ def launch_cluster(replica_id: int,
cluster_name,
detach_setup=True,
detach_run=True,
retry_until_up=True,
retry_until_up=retry_until_up,
_is_launched_by_sky_serve_controller=True)
logger.info(f'Replica cluster {cluster_name} launched.')
except (exceptions.InvalidClusterNameError,
Expand Down Expand Up @@ -248,6 +249,10 @@ class ReplicaStatusProperty:
preempted: bool = False
# Whether the replica is purged.
purged: bool = False
# Whether the replica failed to launch due to spot availability.
# This is only possible when spot placer is enabled, so the retry until up
# is set to True and it can fail immediately due to spot availability.
failed_spot_availability: bool = False

def remove_terminated_replica(self) -> bool:
"""Whether to remove the replica record from the replica table.
Expand Down Expand Up @@ -387,10 +392,11 @@ def to_replica_status(self) -> serve_state.ReplicaStatus:
class ReplicaInfo:
"""Replica info for each replica."""

_VERSION = 0
_VERSION = 1

def __init__(self, replica_id: int, cluster_name: str, replica_port: str,
is_spot: bool, version: int) -> None:
is_spot: bool, location: Optional[spot_placer.Location],
version: int) -> None:
self._version = self._VERSION
self.replica_id: int = replica_id
self.cluster_name: str = cluster_name
Expand All @@ -400,6 +406,11 @@ def __init__(self, replica_id: int, cluster_name: str, replica_port: str,
self.consecutive_failure_times: List[float] = []
self.status_property: ReplicaStatusProperty = ReplicaStatusProperty()
self.is_spot: bool = is_spot
self.location: Optional[Dict[str, Optional[str]]] = (
location.to_pickleable() if location is not None else None)

def get_spot_location(self) -> Optional[spot_placer.Location]:
return spot_placer.Location.from_pickleable(self.location)

def handle(
self,
Expand Down Expand Up @@ -485,6 +496,7 @@ def __repr__(self) -> str:
f'version={self.version}, '
f'replica_port={self.replica_port}, '
f'is_spot={self.is_spot}, '
f'location={self.location}, '
f'status={self.status}, '
f'launched_at={info_dict["launched_at"]}{handle_str})')
return info
Expand Down Expand Up @@ -559,6 +571,9 @@ def __setstate__(self, state):
# Treated similar to on-demand instances.
self.is_spot = False

if version < 1:
self.location = None

self.__dict__.update(state)


Expand Down Expand Up @@ -658,36 +673,41 @@ def _launch_replica(
self._service_name, replica_id)
log_file_name = serve_utils.generate_replica_launch_log_file_name(
self._service_name, replica_id)
retry_until_up = True
location = None
if self._spot_placer is not None:
# For spot placer, we don't retry until up so any launch failed
# due to availability issue will be handled by the placer.
retry_until_up = False
# TODO(tian): Currently, the resources_override can only be
# `use_spot=True/False`, which will not cause any conflict with
# spot placer's cloud, region & zone. When we add more resources
# to the resources_override, we need to make sure they won't
# conflict with the spot placer's selection.
if resources_override is None:
resources_override = {}
current_resources = []
current_spot_locations: List[spot_placer.Location] = []
for info in serve_state.get_replica_infos(self._service_name):
handle = global_user_state.get_handle_from_cluster_name(
info.cluster_name)
assert handle is not None and isinstance(
handle, backends.CloudVmRayResourceHandle)
current_resources.append(handle.launched_resources)
resources_override.update(
self._spot_placer.select_next_location(current_resources))
if info.is_spot:
spot_location = info.get_spot_location()
if spot_location is not None:
current_spot_locations.append(spot_location)
location = self._spot_placer.select_next_location(
current_spot_locations)
resources_override.update(location.to_dict())
p = multiprocessing.Process(
target=ux_utils.RedirectOutputForProcess(
launch_cluster,
log_file_name,
).run,
args=(replica_id, self._task_yaml_path, cluster_name,
resources_override),
resources_override, retry_until_up),
)
replica_port = _get_resources_ports(self._task_yaml_path)
use_spot = _should_use_spot(self._task_yaml_path, resources_override)

info = ReplicaInfo(replica_id, cluster_name, replica_port, use_spot,
self.latest_version)
location, self.latest_version)
serve_state.add_or_update_replica(self._service_name, replica_id, info)
# Don't start right now; we will start it later in _refresh_process_pool
# to avoid too many sky.launch running at the same time.
Expand Down Expand Up @@ -836,7 +856,9 @@ def _handle_preemption(self, info: ReplicaInfo) -> bool:
f'Replica {info.replica_id} is preempted{cluster_status_str}.')
info.status_property.preempted = True
if self._spot_placer is not None:
self._spot_placer.set_preemptive(handle.launched_resources)
spot_location = info.get_spot_location()
assert spot_location is not None
self._spot_placer.set_preemptive(spot_location)
serve_state.add_or_update_replica(self._service_name, info.replica_id,
info)
self._terminate_replica(info.replica_id,
Expand Down Expand Up @@ -891,11 +913,7 @@ def _refresh_process_pool(self) -> None:
else:
info.status_property.sky_launch_status = (
ProcessStatus.SUCCEEDED)
if self._spot_placer is not None:
handle = global_user_state.get_handle_from_cluster_name(
info.cluster_name)
assert handle is not None and isinstance(
handle, backends.CloudVmRayResourceHandle)
if self._spot_placer is not None and info.is_spot:
# TODO(tian): Currently, we set the location to
# preemptive if the launch process failed. This is
# because if the error is not related to the
Expand All @@ -905,12 +923,13 @@ def _refresh_process_pool(self) -> None:
# locations would fail. We should implement a log parser
# to detect if the error is actually related to the
# availability of the location later.
location = info.get_spot_location()
assert location is not None
if p.exitcode != 0:
self._spot_placer.set_preemptive(
handle.launched_resources)
self._spot_placer.set_preemptive(location)
info.status_property.failed_spot_availability = True
else:
self._spot_placer.set_active(
handle.launched_resources)
self._spot_placer.set_active(location)
serve_state.add_or_update_replica(self._service_name,
replica_id, info)
if error_in_sky_launch:
Expand Down Expand Up @@ -961,6 +980,8 @@ def _refresh_process_pool(self) -> None:
removal_reason = 'for version outdated'
elif info.status_property.purged:
removal_reason = 'for purge'
elif info.status_property.failed_spot_availability:
removal_reason = 'for spot availability failure'
else:
logger.info(f'Termination of replica {replica_id} '
'finished. Replica info is kept since some '
Expand Down
19 changes: 17 additions & 2 deletions sky/serve/service_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
readiness_timeout_seconds: int,
min_replicas: int,
max_replicas: Optional[int] = None,
num_overprovision: Optional[int] = None,
ports: Optional[str] = None,
target_qps_per_replica: Optional[float] = None,
post_data: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self._readiness_timeout_seconds: int = readiness_timeout_seconds
self._min_replicas: int = min_replicas
self._max_replicas: Optional[int] = max_replicas
self._num_overprovision: Optional[int] = num_overprovision
self._ports: Optional[str] = ports
self._target_qps_per_replica: Optional[float] = target_qps_per_replica
self._post_data: Optional[Dict[str, Any]] = post_data
Expand Down Expand Up @@ -159,13 +161,16 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec':
min_replicas = constants.DEFAULT_MIN_REPLICAS
service_config['min_replicas'] = min_replicas
service_config['max_replicas'] = None
service_config['num_overprovision'] = None
service_config['target_qps_per_replica'] = None
service_config['upscale_delay_seconds'] = None
service_config['downscale_delay_seconds'] = None
else:
service_config['min_replicas'] = policy_section['min_replicas']
service_config['max_replicas'] = policy_section.get(
'max_replicas', None)
service_config['num_overprovision'] = policy_section.get(
'num_overprovision', None)
service_config['target_qps_per_replica'] = policy_section.get(
'target_qps_per_replica', None)
service_config['upscale_delay_seconds'] = policy_section.get(
Expand Down Expand Up @@ -238,6 +243,8 @@ def add_if_not_none(section: str,
add_if_not_none('readiness_probe', 'headers', self._readiness_headers)
add_if_not_none('replica_policy', 'min_replicas', self.min_replicas)
add_if_not_none('replica_policy', 'max_replicas', self.max_replicas)
add_if_not_none('replica_policy', 'num_overprovision',
self.num_overprovision)
add_if_not_none('replica_policy', 'target_qps_per_replica',
self.target_qps_per_replica)
add_if_not_none('replica_policy', 'dynamic_ondemand_fallback',
Expand Down Expand Up @@ -302,9 +309,13 @@ def autoscaling_policy_str(self):
assert self.target_qps_per_replica is not None
# TODO(tian): Refactor to contain more information
max_plural = '' if self.max_replicas == 1 else 's'
overprovision_str = ''
if self.num_overprovision is not None:
overprovision_str = (
f' with {self.num_overprovision} overprovisioned replicas')
return (f'Autoscaling from {self.min_replicas} to {self.max_replicas} '
f'replica{max_plural} (target QPS per replica: '
f'{self.target_qps_per_replica})')
f'replica{max_plural}{overprovision_str} (target QPS per '
f'replica: {self.target_qps_per_replica})')

def set_ports(self, ports: str) -> None:
self._ports = ports
Expand Down Expand Up @@ -347,6 +358,10 @@ def max_replicas(self) -> Optional[int]:
# If None, treated as having the same value of min_replicas.
return self._max_replicas

@property
def num_overprovision(self) -> Optional[int]:
return self._num_overprovision

@property
def ports(self) -> Optional[str]:
return self._ports
Expand Down
Loading

0 comments on commit 32c55be

Please sign in to comment.