Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SlurmGCP. Make delete instances status tracking "asynchronous" #3419

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2024 "Google LLC"
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import json

import logging
from dataclasses import dataclass, asdict
from pathlib import Path

import util

log = logging.getLogger()

SUPPORTED_OPERATION_TYPES = frozenset({"delete"})

@dataclass(frozen=True)
class _Record:
# common fields
name: str
type: str

region: Optional[str] = None
zone: Optional[str] = None

# operation-type specific fields
nodes: Optional[List[str]] = None

@classmethod
def from_json(cls, jo: dict) -> "_Record":
return cls(**jo)

def to_json(self) -> dict:
return {k: v for k, v in asdict(self).items() if v is not None}

@classmethod
def from_op(cls, op: dict, **extra) -> "_Record":
base = dict(
name = op["name"],
type = op["operationType"])
if "region" in op:
base["region"] = util.trim_self_link(op["region"])
if "zone" in op:
base["zone"] = util.trim_self_link(op["zone"])
return cls.from_json({**base, **extra})

def _records_dir() -> Path:
return Path("/var/spool/slurm_gcp/watched_ops") # !!! create it

def _record_path(r: _Record) -> Path:
return _records_dir() / f"{r.name}.json"

def _list_records() -> List[_Record]:
res = []
for p in _records_dir().glob("*.json"):
try:
jo = json.loads(p.read_text())
res.append(_Record.from_json(jo))
except:
log.exception(f"Failed to read {p}")
return res

def _add_record(r: _Record) -> None:
path = _record_path(r)
assert not path.exists(), f"{path}"
# No concern about reading partial writes,
# since json deserialization would simply fail
path.write_text(json.dumps(r.to_json()))

def _remove_record(r: _Record) -> None:
_record_path(r).unlink(missing_ok=False)

def _get_op_req(r: _Record, lkp: util.Lookup) -> object:
"""
Queries the state of operation.
NOTE: it DOES NOT "wait" for operation.
"""
if r.zone:
return lkp.compute.zoneOperations().get(project=lkp.project, zone=r.zone, operation=r.name)
elif r.region:
return lkp.compute.regionOperations().get(project=lkp.project, region=r.region, operation=r.name)
raise NotImplementedError("GlobalOperations are not supported")


def _sync_delete_op(r: _Record, lkp: util.Lookup) -> None:
"""
Processes VM delete-operation.
If operation is still running - do nothing
If operation failed - log error & remove op from watch list
If operation is done - remove op from watch list do nothing

To avoid querying status for each op individually, use list of VM instances as
a source of data. Don't query op for instance X if instance X is not present
(presumably deleted).
NOTE: This optimization can lead to false-positives -
absence of error-logs in case op failed, but VM got deleted by other means.
"""
assert len(r.nodes) == 1 and r.type == "delete", f"{r}"
node = r.nodes[0]
inst = lkp.instance(node)

if not inst:
log.debug(f"Stop watching op {r.name}, VM {node} appears to be deleted")
return _remove_record(r) # potentially false-positive

log.debug(f"Watching delete-instance op={r.name}. VM {node} status={inst.status}")
if inst.status == "TERMINATED":
log.debug(f"Stop watching op {r.name}, VM {node} is TERMINATED")
return _remove_record(r) # potentially false-positive

if inst.status == "STOPPING":
log.debug(f"Skipping op {r.name}, VM {node} is STOPPING")
return # try later

op = _get_op_req(r, lkp).execute() # don't handle exceptions, it would be logged and re-tried
if op["status"] != "DONE":
log.debug(f"Watching op {r.name} is still not done ({op['status']})")
return # try later

if "error" in op:
log.error(f"Operation {r.name} to delete {node} finished with error: {op['error']}")
else:
log.debug(f"Operation {r.name} to delete {node} successfully finished")
return _remove_record(r)

def _sync_deletes(records: List[_Record], lkp: util.Lookup) -> None:
log.info(f"Processing {len(records)} delete-instance operations")

for r in records:
try:
_sync_delete_op(r, lkp)
except Exception:
log.exception(f"Failed to process {r}")
# DO NOT skip others ops processing

def sync_ops(lkp: util.Lookup) -> None:
records = _list_records()
for t, records in util.groupby_unsorted(records, lambda r: r.type):
if t == "delete":
_sync_deletes(list(records), lkp)
else:
log.error(f"Unknown type {t} for {len(records)} operations")

def watch_delete_op(op: dict, node: str) -> None:
assert op["operationType"] == "delete"
_add_record(_Record.from_op(op, nodes=[node]))
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
import conf

from slurmsync import sync_slurm
import slurmsync

