Skip to content

Commit

Permalink
[k8s] Enable multiple kubernetes contexts for failover (#3968)
Browse files Browse the repository at this point in the history
* wip

* Fix

* format

* format

* Fix context and namespace used

* update

* fix

* Fix feasibility check

* fix image for k8s

* patch k8s tests

* format

* format

* format

* Fix tests

* avoid -s

* Fix acc detection

* format

* Update docs/source/reference/config.rst

Co-authored-by: Romil Bhardwaj <[email protected]>

* refactor a little

* Add docs for k8s context update

* Use all pods in a context

* Add policy

* Fix unsupported features and other kube calls

* Add policies

* Fix backward compatbility

* Add smoke test

* set

* fix typing

* Add check for local k8s cluster in smoke test

* Add skypilot config

* Fix smoke

* Make loging log once

* format

* format

---------

Co-authored-by: Romil Bhardwaj <[email protected]>
  • Loading branch information
Michaelvll and romilbhardwaj authored Sep 26, 2024
1 parent 4740ea8 commit 4e46cf4
Show file tree
Hide file tree
Showing 21 changed files with 599 additions and 165 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ jobs:
pip install pytest pytest-xdist pytest-env>=0.6 memory-profiler==0.61.0
- name: Run tests with pytest
run: SKYPILOT_DISABLE_USAGE_COLLECTION=1 SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK=1 pytest -n 1 --dist no ${{ matrix.test-path }}
run: SKYPILOT_DISABLE_USAGE_COLLECTION=1 SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK=1 pytest -n 0 --dist no ${{ matrix.test-path }}
16 changes: 16 additions & 0 deletions docs/source/cloud-setup/policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Example usage:
- :ref:`disable-public-ip-policy`
- :ref:`use-spot-for-gpu-policy`
- :ref:`enforce-autostop-policy`
- :ref:`dynamic-kubernetes-contexts-update-policy`


To implement and use an admin policy:
Expand Down Expand Up @@ -193,3 +194,18 @@ Enforce Autostop for all Tasks
.. literalinclude:: ../../../examples/admin_policy/enforce_autostop.yaml
:language: yaml
:caption: `Config YAML for using EnforceAutostopPolicy <https://github.com/skypilot-org/skypilot/blob/master/examples/admin_policy/enforce_autostop.yaml>`_


.. _dynamic-kubernetes-contexts-update-policy:

Dynamically Update Kubernetes Contexts to Use
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. literalinclude:: ../../../examples/admin_policy/example_policy/example_policy/skypilot_policy.py
:language: python
:pyobject: DynamicKubernetesContextsUpdatePolicy
:caption: `DynamicKubernetesContextsUpdatePolicy <https://github.com/skypilot-org/skypilot/blob/master/examples/admin_policy/example_policy/example_policy/skypilot_policy.py>`_

.. literalinclude:: ../../../examples/admin_policy/dynamic_kubernetes_contexts_update.yaml
:language: yaml
:caption: `Config YAML for using DynamicKubernetesContextsUpdatePolicy <https://github.com/skypilot-org/skypilot/blob/master/examples/admin_policy/dynamic_kubernetes_contexts_update.yaml>`_
13 changes: 13 additions & 0 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,19 @@ Available fields and semantics:
# Default: 'SERVICE_ACCOUNT'.
remote_identity: my-k8s-service-account
# Allowed context names to use for Kubernetes clusters (optional).
#
# SkyPilot will try provisioning and failover Kubernetes contexts in the
# same order as they are specified here. E.g., SkyPilot will try using
# context1 first. If it is out of resources or unreachable, it will failover
# and try context2.
#
# If not specified, only the current active context is used for launching
# new clusters.
allowed_contexts:
- context1
- context2
# Attach custom metadata to Kubernetes objects created by SkyPilot
#
# Uses the same schema as Kubernetes metadata object: https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.26/#objectmeta-v1-meta
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
admin_policy: example_policy.DynamicKubernetesContextsUpdatePolicy
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Example admin policy module and prebuilt policies."""
from example_policy.skypilot_policy import AddLabelsPolicy
from example_policy.skypilot_policy import DisablePublicIpPolicy
from example_policy.skypilot_policy import DynamicKubernetesContextsUpdatePolicy
from example_policy.skypilot_policy import EnforceAutostopPolicy
from example_policy.skypilot_policy import RejectAllPolicy
from example_policy.skypilot_policy import UseSpotForGpuPolicy
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Example prebuilt admin policies."""
import subprocess

import sky


Expand Down Expand Up @@ -119,3 +121,47 @@ def validate_and_mutate(
return sky.MutatedUserRequest(
task=user_request.task,
skypilot_config=user_request.skypilot_config)


def update_current_kubernetes_clusters_from_registry():
"""Mock implementation of updating kubernetes clusters from registry."""
# All cluster names can be fetched from an organization's internal API.
NEW_CLUSTER_NAMES = ['my-cluster']
for cluster_name in NEW_CLUSTER_NAMES:
# Update the local kubeconfig with the new cluster credentials.
subprocess.run(
f'gcloud container clusters get-credentials {cluster_name} '
'--region us-central1-c',
shell=True,
check=False)


def get_allowed_contexts():
"""Mock implementation of getting allowed kubernetes contexts."""
from sky.provision.kubernetes import utils
contexts = utils.get_all_kube_config_context_names()
return contexts[:2]


class DynamicKubernetesContextsUpdatePolicy(sky.AdminPolicy):
"""Example policy: update the kubernetes context to use."""

@classmethod
def validate_and_mutate(
cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
"""Updates the kubernetes context to use."""
# Append any new kubernetes clusters in local kubeconfig. An example
# implementation of this method can be:
# 1. Query an organization's internal Kubernetes cluster registry,
# which can be some internal API, or a secret vault.
# 2. Append the new credentials to the local kubeconfig.
update_current_kubernetes_clusters_from_registry()
# Get the allowed contexts for the user. Similarly, it can retrieve
# the latest allowed contexts from an organization's internal API.
allowed_contexts = get_allowed_contexts()

# Update the kubernetes allowed contexts in skypilot config.
config = user_request.skypilot_config
config.set_nested(('kubernetes', 'allowed_contexts'), allowed_contexts)
return sky.MutatedUserRequest(task=user_request.task,
skypilot_config=config)
18 changes: 10 additions & 8 deletions sky/adaptors/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,17 @@ def _load_config(context: Optional[str] = None):
suffix += f' Error: {str(e)}'
# Check if exception was due to no current-context
if 'Expected key current-context' in str(e):
err_str = ('Failed to load Kubernetes configuration. '
'Kubeconfig does not contain any valid context(s).'
f'{suffix}\n'
' If you were running a local Kubernetes '
'cluster, run `sky local up` to start the cluster.')
err_str = (
f'Failed to load Kubernetes configuration for {context!r}. '
'Kubeconfig does not contain any valid context(s).'
f'{suffix}\n'
' If you were running a local Kubernetes '
'cluster, run `sky local up` to start the cluster.')
else:
err_str = ('Failed to load Kubernetes configuration. '
'Please check if your kubeconfig file exists at '
f'~/.kube/config and is valid.{suffix}')
err_str = (
f'Failed to load Kubernetes configuration for {context!r}. '
'Please check if your kubeconfig file exists at '
f'~/.kube/config and is valid.{suffix}')
err_str += '\nTo disable Kubernetes for SkyPilot: run `sky check`.'
with ux_utils.print_exception_no_traceback():
raise ValueError(err_str) from None
Expand Down
16 changes: 10 additions & 6 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,11 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
secret_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME
secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name
namespace = config['provider'].get(
'namespace',
kubernetes_utils.get_current_kube_config_context_namespace())
context = config['provider'].get(
'context', kubernetes_utils.get_current_kube_config_context_name())
namespace = config['provider'].get(
'namespace',
kubernetes_utils.get_kube_config_context_namespace(context))
k8s = kubernetes.kubernetes
with open(public_key_path, 'r', encoding='utf-8') as f:
public_key = f.read()
Expand Down Expand Up @@ -425,8 +425,8 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
ssh_jump_name,
nodeport_mode,
private_key_path=private_key_path,
namespace=namespace,
context=context)
context=context,
namespace=namespace)
elif network_mode == port_forward_mode:
# Using `kubectl port-forward` creates a direct tunnel to the pod and
# does not require a ssh jump pod.
Expand All @@ -441,7 +441,11 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
# on GKE.
ssh_target = config['cluster_name'] + '-head'
ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
ssh_target, port_forward_mode, private_key_path=private_key_path)
ssh_target,
port_forward_mode,
private_key_path=private_key_path,
context=context,
namespace=namespace)
else:
# This should never happen because we check for this in from_str above.
raise ValueError(f'Unsupported networking mode: {network_mode_str}')
Expand Down
15 changes: 14 additions & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,7 +2082,7 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle):
"""
# Bump if any fields get added/removed/changed, and add backward
# compaitibility logic in __setstate__.
_VERSION = 8
_VERSION = 9

def __init__(
self,
Expand Down Expand Up @@ -2516,6 +2516,19 @@ def __setstate__(self, state):
if version < 8:
self.cached_cluster_info = None

if version < 9:
# For backward compatibility, we should update the region of a
# SkyPilot cluster on Kubernetes to the actual context it is using.
# pylint: disable=import-outside-toplevel
launched_resources = state['launched_resources']
if isinstance(launched_resources.cloud, clouds.Kubernetes):
yaml_config = common_utils.read_yaml(
os.path.expanduser(state['_cluster_yaml']))
context = kubernetes_utils.get_context_from_config(
yaml_config['provider'])
state['launched_resources'] = launched_resources.copy(
region=context)

self.__dict__.update(state)

# Because the update_cluster_ips and update_ssh_ports
Expand Down
20 changes: 10 additions & 10 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3026,14 +3026,11 @@ def show_gpus(
kubernetes_is_enabled = sky_clouds.cloud_in_iterable(
sky_clouds.Kubernetes(), global_user_state.get_cached_enabled_clouds())

if cloud_is_kubernetes and region is not None:
raise click.UsageError(
'The --region flag cannot be set with --cloud kubernetes.')

def _list_to_str(lst):
return ', '.join([str(e) for e in lst])

def _get_kubernetes_realtime_gpu_table(
context: Optional[str] = None,
name_filter: Optional[str] = None,
quantity_filter: Optional[int] = None):
if quantity_filter:
Expand All @@ -3048,7 +3045,7 @@ def _get_kubernetes_realtime_gpu_table(
gpus_only=True,
clouds='kubernetes',
name_filter=name_filter,
region_filter=region,
region_filter=context,
quantity_filter=quantity_filter,
case_sensitive=False)
assert (set(counts.keys()) == set(capacity.keys()) == set(
Expand Down Expand Up @@ -3078,11 +3075,11 @@ def _get_kubernetes_realtime_gpu_table(
])
return realtime_gpu_table

def _get_kubernetes_node_info_table():
def _get_kubernetes_node_info_table(context: Optional[str]):
node_table = log_utils.create_table(
['NODE_NAME', 'GPU_NAME', 'TOTAL_GPUS', 'FREE_GPUS'])

node_info_dict = kubernetes_utils.get_kubernetes_node_info()
node_info_dict = kubernetes_utils.get_kubernetes_node_info(context)
for node_name, node_info in node_info_dict.items():
node_table.add_row([
node_name, node_info.gpu_type,
Expand Down Expand Up @@ -3116,11 +3113,13 @@ def _output():
print_section_titles = False
# If cloud is kubernetes, we want to show real-time capacity
if kubernetes_is_enabled and (cloud is None or cloud_is_kubernetes):
context = region
try:
# If --cloud kubernetes is not specified, we want to catch
# the case where no GPUs are available on the cluster and
# print the warning at the end.
k8s_realtime_table = _get_kubernetes_realtime_gpu_table()
k8s_realtime_table = _get_kubernetes_realtime_gpu_table(
context)
except ValueError as e:
if not cloud_is_kubernetes:
# Make it a note if cloud is not kubernetes
Expand All @@ -3129,9 +3128,10 @@ def _output():
else:
print_section_titles = True
yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Kubernetes GPUs{colorama.Style.RESET_ALL}\n')
f'Kubernetes GPUs (Context: {context})'
f'{colorama.Style.RESET_ALL}\n')
yield from k8s_realtime_table.get_string()
k8s_node_table = _get_kubernetes_node_info_table()
k8s_node_table = _get_kubernetes_node_info_table(context)
yield '\n\n'
yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Kubernetes per node GPU availability'
Expand Down
Loading

0 comments on commit 4e46cf4

Please sign in to comment.