diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 5f18884e83d79..3f93472b044f6 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -36,7 +36,7 @@ _CLICK_AVAILABLE = RequirementCache("click") _LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") -_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") +_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto") def _get_supported_strategies() -> list[str]: @@ -208,6 +208,14 @@ def _set_env_variables(args: Namespace) -> None: def _get_num_processes(accelerator: str, devices: str) -> int: """Parse the `devices` argument to determine how many processes need to be launched on the current machine.""" + if accelerator == "auto": + if torch.cuda.is_available(): + accelerator = "cuda" + elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + accelerator = "mps" + else: + accelerator = "cpu" + if accelerator == "gpu": parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) elif accelerator == "cuda":