from setup_network_storage import (
setup_network_storage,
Expand Down Expand Up @@ -373,7 +373,7 @@ def setup_controller():
run("systemctl status slurmctld", timeout=30)
run("systemctl status slurmrestd", timeout=30)

sync_slurm()
slurmsync.sync_nodes()
run("systemctl enable slurm_load_bq.timer", timeout=30)
run("systemctl start slurm_load_bq.timer", timeout=30)
run("systemctl status slurm_load_bq.timer", timeout=30)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import sys
import shlex
from datetime import datetime, timedelta
from enum import Enum
from itertools import chain
from pathlib import Path
from dataclasses import dataclass
Expand All @@ -48,6 +47,7 @@
from suspend import delete_instances
from resume import start_tpu
import conf
import ops_watch

log = logging.getLogger()

Expand Down Expand Up @@ -379,7 +379,7 @@ def sync_placement_groups():
delete_placement_groups(list(placement_groups.values()))


def sync_slurm():
def sync_nodes():
compute_instances = {
name for name, inst in lookup().instances().items() if inst.role == "compute"
}
Expand Down Expand Up @@ -580,16 +580,21 @@ def sync_opportunistic_maintenance(lkp: util.Lookup) -> None:
for job_name, node in create_jobs.items():
create_maintenance_job(job_name, node)


def main():
try:
reconfigure_slurm()
except Exception:
log.exception("failed to reconfigure slurm")

if lookup().is_controller:
lkp = lookup()
if lkp.is_controller:
try:
ops_watch.sync_ops(lkp)
except Exception:
log.exception("failed to sync long running operations")

try:
sync_slurm()
sync_nodes()
except Exception:
log.exception("failed to sync instances")

Expand All @@ -599,17 +604,17 @@ def main():
log.exception("failed to sync placement groups")

try:
update_topology(lookup())
update_topology(lkp)
except Exception:
log.exception("failed to update topology")

try:
sync_maintenance_reservation(lookup())
sync_maintenance_reservation(lkp)
except Exception:
log.exception("failed to sync slurm reservation for scheduled maintenance")

try:
sync_opportunistic_maintenance(lookup())
sync_opportunistic_maintenance(lkp)
except Exception:
log.exception("failed to sync opportunistic reservation for scheduled maintenance")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@

import util
from util import (
groupby_unsorted,
log_api_request,
batch_execute,
to_hostlist,
wait_for_operations,
separate,
execute_with_futures,
)
from util import lookup, TPU

import ops_watch
import slurm_gcp_plugins

log = logging.getLogger()
Expand Down Expand Up @@ -87,22 +85,23 @@ def delete_tpu_instances(instances):

def delete_instances(instances):
"""delete instances individually"""
# TODO: consider not doing an expensive call to `instances()`
invalid, valid = separate(lambda inst: bool(lookup().instance(inst)), instances)
if len(invalid) > 0:
log.debug("instances do not exist: {}".format(",".join(invalid)))
log.info(f"instances do not exist: {to_hostlist(invalid)}")
if len(valid) == 0:
log.debug("No instances to delete")
log.info("No instances to delete")
return

requests = {inst: delete_instance_request(inst) for inst in valid}

log.info(f"delete {len(valid)} instances ({to_hostlist(valid)})")
done, failed = batch_execute(requests)
log.info(f"to delete {len(valid)} instances ({to_hostlist(valid)})")
submitted, failed = batch_execute(requests)
for node, (_, err) in failed.items():
log.error(f"instance {node} failed to delete: {err}")
wait_for_operations(done.values())
# TODO do we need to check each operation for success? That is a lot more API calls
log.info(f"deleted {len(done)} instances {to_hostlist(done.keys())}")
log.info(f"deleting {len(submitted)} instances {to_hostlist(submitted.keys())}")
for node, op in submitted.items(): # Track status of submitted operations
ops_watch.watch_delete_op(op, node)


def suspend_nodes(nodes: List[str]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -980,13 +980,11 @@ def ensure_execute(request):
break


def batch_execute(requests, retry_cb=None, log_err=log.error):
def batch_execute(requests: Dict, retry_cb=None, log_err=log.error):
"""execute list or dict<req_id, request> as batch requests
retry if retry_cb returns true
"""
BATCH_LIMIT = 1000
if not isinstance(requests, dict):
requests = {str(k): v for k, v in enumerate(requests)} # rid generated here
done = {}
failed = {}
timestamps = []
Expand Down Expand Up @@ -1776,6 +1774,10 @@ def instances(self) -> Dict[str, object]:
fields = f"items.zones.instances({instance_fields}),nextPageToken"
flt = f"labels.slurm_cluster_name={self.cfg.slurm_cluster_name} AND name:{self.cfg.slurm_cluster_name}-*"
act = self.compute.instances()
# https://cloud.google.com/compute/docs/reference/rest/v1/instances/aggregatedList
# > The performance of this method degrades when a filter is specified on a project
# > that has a very large number of instances.
# TODO: consider issuing in parallel multiple zonal list requests
op = act.aggregatedList(project=self.project, fields=fields, filter=flt)

def properties(inst):
Expand Down
Loading