Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support Tailscale VPN #4025

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3808aa9
Modify the schema for tailscale VPN.
Conless Sep 24, 2024
761033f
Add `tailscale_auth_key` into service config.
Conless Sep 24, 2024
85fcc93
Add tailscale VPN into dependency in specific cases.
Conless Sep 25, 2024
69ec763
Fix tailscale installation on replicas.
Conless Sep 25, 2024
ee4343b
Fix format issues.
Conless Sep 25, 2024
779eba7
Minor fix on format issues.
Conless Sep 28, 2024
c137e1f
Move tailscale_auth_key to env and set it on controller.
Conless Sep 29, 2024
a571c4d
Fix some bugs and finish testing current implementation.
Conless Oct 1, 2024
69d6f0a
Refactor the VPN part and add it into the main launch logic now.
Conless Oct 1, 2024
6a3a831
Add some compatibility check.
Conless Oct 2, 2024
9336ed2
Resolve conflicts with upstream/master.
Conless Oct 2, 2024
0eb67eb
Update version after merge master branch.
Conless Oct 2, 2024
d89d497
Minor format issues.
Conless Oct 2, 2024
955e050
Resolve some format issues and add logic of getting tailscale IP by API.
Conless Oct 3, 2024
33ff645
Skip open ports and add some docs for VPN config.
Conless Oct 5, 2024
ddf6556
Add VPN as a cloud implemented feature.
Conless Oct 5, 2024
3d15683
Remove warning in backend about vpn_config.
Conless Oct 5, 2024
fcc8d1d
Move the setup logic into `update_cluster_ip`.
Conless Oct 6, 2024
7420638
Minor fix on format issues.
Conless Oct 6, 2024
84b62b4
Remove useless functions in previous design.
Conless Oct 6, 2024
ed65265
Only allow one vpn on a cluster.
Conless Oct 8, 2024
ce95cd5
Make the cluster name more human-readable.
Conless Oct 9, 2024
9d0cff9
Fix issues with hostname on TPU VMs.
Conless Oct 10, 2024
95f086c
Use `{cluster_name}-{node_id}` as hostname.
Conless Oct 29, 2024
8cdd338
Move VPN setup logic to controller task.
Conless Oct 29, 2024
47b4cdd
Resolve merge conflicts.
Conless Oct 29, 2024
3178cc0
Make it optional to expose service LB.
Conless Nov 2, 2024
3f875d5
Rewrite the logic of updating VPN IPs.
Conless Nov 4, 2024
2e6fe09
Remove ports when VPN is enabled.
Conless Nov 19, 2024
d77b7ce
Merge remote-tracking branch 'upstream/master' into vpn-enhanced
Conless Nov 19, 2024
5c3917d
Remove ports in provisioner instead of task creation.
Conless Nov 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from sky.utils import subprocess_utils
from sky.utils import timeline
from sky.utils import ux_utils
from sky.utils import vpn_utils

if typing.TYPE_CHECKING:
from sky import dag
Expand Down Expand Up @@ -2102,10 +2103,11 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle):
- (optional) Launched resources
- (optional) Docker user name
- (optional) If TPU(s) are managed, a path to a deletion script.
- (optional) VPN configuration
"""
# Bump if any fields get added/removed/changed, and add backward
# compaitibility logic in __setstate__.
_VERSION = 9
_VERSION = 10

def __init__(
self,
Expand Down Expand Up @@ -2134,6 +2136,8 @@ def __init__(
self.launched_nodes = launched_nodes
self.launched_resources = launched_resources
self.docker_user: Optional[str] = None
# VPN configuration for the cluster.
self.vpn_config: Optional[Dict[str, Any]] = None

def __repr__(self):
return (f'ResourceHandle('
Expand Down Expand Up @@ -2212,6 +2216,10 @@ def _update_cluster_info(self):
raise exceptions.FetchClusterInfoError(
exceptions.FetchClusterInfoError.Reason.HEAD)
self.cached_cluster_info = cluster_info
if self.vpn_config is not None:
vpn_utils.rewrite_cluster_info_by_vpn(self.cached_cluster_info,
self.cluster_name,
self.vpn_config)

def update_cluster_ips(
self,
Expand Down Expand Up @@ -2250,6 +2258,11 @@ def update_cluster_ips(
"""
if cluster_info is not None:
self.cached_cluster_info = cluster_info
# Update cluster config by private IPs (if available).
if self.vpn_config is not None:
vpn_utils.rewrite_cluster_info_by_vpn(self.cached_cluster_info,
self.cluster_name,
self.vpn_config)
cluster_feasible_ips = self.cached_cluster_info.get_feasible_ips()
cluster_internal_ips = self.cached_cluster_info.get_feasible_ips(
force_internal_ips=True)
Expand All @@ -2265,6 +2278,10 @@ def is_provided_ips_valid(

use_internal_ips = self._use_internal_ips()

assert self.vpn_config is None, (
'Clouds that do not support the new provisioner should not '
'have VPN configurations.')
cblmemo marked this conversation as resolved.
Show resolved Hide resolved

# cluster_feasible_ips is the list of IPs of the nodes in the
# cluster which can be used to connect to the cluster. It is a list
# of external IPs if the cluster is assigned public IPs, otherwise
Expand Down Expand Up @@ -2443,6 +2460,23 @@ def setup_docker_user(self, cluster_config_file: str):
cluster_config_file)
self.docker_user = docker_user

