diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index fc78e27c..1a19bbd2 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -1,11 +1,19 @@ import logging import os +import re import warnings from enum import Enum from pathlib import Path from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, + validator, +) from pydantic_settings import ( BaseSettings, PydanticBaseSettingsSource, @@ -33,6 +41,20 @@ class AcceleratorOptions(BaseSettings): num_threads: int = 4 device: AcceleratorDevice = AcceleratorDevice.AUTO + @validator("device") + def validate_device(cls, value): + # Allow both Enum and str inputs + if isinstance(value, AcceleratorDevice): + return value + # Validate as a string + if value in {d.value for d in AcceleratorDevice} or re.match( + r"^cuda(:\d+)?$", value + ): + return AcceleratorDevice(value) + raise ValueError( + "Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'." + ) + @model_validator(mode="before") @classmethod def check_alternative_envvars(cls, data: Any) -> Any: diff --git a/docling/utils/accelerator_utils.py b/docling/utils/accelerator_utils.py index 59b04796..8c930250 100644 --- a/docling/utils/accelerator_utils.py +++ b/docling/utils/accelerator_utils.py @@ -7,36 +7,62 @@ _log = logging.getLogger(__name__) -def decide_device(accelerator_device: AcceleratorDevice) -> str: +def decide_device(accelerator_device: str) -> str: r""" - Resolve the device based on the acceleration options and the available devices in the system + Resolve the device based on the acceleration options and the available devices in the system. + Rules: 1. AUTO: Check for the best available device on the system. 2. User-defined: Check if the device actually exists, otherwise fall-back to CPU """ - cuda_index = 0 device = "cpu" has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available() has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() - if accelerator_device == AcceleratorDevice.AUTO: + if accelerator_device == AcceleratorDevice.AUTO.value: # Handle 'auto' if has_cuda: - device = f"cuda:{cuda_index}" + device = "cuda:0" elif has_mps: device = "mps" - else: - if accelerator_device == AcceleratorDevice.CUDA: - if has_cuda: - device = f"cuda:{cuda_index}" - else: - _log.warning("CUDA is not available in the system. Fall back to 'CPU'") - elif accelerator_device == AcceleratorDevice.MPS: - if has_mps: - device = "mps" + elif accelerator_device.startswith("cuda"): + if has_cuda: + # if cuda device index specified extract device id + parts = accelerator_device.split(":") + if len(parts) == 2 and parts[1].isdigit(): + # select cuda device's id + cuda_index = int(parts[1]) + if cuda_index < torch.cuda.device_count(): + device = f"cuda:{cuda_index}" + else: + _log.warning( + "CUDA device 'cuda:%d' is not available. Fall back to 'CPU'.", + cuda_index, + ) + elif len(parts) == 1: # just "cuda" + device = "cuda:0" else: - _log.warning("MPS is not available in the system. Fall back to 'CPU'") + _log.warning( + "Invalid CUDA device format '%s'. Fall back to 'CPU'", + accelerator_device, + ) + else: + _log.warning("CUDA is not available in the system. Fall back to 'CPU'") + + elif accelerator_device == AcceleratorDevice.MPS.value: + if has_mps: + device = "mps" + else: + _log.warning("MPS is not available in the system. Fall back to 'CPU'") + + elif accelerator_device == AcceleratorDevice.CPU.value: + device = "cpu" + + else: + _log.warning( + "Unknown device option '%s'. Fall back to 'CPU'", accelerator_device + ) _log.info("Accelerator device: '%s'", device) return device diff --git a/docs/examples/run_with_accelerator.py b/docs/examples/run_with_accelerator.py index 5985401d..86715e6e 100644 --- a/docs/examples/run_with_accelerator.py +++ b/docs/examples/run_with_accelerator.py @@ -30,6 +30,11 @@ def main(): # num_threads=8, device=AcceleratorDevice.CUDA # ) + # easyocr doesnt support cuda:N allocation + # accelerator_options = AcceleratorOptions( + # num_threads=8, device="cuda:1" + # ) + pipeline_options = PdfPipelineOptions() pipeline_options.accelerator_options = accelerator_options pipeline_options.do_ocr = True