Skip to content

Commit

Permalink
Merge pull request #3580 from mr0re1/instance_class2
Browse files Browse the repository at this point in the history
Use `dataclass` for GCP instance
  • Loading branch information
mr0re1 authored Jan 24, 2025
2 parents bf8351a + aeae793 commit 47f966a
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def add_nodeset_topology(
except Exception:
continue

phys_host = inst.resourceStatus.get("physicalHost", "")
phys_host = inst.resource_status.get("physicalHost", "")
bldr.summary.physical_host[inst.name] = phys_host
up_nodes.add(inst.name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from itertools import chain
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Tuple, List, Optional, Protocol
from typing import Dict, Tuple, List, Optional, Protocol, Any
from functools import lru_cache

import util
Expand Down Expand Up @@ -119,10 +119,13 @@ def apply(self, nodes:List[str]) -> None:
hostlist = util.to_hostlist(nodes)
log.error(f"{len(nodes)} nodes have unexpected {self.slurm_state} and instance state:{self.instance_state}, ({hostlist})")

def start_instance_op(inst):
def start_instance_op(node: str) -> Any:
inst = lookup().instance(node)
assert inst

return lookup().compute.instances().start(
project=lookup().project,
zone=lookup().instance(inst).zone,
zone=inst.zone,
instance=inst,
)

Expand All @@ -132,7 +135,7 @@ def start_instances(node_list):
lkp = lookup()
# TODO: use code from resume.py to assign proper placement
normal, tpu_nodes = separate(lkp.node_is_tpu, node_list)
ops = {inst: start_instance_op(inst) for inst in normal}
ops = {node: start_instance_op(node) for node in normal}

done, failed = batch_execute(ops)

Expand Down Expand Up @@ -280,7 +283,7 @@ def get_node_action(nodename: str) -> NodeAction:
elif (state is None or "POWERED_DOWN" in state.flags) and inst.status == "RUNNING":
log.info("%s is potential orphan node", nodename)
threshold = timedelta(seconds=90)
age = datetime.now() - parse_gcp_timestamp(inst.creationTimestamp)
age = datetime.now() - inst.creation_timestamp
log.info(f"{nodename} state: {state}, age: {age}")
if age < threshold:
log.info(f"{nodename} not marked as orphan, it started less than {threshold.seconds}s ago ({age.seconds}s)")
Expand Down Expand Up @@ -464,9 +467,9 @@ def get_slurm_reservation_maintenance(lkp: util.Lookup) -> Dict[str, datetime]:
def get_upcoming_maintenance(lkp: util.Lookup) -> Dict[str, Tuple[str, datetime]]:
upc_maint_map = {}

for node, properties in lkp.instances().items():
if 'upcomingMaintenance' in properties:
start_time = parse_gcp_timestamp(properties['upcomingMaintenance']['startTimeWindow']['earliest'])
for node, inst in lkp.instances().items():
if inst.upcoming_maintenance:
start_time = parse_gcp_timestamp(inst.upcoming_maintenance['startTimeWindow']['earliest'])
upc_maint_map[node + "_maintenance"] = (node, start_time)

return upc_maint_map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, Any
import argparse
import logging

Expand Down Expand Up @@ -46,11 +46,14 @@ def truncate_iter(iterable, max_count):
yield el


def delete_instance_request(instance):
def delete_instance_request(name: str) -> Any:
inst = lookup().instance(name)
assert inst

request = lookup().compute.instances().delete(
project=lookup().project,
zone=lookup().instance(instance).zone,
instance=instance,
zone=inst.zone,
instance=name,
)
log_api_request(request)
return request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
from typing import Optional, Any
import sys
from dataclasses import dataclass, field
from datetime import datetime

SCRIPTS_DIR = "community/modules/scheduler/schedmd-slurm-gcp-v6-controller/modules/slurm_files/scripts"
if SCRIPTS_DIR not in sys.path:
sys.path.append(SCRIPTS_DIR) # TODO: make this more robust

import util


SOME_TS = datetime.fromisoformat("2018-09-03T20:56:35.450686+00:00")
# TODO: use "real" classes once they are defined (instead of NSDict)

@dataclass
Expand Down Expand Up @@ -83,17 +86,19 @@ class TstMachineConf:
class TstTemplateInfo:
gpu: Optional[util.AcceleratorInfo]

@dataclass
class TstInstance:
name: str
region: str = "gondor"
zone: str = "anorien"
placementPolicyId: Optional[str] = None
physicalHost: Optional[str] = None

@property
def resourceStatus(self):
return {"physicalHost": self.physicalHost}
def tstInstance(name: str, physical_host: Optional[str] = None):
return util.Instance(
name=name,
zone="anorien",
status="RUNNING",
creation_timestamp=SOME_TS,
resource_status=util.NSDict(
physicalHost = physical_host
),
scheduling=util.NSDict(),
upcoming_maintenance=None,
role="compute",
)

def make_to_hostnames_mock(tbl: Optional[dict[str, list[str]]]):
tbl = tbl or {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json
import mock
from pytest_unordered import unordered
from common import TstCfg, TstNodeset, TstTPU, TstInstance
from common import TstCfg, TstNodeset, TstTPU, tstInstance
import sort_nodes

import util
Expand Down Expand Up @@ -62,13 +62,13 @@ def tpu_se(ns: str, lkp) -> TstTPU:
lkp = util.Lookup(cfg)
lkp.instances = lambda: { n.name: n for n in [
# nodeset blue
TstInstance("m22-blue-0"), # no physicalHost
TstInstance("m22-blue-0", physicalHost="/a/a/a"),
TstInstance("m22-blue-1", physicalHost="/a/a/b"),
TstInstance("m22-blue-2", physicalHost="/a/b/a"),
TstInstance("m22-blue-3", physicalHost="/b/a/a"),
tstInstance("m22-blue-0"), # no physicalHost
tstInstance("m22-blue-0", physical_host="/a/a/a"),
tstInstance("m22-blue-1", physical_host="/a/a/b"),
tstInstance("m22-blue-2", physical_host="/a/b/a"),
tstInstance("m22-blue-3", physical_host="/b/a/a"),
# nodeset green
TstInstance("m22-green-3", physicalHost="/a/a/c"),
tstInstance("m22-green-3", physical_host="/a/a/c"),
]}

uncompressed = conf.gen_topology(lkp)
Expand Down Expand Up @@ -173,19 +173,19 @@ def test_gen_topology_conf_update():
# don't dump

# set empty physicalHost - no reconfigure
lkp.instances = lambda: { n.name: n for n in [TstInstance("m22-green-0", physicalHost="")]}
lkp.instances = lambda: { n.name: n for n in [tstInstance("m22-green-0", physical_host="")]}
upd, sum = conf.gen_topology_conf(lkp)
assert upd == False
# don't dump

# set physicalHost - reconfigure
lkp.instances = lambda: { n.name: n for n in [TstInstance("m22-green-0", physicalHost="/a/b/c")]}
lkp.instances = lambda: { n.name: n for n in [tstInstance("m22-green-0", physical_host="/a/b/c")]}
upd, sum = conf.gen_topology_conf(lkp)
assert upd == True
sum.dump(lkp)

# change physicalHost - reconfigure
lkp.instances = lambda: { n.name: n for n in [TstInstance("m22-green-0", physicalHost="/a/b/z")]}
lkp.instances = lambda: { n.name: n for n in [tstInstance("m22-green-0", physical_host="/a/b/z")]}
upd, sum = conf.gen_topology_conf(lkp)
assert upd == True
sum.dump(lkp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,35 @@ def sockets(self) -> int:
self.family, 1, # assume 1 socket for all other families
)


@dataclass(frozen=True)
class Instance:
name: str
zone: str
status: str
creation_timestamp: datetime
role: Optional[str]

# TODO: use proper InstanceResourceStatus class
resource_status: NSDict
# TODO: use proper InstanceScheduling class
scheduling: NSDict
# TODO: use proper UpcomingMaintenance class
upcoming_maintenance: Optional[NSDict] = None

@classmethod
def from_json(cls, jo: dict) -> "Instance":
labels = jo.get("labels", {})

return cls(
name=jo["name"],
zone=trim_self_link(jo["zone"]),
status=jo["status"],
creation_timestamp=parse_gcp_timestamp(jo["creationTimestamp"]),
resource_status=NSDict(jo.get("resourceStatus")),
scheduling=NSDict(jo.get("scheduling")),
upcoming_maintenance=NSDict(jo["upcomingMaintenance"]) if "upcomingMaintenance" in jo else None,
role = labels.get("slurm_instance_role"),
)

@lru_cache(maxsize=1)
def default_credentials():
Expand Down Expand Up @@ -1500,84 +1528,41 @@ def node_state(self, nodename: str) -> Optional[NodeState]:


@lru_cache(maxsize=1)
def instances(self) -> Dict[str, object]:
def instances(self) -> Dict[str, Instance]:
instance_information_fields = [
"advancedMachineFeatures",
"cpuPlatform",
"creationTimestamp",
"disks",
"disks",
"fingerprint",
"guestAccelerators",
"hostname",
"id",
"kind",
"labelFingerprint",
"labels",
"lastStartTimestamp",
"lastStopTimestamp",
"lastSuspendedTimestamp",
"machineType",
"metadata",
"name",
"networkInterfaces",
"resourceStatus",
"scheduling",
"selfLink",
"serviceAccounts",
"shieldedInstanceConfig",
"shieldedInstanceIntegrityPolicy",
"sourceMachineImage",
"status",
"statusMessage",
"tags",
"labels.slurm_instance_role",
"zone",
# "deletionProtection",
# "startRestricted",
]

# TODO: Merge this with all fields when upcoming maintenance is
# supported in beta.
if endpoint_version(ApiEndpoint.COMPUTE) == 'alpha':
instance_information_fields.append("upcomingMaintenance")

instance_information_fields = sorted(set(instance_information_fields))
instance_fields = ",".join(instance_information_fields)
instance_fields = ",".join(sorted(instance_information_fields))
fields = f"items.zones.instances({instance_fields}),nextPageToken"
flt = f"labels.slurm_cluster_name={self.cfg.slurm_cluster_name} AND name:{self.cfg.slurm_cluster_name}-*"
act = self.compute.instances()
op = act.aggregatedList(project=self.project, fields=fields, filter=flt)

def properties(inst):
"""change instance properties to a preferred format"""
inst["zone"] = trim_self_link(inst["zone"])
inst["machineType"] = trim_self_link(inst["machineType"])
# metadata is fetched as a dict of dicts like:
# {'key': key, 'value': value}, kinda silly
metadata = {i["key"]: i["value"] for i in inst["metadata"].get("items", [])}
if "slurm_instance_role" not in metadata:
return None
inst["role"] = metadata["slurm_instance_role"]
inst["metadata"] = metadata
# del inst["metadata"] # no need to store all the metadata
return NSDict(inst)

instances = {}
while op is not None:
result = ensure_execute(op)
instance_iter = (
(inst["name"], properties(inst))
for inst in chain.from_iterable(
zone.get("instances", []) for zone in result.get("items", {}).values()
)
)
instances.update(
{name: props for name, props in instance_iter if props is not None}
)
for zone in result.get("items", {}).values():
for jo in zone.get("instances", []):
inst = Instance.from_json(jo)
if inst.name in instances:
log.error(f"Duplicate VM name {inst.name} across multiple zones")
instances[inst.name] = inst
op = act.aggregatedList_next(op, result)
return instances

def instance(self, instance_name: str) -> Optional[object]:
def instance(self, instance_name: str) -> Optional[Instance]:
return self.instances().get(instance_name)

@lru_cache()
Expand Down

0 comments on commit 47f966a

Please sign in to comment.