def setup_vpn(self, vpn_config: vpn_utils.VPNConfig) -> None:
Conless marked this conversation as resolved.
Show resolved Hide resolved
self.vpn_config = vpn_config.to_backend_config()
runners = self.get_command_runners()

def _run_setup_commands(id_runner):
node_id, runner = id_runner
cblmemo marked this conversation as resolved.
Show resolved Hide resolved
command = vpn_config.get_setup_command(self.cluster_name, node_id)
returncode, stdout, stderr = runner.run(command,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(
returncode, command, 'Failed to setup VPN on the cluster. '
f'Stdout: {stdout}. Stderr: {stderr}')

subprocess_utils.run_in_parallel(_run_setup_commands,
enumerate(runners))

@property
def cluster_yaml(self):
return os.path.expanduser(self._cluster_yaml)
Expand Down Expand Up @@ -2518,6 +2552,9 @@ def __setstate__(self, state):
state['launched_resources'] = launched_resources.copy(
region=context)

if version < 10:
self.vpn_config = None

self.__dict__.update(state)

# Because the update_cluster_ips and update_ssh_ports
Expand Down Expand Up @@ -2617,6 +2654,7 @@ def check_resources_fit_cluster(

mismatch_str = (f'To fix: specify a new cluster name, or down the '
f'existing cluster first: sky down {cluster_name}')

valid_resource = None
requested_resource_list = []
for resource in task.resources:
Expand All @@ -2630,6 +2668,15 @@ def check_resources_fit_cluster(
else:
requested_resource_list.append(f'{task.num_nodes}x {resource}')

# VPN check
vpn_check_result = vpn_utils.check_vpn_unchanged(
task.vpn_config, handle.vpn_config)
if vpn_check_result is not None:
with ux_utils.print_exception_no_traceback():
raise exceptions.ResourcesMismatchError(
'VPN configuration mismatch: '
f'{vpn_check_result}\n{mismatch_str}.')

if valid_resource is None:
for example_resource in task.resources:
if (example_resource.region is not None and
Expand Down Expand Up @@ -2855,6 +2902,7 @@ def _provision(
provision_record=provision_record,
custom_resource=resources_vars.get('custom_resources'),
log_dir=self.log_dir)

# We use the IPs from the cluster_info to update_cluster_ips,
# when the provisioning is done, to make sure the cluster IPs
# are up-to-date.
Expand All @@ -2867,6 +2915,11 @@ def _provision(
cluster_info=cluster_info)
handle.update_ssh_ports(max_attempts=_FETCH_IP_MAX_ATTEMPTS)

# If VPN is used, we need to reconfigure cluster IPs.
if task.vpn_config is not None:
handle.setup_vpn(task.vpn_config)
handle.update_cluster_ips(cluster_info=cluster_info)

# Update launched resources.
handle.launched_resources = handle.launched_resources.copy(
region=provision_record.region, zone=provision_record.zone)
Expand Down Expand Up @@ -2950,6 +3003,9 @@ def _get_zone(runner):
return handle

def _open_ports(self, handle: CloudVmRayResourceHandle) -> None:
if handle.vpn_config is not None and handle.vpn_config['use_vpn_ip']:
# Skip opening any ports if VPN IP is used.
return
cloud = handle.launched_resources.cloud
logger.debug(
f'Opening ports {handle.launched_resources.ports} for {cloud}')
Expand Down Expand Up @@ -4098,6 +4154,12 @@ def post_teardown_cleanup(self,
f'{common_utils.format_exception(e, use_bracket=True)}')
else:
raise
if terminate and handle.vpn_config is not None:
# Delete the VPN records when terminating the cluster.
if handle.cached_cluster_info is not None:
vpn_utils.remove_nodes_from_vpn(handle.cached_cluster_info,
handle.cluster_name,
handle.vpn_config)

# The cluster file must exist because the cluster_yaml will only
# be removed after the cluster entry in the database is removed.
Expand Down
1 change: 1 addition & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class CloudImplementationFeatures(enum.Enum):
STORAGE_MOUNTING = 'storage_mounting'
HOST_CONTROLLERS = 'host_controllers' # Can run jobs/serve controllers
AUTO_TERMINATE = 'auto_terminate' # Pod/VM can stop or down itself
VPN = 'vpn' # Setup with VPN


class Region(collections.namedtuple('Region', ['name'])):
Expand Down
2 changes: 2 additions & 0 deletions sky/clouds/ibm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def _unsupported_features_for_resources(
),
clouds.CloudImplementationFeatures.OPEN_PORTS:
(f'Opening ports is currently not supported on {cls._REPR}.'),
clouds.CloudImplementationFeatures.VPN:
(f'VPN is currently not supported on {cls._REPR}.'),
}
if resources.use_spot:
features[clouds.CloudImplementationFeatures.STOP] = (
Expand Down
1 change: 1 addition & 0 deletions sky/clouds/lambda_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Lambda(clouds.Cloud):
clouds.CloudImplementationFeatures.CUSTOM_DISK_TIER: f'Custom disk tiers are not supported in {_REPR}.',
clouds.CloudImplementationFeatures.OPEN_PORTS: f'Opening ports is currently not supported on {_REPR}.',
clouds.CloudImplementationFeatures.HOST_CONTROLLERS: f'Host controllers are not supported in {_REPR}.',
clouds.CloudImplementationFeatures.VPN: f'VPN is currently not supported in {_REPR}.',
}

PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT
Expand Down
2 changes: 2 additions & 0 deletions sky/clouds/oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def _unsupported_features_for_resources(
'`run` section in task.yaml.'),
clouds.CloudImplementationFeatures.OPEN_PORTS:
(f'Opening ports is currently not supported on {cls._REPR}.'),
clouds.CloudImplementationFeatures.VPN:
(f'VPN is currently not supported on {cls._REPR}.'),
}
if resources.use_spot:
features[clouds.CloudImplementationFeatures.STOP] = (
Expand Down
2 changes: 2 additions & 0 deletions sky/clouds/scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class SCP(clouds.Cloud):
(f'Custom disk tiers are not supported in {_REPR}.'),
clouds.CloudImplementationFeatures.OPEN_PORTS:
(f'Opening ports is currently not supported on {_REPR}.'),
clouds.CloudImplementationFeatures.VPN:
(f'VPN is currently not supported on {_REPR}.'),
}

_INDENT_PREFIX = ' '
Expand Down
2 changes: 2 additions & 0 deletions sky/clouds/vsphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class Vsphere(clouds.Cloud):
(f'Custom disk tiers are not supported in {_REPR}.'),
clouds.CloudImplementationFeatures.OPEN_PORTS:
(f'Opening ports is currently not supported on {_REPR}.'),
clouds.CloudImplementationFeatures.VPN:
(f'VPN is currently not supported on {_REPR}.'),
}

_MAX_CLUSTER_NAME_LEN_LIMIT = 80 # The name can't exceeds 80 characters
Expand Down
8 changes: 8 additions & 0 deletions sky/provision/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ def num_instances(self) -> int:
"""Get the number of instances in the cluster."""
return sum(len(instances) for instances in self.instances.values())

def get_instances(self) -> List[InstanceInfo]:
"""Get all instances."""
head_instance = self.get_head_instance()
worker_instances = self.get_worker_instances()
if head_instance is not None:
return [head_instance] + worker_instances
return worker_instances
cblmemo marked this conversation as resolved.
Show resolved Hide resolved

def get_head_instance(self) -> Optional[InstanceInfo]:
"""Get the instance metadata of the head node"""
if self.head_instance_id is None:
Expand Down
25 changes: 24 additions & 1 deletion sky/serve/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""SkyServe core APIs."""
import copy
import re
import tempfile
from typing import Any, Dict, List, Optional, Tuple, Union
Expand All @@ -24,6 +25,7 @@
from sky.utils import rich_utils
from sky.utils import subprocess_utils
from sky.utils import ux_utils
from sky.utils import vpn_utils

logger = sky_logging.init_logger(__name__)

Expand Down Expand Up @@ -182,7 +184,16 @@ def up(
}
controller_task.set_resources(controller_resources)

# # Set service_name so the backend will know to modify default ray
# Set VPN configuration on the controller, so the controller can
# start the VPN service on the replicas.
if task.vpn_config is not None:
controller_vpn_config = copy.copy(task.vpn_config)
if task.service is not None and task.service.expose_service:
controller_vpn_config.disable_vpn_ip()
controller_task.set_vpn_config(controller_vpn_config)
controller_task.update_envs(task.vpn_config.get_setup_env_vars())

# Set service_name so the backend will know to modify default ray
# task CPU usage to custom value instead of default 0.5 vCPU. We need
# to set it to a smaller value to support a larger number of services.
controller_task.service_name = service_name
Expand Down Expand Up @@ -330,6 +341,18 @@ def update(
f'use {ux_utils.BOLD}sky serve up{ux_utils.RESET_BOLD}',
)

vpn_check_result = vpn_utils.check_vpn_unchanged(
task.vpn_config, handle.vpn_config,
task.service.expose_service if task.service is not None else False)
if vpn_check_result is not None:
mismatch_str = (
f'To fix: specify a new cluster name, or down the '
f'existing cluster first: sky down {handle.cluster_name}')
with ux_utils.print_exception_no_traceback():
raise exceptions.ResourcesMismatchError(
'VPN configuration mismatch: '
f'{vpn_check_result}.\n{mismatch_str}')

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend)

Expand Down
9 changes: 9 additions & 0 deletions sky/serve/service_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
base_ondemand_fallback_replicas: Optional[int] = None,
upscale_delay_seconds: Optional[int] = None,
downscale_delay_seconds: Optional[int] = None,
expose_service: Optional[bool] = None,
) -> None:
if max_replicas is not None and max_replicas < min_replicas:
with ux_utils.print_exception_no_traceback():
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
int] = base_ondemand_fallback_replicas
self._upscale_delay_seconds: Optional[int] = upscale_delay_seconds
self._downscale_delay_seconds: Optional[int] = downscale_delay_seconds
self._expose_service: Optional[bool] = expose_service

self._use_ondemand_fallback: bool = (
self.dynamic_ondemand_fallback is not None and
Expand Down Expand Up @@ -149,6 +151,8 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec':
'base_ondemand_fallback_replicas', None)
service_config['dynamic_ondemand_fallback'] = policy_section.get(
'dynamic_ondemand_fallback', None)
if config.get('expose_service', False):
service_config['expose_service'] = True

return SkyServiceSpec(**service_config)

Expand Down Expand Up @@ -205,6 +209,7 @@ def add_if_not_none(section, key, value, no_empty: bool = False):
self.upscale_delay_seconds)
add_if_not_none('replica_policy', 'downscale_delay_seconds',
self.downscale_delay_seconds)
add_if_not_none('expose_service', None, self._expose_service)
return config

def probe_str(self):
Expand Down Expand Up @@ -307,6 +312,10 @@ def upscale_delay_seconds(self) -> Optional[int]:
def downscale_delay_seconds(self) -> Optional[int]:
return self._downscale_delay_seconds

@property
def expose_service(self) -> Optional[bool]:
return self._expose_service

@property
def use_ondemand_fallback(self) -> bool:
return self._use_ondemand_fallback
Loading