From 71bd6964ff4bb18551f0405b0ab74991b0555831 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 26 Sep 2024 15:53:42 -0700 Subject: [PATCH] Tweaks to Ray TPU stuff (#747) 1. num_tpus=1 is actually a bad idea because Ray will mask out the other tpus 2. force non-docker workloads to run in a separate process for stability --- infra/cluster/job-cluster.yaml | 23 +- infra/launch_on_ray.py | 61 ++--- src/levanter/infra/ray_tpu.py | 291 ++++++++++++++++------ src/levanter/utils/background_iterable.py | 3 +- 4 files changed, 254 insertions(+), 124 deletions(-) diff --git a/infra/cluster/job-cluster.yaml b/infra/cluster/job-cluster.yaml index 652771fcb..cf8703d54 100644 --- a/infra/cluster/job-cluster.yaml +++ b/infra/cluster/job-cluster.yaml @@ -47,10 +47,23 @@ available_node_types: sourceImage: projects/ubuntu-os-cloud/global/images/family/ubuntu-2204-lts # Worker Nodes =>> + tpu_slice_v4_8: + min_workers: 0 + max_workers: 1024 + resources: { "CPU": 120, "TPU": 4 } + + node_config: + acceleratorType: v4-8 + runtimeVersion: tpu-ubuntu2204-base + + # [IMPORTANT] Configure all TPU Workers to be Preemptible! + schedulingConfig: + preemptible: true + tpu_slice_v4_32: min_workers: 0 max_workers: 1024 - resources: { "CPU": 120, "TPU": 1 } + resources: { "CPU": 120, "TPU": 4 } node_config: acceleratorType: v4-32 @@ -63,7 +76,7 @@ available_node_types: tpu_slice_v4_64: min_workers: 0 max_workers: 1024 - resources: {"CPU": 120, "TPU": 1} + resources: {"CPU": 120, "TPU": 4} node_config: acceleratorType: v4-64 @@ -77,7 +90,7 @@ available_node_types: tpu_slice_v4_128: min_workers: 0 max_workers: 1024 - resources: { "CPU": 120, "TPU": 1 } + resources: { "CPU": 120, "TPU": 4 } node_config: acceleratorType: v4-128 @@ -90,7 +103,7 @@ available_node_types: tpu_slice_v4_256: min_workers: 0 max_workers: 1024 - resources: { "CPU": 120, "TPU": 1 } + resources: { "CPU": 120, "TPU": 4 } node_config: acceleratorType: v4-256 @@ -103,7 +116,7 @@ available_node_types: tpu_slice_v4_512: min_workers: 0 max_workers: 1024 - resources: { "CPU": 120, "TPU": 1 } + resources: { "CPU": 120, "TPU": 4 } node_config: acceleratorType: v4-512 diff --git a/infra/launch_on_ray.py b/infra/launch_on_ray.py index 2040aff44..fa5e81f27 100755 --- a/infra/launch_on_ray.py +++ b/infra/launch_on_ray.py @@ -4,16 +4,15 @@ import argparse import getpass import os -import tempfile import time from pathlib import Path -import draccus from ray.dashboard.modules.job.common import JobStatus from ray.dashboard.modules.job.sdk import JobSubmissionClient import levanter.infra.cli_helpers as cli import levanter.infra.docker as docker +from levanter.infra import ray_tpu def main(): @@ -22,7 +21,7 @@ def main(): cli.add_arg(parser, config, ["--docker_base_image"], default="ghcr.io/stanford-crfm/levanter-base:latest") cli.add_arg(parser, config, ["--docker_repository"], default="levanter") - cli.add_arg(parser, config, ["--address"], default="http://127.0.0.1:8265") + cli.add_arg(parser, config, ["--address"], default=None) cli.add_arg(parser, config, ["--image_name"], default=f"levanter-{getpass.getuser()}") cli.add_capacity_type_args(parser, config) cli.add_arg(parser, config, ["--project"], default=cli.gcloud_config()["project"]) @@ -112,19 +111,11 @@ def main(): env["RUN_ID"] = run_id env["WANDB_DOCKER"] = full_image_id - # run_docker_on_pod( - # full_image_id, - # command=command, - # tpu_type=tpu_type, - # env=env, - # retries=retries, - # ) - # Submit the job to the Ray cluster. We have to use the JobSubmissionClient to do this and stringify the arguments # we want: - from levanter.infra.ray_tpu import RunOnPodConfig + from levanter.infra.ray_tpu import RunDockerOnPodConfig - config = RunOnPodConfig( + config = RunDockerOnPodConfig( image_id=full_image_id, command=command, tpu_type=tpu_type, @@ -133,26 +124,16 @@ def main(): retries=retries, ) - with tempfile.NamedTemporaryFile(suffix=".yaml", prefix=f"launch-{run_id}-", dir=".") as f: - yaml = draccus.dump(config) - f.write(yaml.encode("utf-8")) - f.flush() - - f_name = os.path.relpath(f.name) - print(f"Submitting job with config path {f_name}") - - client = JobSubmissionClient(args.address) + address = args.address or os.getenv("RAY_ADDRESS") - job_id = _make_unique_job_id(client, run_id) - - job_id = client.submit_job( - entrypoint=f"python src/levanter/infra/ray_tpu.py --config_path {f_name}", - runtime_env={"working_dir": "./"}, - job_id=job_id, - ) + job_id = ray_tpu.submit_tpu_job_on_ray( + config, + ray_address=address, + run_id=run_id, + ) - print( - f""" + print( + f""" ------------------------------------------------------- Job '{job_id}' submitted successfully ------------------------------------------------------- @@ -165,9 +146,10 @@ def main(): Request the job to be stopped: ray job stop {job_id} """ - ) + ) if args.foreground: + client = JobSubmissionClient(address) async def tail_job(job_id): async for line in client.tail_job_logs(job_id): # type: ignore @@ -181,7 +163,6 @@ async def tail_job(job_id): wait_until_status( client, job_id, {JobStatus.RUNNING, JobStatus.FAILED, JobStatus.SUCCEEDED, JobStatus.STOPPED} ) - # tail_job(job_id) import asyncio asyncio.run(tail_job(job_id)) @@ -196,19 +177,7 @@ def wait_until_status(client, job_id, status_to_wait_for, timeout_seconds=5): break time.sleep(1) - -# try to make the job id be the same as the run id, but if it already exists, just make it unique -def _make_unique_job_id(client, run_id): - job_id = run_id - try: - while client.get_job_status(job_id) is not None: - job_id = f"{run_id}-{time.time_ns()}" - except Exception as e: # noqa - if "does not exist" in str(e): - pass - else: - raise - return job_id + return status if __name__ == "__main__": diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 69f25d02a..3ae5d0105 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -1,16 +1,23 @@ import dataclasses +import functools import logging +import multiprocessing import os import subprocess +import tempfile +import time from dataclasses import dataclass -from typing import Sequence +from typing import Callable, Optional, Sequence import draccus import ray +from ray._private.accelerators import TPUAcceleratorManager +from ray.dashboard.modules.job.sdk import JobSubmissionClient from ray.exceptions import NodeDiedError, RayError, RaySystemError, RayTaskError, WorkerCrashedError from ray.remote_function import RemoteFunction from levanter.infra.cli_helpers import make_docker_run_command +from levanter.utils.ray_utils import ser_exc_info # CF https://gist.github.com/allenwang28/e3400b9e9212b50aa1cda55ebeccea60 @@ -55,42 +62,61 @@ class TpuRunError(_TpuRunResult): error: Exception -def run_on_pod(remote_fn: RemoteFunction, tpu_type: str): +def run_on_pod(remote_fn: RemoteFunction | Callable, tpu_type: str) -> ray.ObjectRef: """ Run a remote function on a TPU pod. Args: remote_fn: A remote function that takes no arguments tpu_type: The type of TPU to run on, e.g. "v4-32" + + Returns: + A Ray ObjectRef that represents the result of the function """ @ray.remote(resources={f"TPU-{tpu_type}-head": 1}) def do_run(remote_fn) -> _TpuRunResult: - tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> 4 - remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 1}) + remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, num_hosts) info = _TpuInfo(tpu_name, "ACTIVE", "TPU") + futures = [remote_fn.remote() for _ in range(num_hosts)] try: - try: - out = ray.get([remote_fn.remote() for _ in range(num_hosts)]) - logger.info("TPU job finished") - return TpuSuccess(info, out) - except RayError as e: - return _handle_ray_error(info, e) - finally: - # remove the tpu lockfile on each host - logger.debug("Removing lockfiles") - _rm_lockfile = ray.remote(resources={tpu_name: 1, "TPU": 1})(_hacky_remove_tpu_lockfile) - try: - ray.get([_rm_lockfile.remote() for _ in range(num_hosts)]) - except Exception: - logger.exception("Failed to remove lockfile") - # swallow the exception + out = ray.get(futures) + logger.info("TPU job finished") + return TpuSuccess(info, out) + except RayError as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + return _handle_ray_error(info, e) return do_run.remote(remote_fn) +def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts): + """ + Redecorate a remote function to run on a TPU pod. + + Specifically, this function: + + * Adds the TPU resources to the function + * forces the function to run in its own process to remove the TPU lockfile (and shutdown jax distributed) + + """ + remote_fn = _forkify_remote_fn(remote_fn) + if not isinstance(remote_fn, RemoteFunction): + remote_fn = ray.remote(remote_fn) + + tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu + num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8 + remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": num_tpus_per_host}) + logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host") + return remote_fn, tpu_name + + def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_retries_failure=10): """ Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached. @@ -100,50 +126,63 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re tpu_type: The type of TPU to run on, e.g. "v4-32" max_retries_preemption: The maximum number of times to retry if the job is preempted max_retries_failure: The maximum number of times to retry if the job fails + + Returns: + The result of the function (not an ObjectRef) + """ num_failures = 0 num_preemptions = 0 + attempt = 0 + problem: Exception | None = None while num_failures < max_retries_failure and num_preemptions < max_retries_preemption: + logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}") + attempt += 1 + problem = None try: out = ray.get(run_on_pod(remote_fn, tpu_type)) - if isinstance(out, TpuSuccess): - result = out.result - logger.info("Success") - return result - elif isinstance(out, TpuPreempted): - e = out.error - num_preemptions += 1 - print(f"Preempted {num_preemptions} times. {e}") - logger.warning(f"Preempted {num_preemptions} times. {e}", exc_info=e) - elif isinstance(out, TpuFailed): - num_preemptions += 1 - logger.warning(f"TPU node failure. Treating as preempted: {num_preemptions} times") - elif isinstance(out, TpuRunError): - e = out.error - num_failures += 1 - logger.warning(f"Failed {num_failures} times") - logger.exception(e) - else: - raise RuntimeError(f"Unexpected result: {out}") except ray.exceptions.RayTaskError as e: + problem = e if "preempted" in str(e): num_preemptions += 1 logger.warning(f"Preempted {num_preemptions} times, {e}") else: num_failures += 1 logger.warning(f"Failed {num_failures} times") + continue except Exception as e: + problem = e num_failures += 1 - logger.warning(f"Failed {num_failures} times") - logger.exception(e) if num_failures >= max_retries_failure: + logger.exception("Failed too many times", exc_info=e) raise e + else: + logger.warning(f"Failed {num_failures} times", exc_info=e) + continue + + if isinstance(out, TpuSuccess): + result = out.result + logger.info("Success") + return result + elif isinstance(out, TpuPreempted): + problem = out.error + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times. {problem}", exc_info=problem) + elif isinstance(out, TpuFailed): + num_preemptions += 1 + logger.warning(f"TPU node failure. Treating as preempted: {num_preemptions} times") + elif isinstance(out, TpuRunError): + problem = out.error + num_failures += 1 + logger.warning(f"Failed {num_failures} times", exc_info=problem) + else: + raise RuntimeError(f"Unexpected result: {out}") - if num_preemptions >= max_retries_preemption: - raise RuntimeError("Preempted too many times") - elif num_failures >= max_retries_failure: - raise RuntimeError("Failed too many times") + if num_preemptions >= max_retries_preemption: + raise RuntimeError("Preempted too many times") from problem + elif num_failures >= max_retries_failure: + raise RuntimeError("Failed too many times") from problem def _run_command(*args, **kwargs): @@ -170,6 +209,7 @@ def run_docker(): def _kill_old_container(name): try: + logger.info(f"Killing old container {name}") _run_command("sudo", "docker", "rm", "-f", name) except subprocess.CalledProcessError: pass @@ -182,11 +222,9 @@ def _handle_ray_error(tpu_info: _TpuInfo, e: RayError): """ # treat node failures as preemptions if isinstance(e, NodeDiedError): - print("Node died") logger.exception("Node died", exc_info=e) return TpuPreempted(tpu_info, e) elif isinstance(e, WorkerCrashedError): - print("Worker crashed") logger.exception("Worker crashed", exc_info=e) return TpuPreempted(tpu_info, e) elif isinstance(e, RaySystemError): @@ -198,7 +236,6 @@ def _handle_ray_error(tpu_info: _TpuInfo, e: RayError): from levanter.infra.tpus import get_current_tpu_is_preempted if get_current_tpu_is_preempted(): - print("Preempted") logger.exception("Preempted", exc_info=e) return TpuPreempted(tpu_info, e) @@ -210,39 +247,70 @@ def _handle_ray_error(tpu_info: _TpuInfo, e: RayError): return TpuRunError(tpu_info, e) -@dataclass -class RunOnPodConfig: - image_id: str - command: list[str] | str - tpu_type: str - env: dict = dataclasses.field(default_factory=dict) - name: str = "levanter" - retries: int = 10 +def _forkify_remote_fn(remote_fn: RemoteFunction | Callable): + """ + This is a bit of a hacky way to force a remote function to run in its own process, using multiprocessing. + There are a few issues we're trying to cover: + + * libtpu only allows one process to access the TPU at a time, and it uses a lockfile to enforce this. + * Ray runs tasks in a long-running daemon, so the lockfile persists across tasks. + * jax.distributed likes to only be called once per process, even if you call shutdown -@draccus.wrap() -def main(args: RunOnPodConfig): """ - Run a command on a TPU pod. This is a wrapper around `run_docker_on_pod` that takes a config object as a CLI. + if isinstance(remote_fn, RemoteFunction): + fn = remote_fn._function + + @functools.wraps(fn) + def wrapped_fn(*args, **kwargs): + return _separate_process_fn(fn, args, kwargs) + + # We need these arguments to be able to reconstruct the remote function + # def __init__( + # self, + # language, + # function, + # function_descriptor, + # task_options, + # ): + remote_fn = RemoteFunction( + language=remote_fn._language, + function=wrapped_fn, + function_descriptor=remote_fn._function_descriptor, + task_options=remote_fn._default_options, + ) + return remote_fn + else: + return functools.partial(_separate_process_fn, remote_fn) - We use this via infra/launch_on_ray.py to run docker containers on TPUs. + +def _separate_process_fn(underlying_function, args, kwargs): + """ + Helper function for _forkify_remote_fn. This runs the function in a separate process. """ - ray.init() - import shlex + def target_fn(queue, args, kwargs): + try: + # Call the original function + result = underlying_function(*args, **kwargs) + queue.put((True, result)) # Success, put the result + except Exception as e: + # Capture and return the full traceback in case of an exception + info = ser_exc_info(e) + queue.put((False, info)) - if isinstance(args.command, str): - command = shlex.split(args.command) - else: - command = args.command + queue = multiprocessing.Queue() + process = multiprocessing.Process(target=target_fn, args=(queue, args, kwargs)) + process.start() + process.join() - run_docker_on_pod( - args.image_id, - command, - tpu_type=args.tpu_type, - env=args.env, - name=args.name, - ) + # Retrieve the result or error from the queue + success, value = queue.get() + + if success: + return value + else: + value.reraise() def _hacky_remove_tpu_lockfile(): @@ -267,6 +335,85 @@ def _hacky_remove_tpu_lockfile(): pass +@dataclass +class RunDockerOnPodConfig: + image_id: str + command: list[str] | str + tpu_type: str + env: dict = dataclasses.field(default_factory=dict) + name: str = "levanter" + retries: int = 10 + + +def submit_tpu_job_on_ray(config: RunDockerOnPodConfig, ray_address: str, run_id: Optional[str] = None): + """ + Submit a job to run on a TPU pod on a Ray cluster. This programmatically submits a job to the Ray cluster. + This should be run on your local machine, not on the Ray cluster itself. + + If run_id is not provided, a default run ID will be generated. + """ + + with tempfile.NamedTemporaryFile(suffix=".yaml", prefix=f"launch-{run_id}-", dir=".") as f: + yaml = draccus.dump(config) + f.write(yaml.encode("utf-8")) + f.flush() + + f_name = os.path.relpath(f.name) + logger.info(f"Submitting job with config path {f_name}") + + client = JobSubmissionClient(ray_address) + + job_id = _make_unique_job_id(client, run_id) if run_id is not None else None + + job_id = client.submit_job( + entrypoint=f"python -m levanter.infra.ray_tpu --config_path {f_name}", + runtime_env={"working_dir": ".", "env_vars": {"PYTHONPATH": "src:."}}, + submission_id=job_id, + ) + + return job_id + + +# try to make the job id be the same as the run id, but if it already exists, just make it unique +def _make_unique_job_id(client, run_id): + job_id = run_id + try: + while client.get_job_status(job_id) is not None: + job_id = f"{run_id}-{time.time_ns()}" + except Exception as e: # noqa + if "does not exist" in str(e): + pass + else: + raise + return job_id + + +@draccus.wrap() +def main(args: RunDockerOnPodConfig): + """ + *This command is designed to run on a Ray cluster, not on your local machine. You probably want submit_tpu_job_on_ray.* + + Run a command on a TPU pod. This is a wrapper around `run_docker_on_pod` that takes a config object as a CLI. + + We use this via infra/launch_on_ray.py to run docker containers on TPUs. + """ + + import shlex + + if isinstance(args.command, str): + command = shlex.split(args.command) + else: + command = args.command + + run_docker_on_pod( + args.image_id, + command, + tpu_type=args.tpu_type, + env=args.env, + name=args.name, + ) + + def _massage_env(env): # Ray pretends it's running in a TTY, which leads to a ton of log spam from tqdm. # Levanter uses tqdm_loggable, which tries to sniff out the TTY, but it doesn't work with Ray. diff --git a/src/levanter/utils/background_iterable.py b/src/levanter/utils/background_iterable.py index 4318b3f9b..11a80f8ec 100644 --- a/src/levanter/utils/background_iterable.py +++ b/src/levanter/utils/background_iterable.py @@ -82,7 +82,8 @@ def __del__(self): def stop(self, wait: bool = True): self._stop_event.set() - if self.thread is not None and wait: + # I'm getting an error that the thread is threading.current_thread(), which seems impossible + if self.thread is not None and wait and self.thread != threading.current_thread(): self.thread.join() def _fill_queue_with_batches(self):