From f76db0d51d351e4729778a043a0a317bf622c729 Mon Sep 17 00:00:00 2001
From: Christopher Cooper <cooperc@assemblesys.com>
Date: Mon, 16 Dec 2024 09:59:43 -0800
Subject: [PATCH] [core] skip provider.availability_zone in the cluster config
 hash (#4463)

skip provider.availability_zone in the cluster config hash
---
 sky/backends/backend_utils.py | 36 ++++++++++++++++++++++++++++++++---
 1 file changed, 33 insertions(+), 3 deletions(-)

diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py
index a3651bdba9a..0f55b8a7f17 100644
--- a/sky/backends/backend_utils.py
+++ b/sky/backends/backend_utils.py
@@ -173,6 +173,16 @@
     ('available_node_types', 'ray.head.default', 'node_config',
      'azure_arm_parameters', 'cloudInitSetupCommands'),
 ]
+# These keys are expected to change when provisioning on an existing cluster,
+# but they don't actually represent a change that requires re-provisioning the
+# cluster.  If the cluster yaml is the same except for these keys, we can safely
+# skip reprovisioning. See _deterministic_cluster_yaml_hash.
+_RAY_YAML_KEYS_TO_REMOVE_FOR_HASH = [
+    # On first launch, availability_zones will include all possible zones. Once
+    # the cluster exists, it will only include the zone that the cluster is
+    # actually in.
+    ('provider', 'availability_zone'),
+]
 
 
 def is_ip(s: str) -> bool:
@@ -1087,7 +1097,7 @@ def _deterministic_cluster_yaml_hash(yaml_path: str) -> str:
     yaml file and all the files in the file mounts, then hash the byte sequence.
 
     The format of the byte sequence is:
-    32 bytes - sha256 hash of the yaml file
+    32 bytes - sha256 hash of the yaml
     for each file mount:
       file mount remote destination (UTF-8), \0
       if the file mount source is a file:
@@ -1111,14 +1121,29 @@ def _deterministic_cluster_yaml_hash(yaml_path: str) -> str:
     we construct it incrementally by using hash.update() to add new bytes.
     """
 
+    # Load the yaml contents so that we can directly remove keys.
+    yaml_config = common_utils.read_yaml(yaml_path)
+    for key_list in _RAY_YAML_KEYS_TO_REMOVE_FOR_HASH:
+        dict_to_remove_from = yaml_config
+        found_key = True
+        for key in key_list[:-1]:
+            if (not isinstance(dict_to_remove_from, dict) or
+                    key not in dict_to_remove_from):
+                found_key = False
+                break
+            dict_to_remove_from = dict_to_remove_from[key]
+        if found_key and key_list[-1] in dict_to_remove_from:
+            dict_to_remove_from.pop(key_list[-1])
+
     def _hash_file(path: str) -> bytes:
         return common_utils.hash_file(path, 'sha256').digest()
 
     config_hash = hashlib.sha256()
 
-    config_hash.update(_hash_file(yaml_path))
+    yaml_hash = hashlib.sha256(
+        common_utils.dump_yaml_str(yaml_config).encode('utf-8'))
+    config_hash.update(yaml_hash.digest())
 
-    yaml_config = common_utils.read_yaml(yaml_path)
     file_mounts = yaml_config.get('file_mounts', {})
     # Remove the file mounts added by the newline.
     if '' in file_mounts:
@@ -1126,6 +1151,11 @@ def _hash_file(path: str) -> bytes:
         file_mounts.pop('')
 
     for dst, src in sorted(file_mounts.items()):
+        if src == yaml_path:
+            # Skip the yaml file itself. We have already hashed a modified
+            # version of it. The file may include fields we don't want to hash.
+            continue
+
         expanded_src = os.path.expanduser(src)
         config_hash.update(dst.encode('utf-8') + b'\0')