diff --git a/docker-compose.yaml b/docker-compose.yaml index bc0ac49b..c05a71d5 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -194,6 +194,7 @@ services: - sys_nice environment: - CUDA_VISIBLE_DEVICES=-1 + - DISABLE_FAST_API_ACCESS_LOGGING=true - EMBEDDING_NIM_MODEL_NAME=${EMBEDDING_NIM_MODEL_NAME:-nvidia/llama-3.2-nv-embedqa-1b-v2} - INGEST_LOG_LEVEL=DEFAULT # Message client for development @@ -213,27 +214,28 @@ services: - NVIDIA_BUILD_API_KEY=${NVIDIA_BUILD_API_KEY:-${NGC_API_KEY:-ngcapikey}} - OTEL_EXPORTER_OTLP_ENDPOINT=otel-collector:4317 # Self-hosted paddle endpoints. - - PADDLE_GRPC_ENDPOINT=paddle:8001 + #- PADDLE_GRPC_ENDPOINT=paddle:8001 - PADDLE_HTTP_ENDPOINT=http://paddle:8000/v1/infer - - PADDLE_INFER_PROTOCOL=grpc + #- PADDLE_INFER_PROTOCOL=grpc # build.nvidia.com hosted paddle endpoints. #- PADDLE_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/cv/baidu/paddleocr - #- PADDLE_INFER_PROTOCOL=http + - PADDLE_INFER_PROTOCOL=http - READY_CHECK_ALL_COMPONENTS=True - REDIS_MORPHEUS_TASK_QUEUE=morpheus_task_queue # Self-hosted redis endpoints. - - YOLOX_GRPC_ENDPOINT=yolox:8001 + #- YOLOX_GRPC_ENDPOINT=yolox:8001 - YOLOX_HTTP_ENDPOINT=http://yolox:8000/v1/infer - - YOLOX_INFER_PROTOCOL=grpc + #- YOLOX_INFER_PROTOCOL=grpc # build.nvidia.com hosted yolox endpoints. #- YOLOX_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/cv/nvidia/nv-yolox-page-elements-v1 - #- YOLOX_INFER_PROTOCOL=http - - YOLOX_GRAPHIC_ELEMENTS_GRPC_ENDPOINT=yolox-graphic-elements:8001 + - YOLOX_INFER_PROTOCOL=http + #- YOLOX_GRAPHIC_ELEMENTS_GRPC_ENDPOINT=yolox-graphic-elements:8001 - YOLOX_GRAPHIC_ELEMENTS_HTTP_ENDPOINT=http://yolox-graphic-elements:8000/v1/infer - - YOLOX_GRAPHIC_ELEMENTS_INFER_PROTOCOL=grpc - - YOLOX_TABLE_STRUCTURE_GRPC_ENDPOINT=yolox-table-structure:8001 + - YOLOX_GRAPHIC_ELEMENTS_INFER_PROTOCOL=http + #- YOLOX_TABLE_STRUCTURE_GRPC_ENDPOINT=yolox-table-structure:8001 - YOLOX_TABLE_STRUCTURE_HTTP_ENDPOINT=http://yolox-table-structure:8000/v1/infer - - YOLOX_TABLE_STRUCTURE_INFER_PROTOCOL=grpc + - YOLOX_TABLE_STRUCTURE_INFER_PROTOCOL=http + - VLM_CAPTION_ENDPOINT=https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct/chat/completions - VLM_CAPTION_ENDPOINT=http://vlm:8000/v1/chat/completions - VLM_CAPTION_MODEL_NAME=meta/llama-3.2-11b-vision-instruct healthcheck: diff --git a/docker/scripts/entrypoint.sh b/docker/scripts/entrypoint.sh index 83c9a945..0e9a91d3 100755 --- a/docker/scripts/entrypoint.sh +++ b/docker/scripts/entrypoint.sh @@ -27,14 +27,29 @@ SRC_FILE="/opt/docker/bin/entrypoint_source" # Check if user supplied a command if [ "$#" -gt 0 ]; then - # If a command is provided, run it + # If a command is provided, run it. exec "$@" else - # If no command is provided, run the default startup launch + # If no command is provided, run the default startup launch. if [ "${MESSAGE_CLIENT_TYPE}" != "simple" ]; then - # Start uvicorn if MESSAGE_CLIENT_TYPE is not 'simple' - uvicorn nv_ingest.main:app --workers 32 --host 0.0.0.0 --port 7670 & + # Determine the log level for uvicorn. + log_level=$(echo "${INGEST_LOG_LEVEL:-default}" | tr '[:upper:]' '[:lower:]') + if [ "$log_level" = "default" ]; then + log_level="info" + fi + + # Build the uvicorn command with the specified log level. + uvicorn_cmd="uvicorn nv_ingest.main:app --workers 32 --host 0.0.0.0 --port 7670 --log-level ${log_level}" + + # If DISABLE_FAST_API_ACCESS_LOGGING is true, disable access logs. + if [ "${DISABLE_FAST_API_ACCESS_LOGGING}" == "true" ]; then + uvicorn_cmd="${uvicorn_cmd} --no-access-log" + fi + + # Start uvicorn in the background. + $uvicorn_cmd & fi + # Start the microservice entrypoint. python /workspace/microservice_entrypoint.py fi diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index e9108eff..049f201b 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -19,7 +19,8 @@ import concurrent.futures import logging import traceback -from typing import List +import multiprocessing as mp +from typing import List, Dict, Any from typing import Optional from typing import Tuple @@ -219,7 +220,10 @@ def _extract_page_text(page) -> str: The caller decides whether to use per-page or doc-level logic. """ textpage = page.get_textpage() - return textpage.get_text_bounded() + text = textpage.get_text_bounded() + textpage.close() + + return text def _extract_page_images( @@ -264,6 +268,7 @@ def _extract_page_images( extracted_images.append(image_meta) except Exception as e: logger.error(f"Unhandled error extracting image on page {page_idx}: {e}") + obj.close() return extracted_images @@ -312,82 +317,76 @@ def _extract_page_elements( return extracted_page_elements -def pdfium_extractor( - pdf_stream, +# ----------------------------------------------------------------------------- +# Worker function: runs in a separate process to perform pdfium-based extraction. +# ----------------------------------------------------------------------------- +def _pdfium_extraction_worker( + pdf_stream: bytes, extract_text: bool, extract_images: bool, + extract_infographics: bool, extract_tables: bool, extract_charts: bool, - trace_info=None, - **kwargs, -): - logger.debug("Extracting PDF with pdfium backend.") - - row_data = kwargs.get("row_data") - source_id = row_data["source_id"] - - text_depth = kwargs.get("text_depth", "page") - text_depth = TextTypeEnum[text_depth.upper()] - - extract_infographics = kwargs.get("extract_infographics", False) - paddle_output_format = kwargs.get("paddle_output_format", "pseudo_markdown") - paddle_output_format = TableFormatEnum[paddle_output_format.upper()] - - # Basic config - metadata_col = kwargs.get("metadata_column", "metadata") - pdfium_config = kwargs.get("pdfium_config", {}) - if isinstance(pdfium_config, dict): - pdfium_config = PDFiumConfigSchema(**pdfium_config) - - base_unified_metadata = row_data[metadata_col] if metadata_col in row_data.index else {} - base_source_metadata = base_unified_metadata.get("source_metadata", {}) - source_location = base_source_metadata.get("source_location", "") - collection_id = base_source_metadata.get("collection_id", "") - partition_id = base_source_metadata.get("partition_id", -1) - access_level = base_source_metadata.get("access_level", AccessLevelEnum.LEVEL_1) - - doc = libpdfium.PdfDocument(pdf_stream) - pdf_metadata = extract_pdf_metadata(doc, source_id) - page_count = pdf_metadata.page_count - - source_metadata = { - "source_name": pdf_metadata.filename, - "source_id": source_id, - "source_location": source_location, - "source_type": pdf_metadata.source_type, - "collection_id": collection_id, - "date_created": pdf_metadata.date_created, - "last_modified": pdf_metadata.last_modified, - "summary": "", - "partition_id": partition_id, - "access_level": access_level, - } - - logger.debug(f"PDF has {page_count} pages.") - logger.debug( - f"extract_text={extract_text}, extract_images={extract_images}, " - f"extract_tables={extract_tables}, extract_charts={extract_charts}, " - f"extract_infographics={extract_infographics}" - ) - - # Decide if text_depth is PAGE or DOCUMENT - if text_depth != TextTypeEnum.PAGE: - text_depth = TextTypeEnum.DOCUMENT - - extracted_data = [] - accumulated_text = [] - - # Prepare for table/chart extraction - pages_for_tables = [] # We'll accumulate (page_idx, np_image, padding_offset) here - futures = [] # We'll keep track of all the Future objects for table/charts - - with concurrent.futures.ThreadPoolExecutor(max_workers=pdfium_config.workers_per_progress_engine) as executor: - # PAGE LOOP + trace_info: Optional[List] = None, + kwargs: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + This worker function encapsulates pdfium operations. It opens the PDF document, + extracts text and images, and accumulates pages (as tuples of (page_index, numpy_array)) + for table/chart extraction. All data required for further processing is returned + in a dictionary. + """ + try: + kwargs = kwargs or {} + row_data = kwargs.get("row_data") + source_id = row_data["source_id"] + + text_depth = kwargs.get("text_depth", "page") + text_depth = TextTypeEnum[text_depth.upper()] + + paddle_output_format = kwargs.get("paddle_output_format", "pseudo_markdown") + paddle_output_format = TableFormatEnum[paddle_output_format.upper()] + + metadata_col = kwargs.get("metadata_column", "metadata") + pdfium_config = kwargs.get("pdfium_config", {}) + if isinstance(pdfium_config, dict): + pdfium_config = PDFiumConfigSchema(**pdfium_config) + + base_unified_metadata = row_data[metadata_col] if metadata_col in row_data.index else {} + base_source_metadata = base_unified_metadata.get("source_metadata", {}) + source_location = base_source_metadata.get("source_location", "") + collection_id = base_source_metadata.get("collection_id", "") + partition_id = base_source_metadata.get("partition_id", -1) + access_level = base_source_metadata.get("access_level", AccessLevelEnum.LEVEL_1) + + # Open the PDF document using pdfium. + doc = libpdfium.PdfDocument(pdf_stream) + pdf_metadata = extract_pdf_metadata(doc, source_id) + page_count = pdf_metadata.page_count + + source_metadata = { + "source_name": pdf_metadata.filename, + "source_id": source_id, + "source_location": source_location, + "source_type": pdf_metadata.source_type, + "collection_id": collection_id, + "date_created": pdf_metadata.date_created, + "last_modified": pdf_metadata.last_modified, + "summary": "", + "partition_id": partition_id, + "access_level": access_level, + } + + extracted_data = [] + accumulated_text = [] + pages_for_tables: List[Tuple[int, Any]] = [] + + # Process each page. for page_idx in range(page_count): page = doc.get_page(page_idx) page_width, page_height = page.get_size() - # If we want text, extract text now. + # Extract text. if extract_text: page_text = _extract_page_text(page) if text_depth == TextTypeEnum.PAGE: @@ -406,10 +405,9 @@ def pdfium_extractor( ) extracted_data.append(text_meta) else: - # doc-level => accumulate accumulated_text.append(page_text) - # If we want images, extract images now. + # Extract images. if extract_images: image_data = _extract_page_images( page, @@ -432,11 +430,131 @@ def pdfium_extractor( ) pages_for_tables.append((page_idx, image[0], padding_offsets[0])) - # Whenever pages_for_tables hits YOLOX_MAX_BATCH_SIZE, submit a job - if len(pages_for_tables) >= YOLOX_MAX_BATCH_SIZE: - future = executor.submit( + page.close() + + # For document-level text, combine accumulated text. + if extract_text and text_depth == TextTypeEnum.DOCUMENT and accumulated_text: + doc_text_meta = construct_text_metadata( + accumulated_text, + pdf_metadata.keywords, + -1, + -1, + -1, + -1, + page_count, + text_depth, + source_metadata, + base_unified_metadata, + ) + extracted_data.append(doc_text_meta) + + doc.close() + + return { + "extracted_data": extracted_data, + "pages_for_tables": pages_for_tables, + "pdfium_config": pdfium_config, + "page_count": page_count, + "source_metadata": source_metadata, + "base_unified_metadata": base_unified_metadata, + "paddle_output_format": paddle_output_format, + "trace_info": trace_info, + "extract_tables": extract_tables, + "extract_charts": extract_charts, + } + + except Exception as e: + logger.error(f"Error in pdfium extraction worker: {e}") + traceback.print_exc() + raise + + +# ----------------------------------------------------------------------------- +# Main function: calls the worker in an isolated process, then uses a threadpool in the main thread +# to perform table/chart extraction. +# ----------------------------------------------------------------------------- +def pdfium_extractor( + pdf_stream: bytes, + extract_text: bool, + extract_images: bool, + extract_tables: bool, + extract_charts: bool, + trace_info: Optional[List] = None, + **kwargs, +) -> List[Any]: + """ + Extracts text, images, and (optionally) tables/charts from a PDF stream. + This function launches a separate process to isolate pdfium usage and then, + in the main thread, offloads table/chart extraction to a ThreadPoolExecutor. + + Returns: + A list of extracted items matching the downstream expected format. + + Raises: + RuntimeError: If the pdfium extraction process crashes or terminates unexpectedly. + """ + + extract_infographics = kwargs.get("extract_infographics", False) + logger.debug("Launching pdfium extraction in a separate process.") + with concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=mp.get_context("fork")) as executor: + future = executor.submit( + _pdfium_extraction_worker, + pdf_stream, + extract_text, + extract_images, + extract_infographics, + extract_tables, + extract_charts, + trace_info, + kwargs, + ) + try: + result = future.result() # Blocks until the worker finishes. + except concurrent.futures.process.BrokenProcessPool as e: + raise RuntimeError("Pdfium extraction process crashed or terminated unexpectedly.") from e + + logger.debug("Pdfium extraction process completed; processing table/chart extraction in main thread.") + extracted_data = result["extracted_data"] + + # Use a threadpool in the main thread for _extract_tables_and_charts calls. + if (extract_tables or extract_charts or extract_infographics) and result["pages_for_tables"]: + pages = result["pages_for_tables"] + pdfium_config = result["pdfium_config"] + page_count = result["page_count"] + source_metadata = result["source_metadata"] + base_unified_metadata = result["base_unified_metadata"] + paddle_output_format = result["paddle_output_format"] + trace_info = result["trace_info"] + + table_chart_items = [] + futures = [] + # Create a ThreadPoolExecutor with the same max_workers as configured. + with concurrent.futures.ThreadPoolExecutor(max_workers=pdfium_config.workers_per_progress_engine) as executor: + batch = [] + for item in pages: + batch.append(item) + if len(batch) >= YOLOX_MAX_BATCH_SIZE: + futures.append( + executor.submit( + _extract_page_elements, + batch.copy(), + pdfium_config, + page_count, + source_metadata, + base_unified_metadata, + extract_tables, + extract_charts, + extract_infographics, + paddle_output_format, + trace_info=trace_info, + ) + ) + batch = [] + if batch: + futures.append( + executor.submit( _extract_page_elements, - pages_for_tables[:], # pass a copy + batch.copy(), pdfium_config, page_count, source_metadata, @@ -447,48 +565,12 @@ def pdfium_extractor( paddle_output_format, trace_info=trace_info, ) - futures.append(future) - pages_for_tables.clear() + ) - page.close() + for fut in concurrent.futures.as_completed(futures): + table_chart_items.extend(fut.result()) - # After page loop, if we still have leftover pages_for_tables, submit one last job - if (extract_tables or extract_charts or extract_infographics) and pages_for_tables: - future = executor.submit( - _extract_page_elements, - pages_for_tables[:], - pdfium_config, - page_count, - source_metadata, - base_unified_metadata, - extract_tables, - extract_charts, - extract_infographics, - paddle_output_format, - trace_info=trace_info, - ) - futures.append(future) - pages_for_tables.clear() - - # Now wait for all futures to complete - for fut in concurrent.futures.as_completed(futures): - table_chart_items = fut.result() # blocks until finished - extracted_data.extend(table_chart_items) - - # DOC-LEVEL TEXT added last - if extract_text and text_depth == TextTypeEnum.DOCUMENT and accumulated_text: - doc_text_meta = construct_text_metadata( - accumulated_text, - pdf_metadata.keywords, - -1, - -1, - -1, - -1, - page_count, - text_depth, - source_metadata, - base_unified_metadata, - ) - extracted_data.append(doc_text_meta) + extracted_data.extend(table_chart_items) + logger.debug("Pdfium extraction completed; process terminated and resources freed.") return extracted_data diff --git a/src/nv_ingest/modules/telemetry/otel_meter.py b/src/nv_ingest/modules/telemetry/otel_meter.py index 6ed99b14..1fb765c0 100644 --- a/src/nv_ingest/modules/telemetry/otel_meter.py +++ b/src/nv_ingest/modules/telemetry/otel_meter.py @@ -109,6 +109,9 @@ def update_job_latency(message): ts_entry = message.get_timestamp(entry_key) job_name = key.replace("trace::exit::", "") + if ts_entry is None or ts_exit is None: + continue + # Sanitize job name sanitized_job_name = sanitize_name(job_name) diff --git a/src/nv_ingest/modules/telemetry/otel_tracer.py b/src/nv_ingest/modules/telemetry/otel_tracer.py index 5045a927..2bf5ea4b 100644 --- a/src/nv_ingest/modules/telemetry/otel_tracer.py +++ b/src/nv_ingest/modules/telemetry/otel_tracer.py @@ -129,7 +129,7 @@ def extract_timestamps_from_message(message): timestamps = {} dedup_counter = {} - for key, val in message.filter_timestamp("trace::exit::").items(): + for key, val in message.filter_timestamp("trace::exit::.*").items(): exit_key = key entry_key = exit_key.replace("trace::exit::", "trace::entry::") @@ -142,6 +142,8 @@ def extract_timestamps_from_message(message): ts_entry = message.get_timestamp(entry_key) ts_exit = message.get_timestamp(exit_key) + if (ts_entry is None) or (ts_exit is None): + continue ts_entry_ns = int(ts_entry.timestamp() * 1e9) ts_exit_ns = int(ts_exit.timestamp() * 1e9) diff --git a/src/nv_ingest/stages/multiprocessing_stage.py b/src/nv_ingest/stages/multiprocessing_stage.py index 01db1e14..91f495e8 100644 --- a/src/nv_ingest/stages/multiprocessing_stage.py +++ b/src/nv_ingest/stages/multiprocessing_stage.py @@ -261,6 +261,7 @@ def work_package_input_handler( future = process_pool.submit_task(process_fn, (df, task_props)) # This can return/raise an exception result = future.result() + extra_results = [] if isinstance(result, tuple): result, *extra_results = result diff --git a/src/nv_ingest/util/multi_processing/mp_pool_singleton.py b/src/nv_ingest/util/multi_processing/mp_pool_singleton.py index 3aae3ff2..5f9a51c4 100644 --- a/src/nv_ingest/util/multi_processing/mp_pool_singleton.py +++ b/src/nv_ingest/util/multi_processing/mp_pool_singleton.py @@ -2,16 +2,14 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 - import logging import math import multiprocessing as mp import os -from multiprocessing import Manager -from threading import Lock -from typing import Any -from typing import Callable -from typing import Optional +import threading +from ctypes import py_object +from threading import RLock +from typing import Any, Callable, Optional logger = logging.getLogger(__name__) @@ -20,75 +18,64 @@ class SimpleFuture: """ A simplified future object for handling asynchronous task results. - This class allows the storage and retrieval of the result or exception from an asynchronous task, - using multiprocessing primitives for inter-process communication. - - Parameters - ---------- - manager : multiprocessing.Manager - A multiprocessing manager that provides shared memory for the result and exception. + This class uses multiprocessing primitives to store and retrieve the result or exception + from an asynchronous task, and it pins the manager used to create the shared proxies so that + they remain valid until the future is resolved. Attributes ---------- - _result : multiprocessing.Value - A shared memory object to store the result of the asynchronous task. - _exception : multiprocessing.Value - A shared memory object to store any exception raised during task execution. - _done : multiprocessing.Event - An event that signals the completion of the task. - - Methods - ------- - set_result(result) - Sets the result of the task and marks the task as done. - set_exception(exception) - Sets the exception of the task and marks the task as done. - result() - Waits for the task to complete and returns the result, or raises the exception if one occurred. + _manager : mp.Manager + The Manager instance used to create shared objects. It is "pinned" (kept alive) until + the future is resolved. + _result : mp.Value + A proxy holding the result of the asynchronous task. + _exception : mp.Value + A proxy holding any exception raised during task execution. + _done : mp.Event + A synchronization event that signals task completion. """ - def __init__(self, manager: Manager): - self._result = manager.Value("i", None) - self._exception = manager.Value("i", None) + def __init__(self, manager: mp.Manager) -> None: + """ + Initialize a SimpleFuture. + + Parameters + ---------- + manager : mp.Manager + The Manager instance used to create shared objects. + """ + self._manager = manager # Pin the manager until this future is resolved. + self._result = manager.Value(py_object, None) + self._exception = manager.Value(py_object, None) self._done = manager.Event() def set_result(self, result: Any) -> None: """ - Sets the result of the asynchronous task and signals task completion. + Set the result of the asynchronous task. Parameters ---------- result : Any - The result of the asynchronous task. - - Returns - ------- - None + The result produced by the task. """ self._result.value = result self._done.set() def set_exception(self, exception: Exception) -> None: """ - Sets the exception raised by the asynchronous task and signals task completion. + Set an exception raised during the execution of the asynchronous task. Parameters ---------- exception : Exception - The exception raised during task execution. - - Returns - ------- - None + The exception encountered during task execution. """ self._exception.value = exception self._done.set() def result(self) -> Any: """ - Retrieves the result of the asynchronous task or raises the exception if one occurred. - - This method blocks until the task is complete. + Block until the task completes and return the result. Returns ------- @@ -98,54 +85,63 @@ def result(self) -> Any: Raises ------ Exception - The exception raised during task execution, if any. + Re-raises any exception encountered during task execution. """ self._done.wait() if self._exception.value is not None: raise self._exception.value return self._result.value + def __getstate__(self) -> dict: + """ + Return the state for pickling, excluding the _manager to avoid pickling errors. + + Returns + ------- + dict + The object's state without the _manager attribute. + """ + state = self.__dict__.copy() + state.pop("_manager", None) + return state + class ProcessWorkerPoolSingleton: """ A singleton process worker pool for managing a fixed number of worker processes. - This class implements a process pool using the singleton pattern, ensuring that only one instance - of the pool exists. It manages worker processes that can execute tasks asynchronously. + This class implements a process pool using the singleton pattern, ensuring that only one + instance exists. It manages worker processes that execute tasks asynchronously. A background + thread periodically checks if the task queue is empty; if so, it refreshes the entire pool: + - Closes (and optionally joins) all current worker processes (without shutting down the active Manager). + - Creates a new Manager. + - Re-creates all worker processes using the new Manager. + - Swaps in the new Manager as the active manager, allowing the old Manager to eventually be garbage collected. + + The public task submission interface (submit_task) remains unchanged. Attributes ---------- - _instance : ProcessWorkerPoolSingleton or None - The singleton instance of the class. - _lock : threading.Lock - A lock to ensure thread-safe initialization of the singleton instance. + _instance : Optional[ProcessWorkerPoolSingleton] + The singleton instance. + _lock : RLock + A reentrant lock to ensure thread-safe access. _total_workers : int The total number of worker processes. - - Methods - ------- - __new__(cls) - Ensures only one instance of the class is created. - _initialize(total_max_workers) - Initializes the worker pool with the specified number of workers. - submit_task(process_fn, *args) - Submits a task to the worker pool for asynchronous execution. - close() - Closes the worker pool and terminates all worker processes. """ _instance: Optional["ProcessWorkerPoolSingleton"] = None - _lock: Lock = Lock() + _lock: RLock = RLock() # Use reentrant lock to avoid deadlocks in nested acquisitions. _total_workers: int = 0 - def __new__(cls): + def __new__(cls) -> "ProcessWorkerPoolSingleton": """ - Ensures that only one instance of the ProcessWorkerPoolSingleton is created. + Create or return the singleton instance. Returns ------- ProcessWorkerPoolSingleton - The singleton instance of the class. + The singleton instance. """ logger.debug("Creating ProcessWorkerPoolSingleton instance...") with cls._lock: @@ -153,104 +149,165 @@ def __new__(cls): cls._instance = super(ProcessWorkerPoolSingleton, cls).__new__(cls) max_workers = math.floor(max(1, len(os.sched_getaffinity(0)) * 0.4)) cls._instance._initialize(max_workers) - logger.debug(f"ProcessWorkerPoolSingleton instance created: {cls._instance}") + cls._instance._start_manager_monitor() + logger.info(f"ProcessWorkerPoolSingleton instance created: {cls._instance}") else: - logger.debug(f"ProcessWorkerPoolSingleton instance already exists: {cls._instance}") + logger.info(f"ProcessWorkerPoolSingleton instance already exists: {cls._instance}") return cls._instance - def _initialize(self, total_max_workers: int) -> None: + def _initialize(self, total_max_workers: int, new_manager: Optional[mp.Manager] = None) -> None: """ - Initializes the worker pool with the specified number of worker processes. + Initialize the worker pool with a specified number of worker processes. Parameters ---------- total_max_workers : int - The maximum number of worker processes to create. - - Returns - ------- - None + The number of worker processes to create. + new_manager : Optional[mp.Manager], optional + A new Manager to use for shared objects. If None, a new Manager is created. """ - self._total_max_workers = total_max_workers + self._total_workers = total_max_workers self._context = mp.get_context("fork") self._task_queue = self._context.Queue() - self._manager = mp.Manager() + self._manager = new_manager if new_manager is not None else mp.Manager() + self._active_manager = self._manager self._processes = [] + logger.debug(f"Initializing ProcessWorkerPoolSingleton with {total_max_workers} workers.") for i in range(total_max_workers): - p = self._context.Process(target=self._worker, args=(self._task_queue, self._manager)) - p.start() - self._processes.append(p) - logger.debug(f"Started worker process {i + 1}/{total_max_workers}: PID {p.pid}") + process = self._context.Process(target=self._worker, args=(self._task_queue, self._manager)) + process.start() + self._processes.append(process) + logger.debug(f"Started worker process {i + 1}/{total_max_workers}: PID {process.pid}") logger.debug(f"Initialized with max workers: {total_max_workers}") + def _start_manager_monitor(self) -> None: + """ + Start a background thread that periodically checks if the task queue is empty. + """ + self._stop_manager_monitor = False + self._monitor_thread = threading.Thread(target=self._monitor_manager, daemon=True) + self._monitor_thread.start() + logger.debug("Started Manager monitoring thread.") + + def _monitor_manager(self) -> None: + """ + Periodically check whether the task queue is empty. If so, refresh the pool. + + Notes + ----- + Consider adding exception handling in this loop to prevent unexpected thread termination. + """ + import time + + check_interval = 2 * 60 # 5 minute Manager cache rotation interval + while not self._stop_manager_monitor: + time.sleep(check_interval) + with self._lock: + self._refresh_manager() + + def _refresh_manager(self) -> None: + """ + Refresh the Manager and re-create all worker processes. + + This method performs the following steps: + 1. Closes current worker processes without shutting down the active Manager. + 3. reinitializes the worker pool using the new manager. + 4. swaps in the new manager as the active manager. + 2. Creates a new Manager. + """ + logger.warning("Cycling ProcessWorkerPoolSingleton workers...") + + # Close current workers without waiting (join=False). + self.close(join=False) + + # Create a new Manager and reinitialize the worker pool. + new_manager = mp.Manager() + self._initialize(self._total_workers, new_manager=new_manager) + + # Swap in the new Manager. + self._active_manager = new_manager + logger.warning("ProcessWorkerPoolSingleton workers cycled.") + @staticmethod def _worker(task_queue: mp.Queue, manager: mp.Manager) -> None: """ - The worker process function that executes tasks from the queue. + Worker process function that executes tasks from the queue. Parameters ---------- - task_queue : multiprocessing.Queue + task_queue : mp.Queue The queue from which tasks are retrieved. - manager : multiprocessing.Manager - The manager providing shared memory for inter-process communication. - - Returns - ------- - None + manager : mp.Manager + The Manager instance used to create shared objects. """ logger.debug(f"Worker process started: PID {os.getpid()}") while True: task = task_queue.get() - if task is None: # Stop signal + if task is None: logger.debug(f"Worker process {os.getpid()} received stop signal.") break future, process_fn, args = task args, *kwargs = args try: - result = process_fn(*args, **{k: v for kwarg in kwargs for k, v in kwarg.items()}) + # Flatten kwargs from list of dictionaries. + kwargs_dict = {k: v for kwarg in kwargs for k, v in kwarg.items()} + result = process_fn(*args, **kwargs_dict) future.set_result(result) except Exception as e: - logger.error(f"Future result failure - {e}\n") + logger.error(f"Future result failure - {e}") future.set_exception(e) def submit_task(self, process_fn: Callable, *args: Any) -> SimpleFuture: """ - Submits a task to the worker pool for asynchronous execution. + Submit a task to the worker pool for asynchronous execution. Parameters ---------- - process_fn : callable - The function to be executed by the worker process. - args : tuple - The arguments to pass to the function. + process_fn : Callable + The function to be executed by a worker. + *args : Any + Positional arguments for the function. Returns ------- SimpleFuture - A future object representing the result of the task. + A future representing the asynchronous execution of the task. """ - future = SimpleFuture(self._manager) - self._task_queue.put((future, process_fn, args)) - return future + with self._lock: + future = SimpleFuture(self._active_manager) + self._task_queue.put((future, process_fn, args)) + return future - def close(self) -> None: + def close(self, join: bool = True) -> None: """ - Closes the worker pool and terminates all worker processes. + Close the worker pool by sending stop signals to all workers. + Optionally waits for them to terminate (join). - This method sends a stop signal to each worker and waits for them to terminate. + The active Manager is not shut down so that outstanding references remain valid. - Returns - ------- - None + Parameters + ---------- + join : bool, optional + If True (default), waits for the worker processes to terminate. + If False, sends stop signals and returns immediately. """ - logger.debug("Closing ProcessWorkerPoolSingleton...") - for _ in range(self._total_max_workers): - self._task_queue.put(None) # Send stop signal to all workers + logger.debug("Closing ProcessWorkerPoolSingleton workers...") + for _ in range(self._total_workers): + self._task_queue.put(None) logger.debug("Sent stop signal to worker.") - for i, p in enumerate(self._processes): - p.join() - logger.debug(f"Worker process {i + 1}/{self._total_max_workers} joined: PID {p.pid}") - logger.debug("ProcessWorkerPoolSingleton closed.") + if join: + for i, process in enumerate(self._processes): + process.join() + logger.debug(f"Worker process {i + 1}/{self._total_workers} joined: PID {process.pid}") + logger.debug("Worker pool closed.") + + def shutdown_manager_monitor(self) -> None: + """ + Stop the background Manager monitoring thread. + """ + self._stop_manager_monitor = True + if hasattr(self, "_monitor_thread"): + self._monitor_thread.join(timeout=5) + logger.debug("Manager monitoring thread stopped.") diff --git a/src/nv_ingest/util/pdf/pdfium.py b/src/nv_ingest/util/pdf/pdfium.py index 9769d450..dbdb38e5 100644 --- a/src/nv_ingest/util/pdf/pdfium.py +++ b/src/nv_ingest/util/pdf/pdfium.py @@ -44,6 +44,7 @@ def convert_bitmap_to_corrected_numpy(bitmap: pdfium.PdfBitmap) -> np.ndarray: # Get the bitmap format information bitmap_info = bitmap.get_info() mode = bitmap_info.mode # Use the mode to identify the correct format + del bitmap_info # Convert to a NumPy array using the built-in method img_arr = bitmap.to_numpy().copy() @@ -109,6 +110,7 @@ def pdfium_try_get_bitmap_as_numpy(image_obj) -> np.ndarray: # Convert the bitmap to a NumPy array img_array = convert_bitmap_to_corrected_numpy(image_bitmap) + image_bitmap.close() return img_array @@ -171,6 +173,7 @@ def pdfium_pages_to_numpy( # Convert the bitmap to a PIL image pil_image = page_bitmap.to_pil() + page_bitmap.close() # Apply scaling using the thumbnail approach if specified if scale_tuple: diff --git a/src/nv_ingest/util/pipeline/pipeline_builders.py b/src/nv_ingest/util/pipeline/pipeline_builders.py index 3c245e78..db145ff7 100644 --- a/src/nv_ingest/util/pipeline/pipeline_builders.py +++ b/src/nv_ingest/util/pipeline/pipeline_builders.py @@ -2,7 +2,7 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import os +# import os import typing from morpheus.config import Config @@ -17,7 +17,7 @@ def setup_ingestion_pipeline( pipe: Pipeline, morpheus_pipeline_config: Config, ingest_config: typing.Dict[str, typing.Any] ): default_cpu_count = get_default_cpu_count() - add_meter_stage = os.environ.get("MESSAGE_CLIENT_TYPE") != "simple" + # add_meter_stage = os.environ.get("MESSAGE_CLIENT_TYPE") != "simple" ######################################################################################################## ## Insertion and Pre-processing stages @@ -66,11 +66,11 @@ def setup_ingestion_pipeline( ####################################################################################################### ## Telemetry (Note: everything after the sync stage is out of the hot path, please keep it that way) ## ####################################################################################################### - otel_tracer_stage = add_otel_tracer_stage(pipe, morpheus_pipeline_config, ingest_config) - if add_meter_stage: - otel_meter_stage = add_otel_meter_stage(pipe, morpheus_pipeline_config, ingest_config) - else: - otel_meter_stage = None + # otel_tracer_stage = add_otel_tracer_stage(pipe, morpheus_pipeline_config, ingest_config) + # if add_meter_stage: + # otel_meter_stage = add_otel_meter_stage(pipe, morpheus_pipeline_config, ingest_config) + # else: + # otel_meter_stage = None completed_job_counter_stage = add_completed_job_counter_stage(pipe, morpheus_pipeline_config, ingest_config) ######################################################################################################## @@ -94,10 +94,11 @@ def setup_ingestion_pipeline( pipe.add_edge(embedding_storage_stage, vdb_task_sink_stage) pipe.add_edge(vdb_task_sink_stage, sink_stage) - if add_meter_stage: - pipe.add_edge(sink_stage, otel_meter_stage) - pipe.add_edge(otel_meter_stage, otel_tracer_stage) - else: - pipe.add_edge(sink_stage, otel_tracer_stage) + # if add_meter_stage: + # pipe.add_edge(sink_stage, otel_meter_stage) + # pipe.add_edge(otel_meter_stage, otel_tracer_stage) + # else: + # pipe.add_edge(sink_stage, otel_tracer_stage) - pipe.add_edge(otel_tracer_stage, completed_job_counter_stage) + # pipe.add_edge(otel_tracer_stage, completed_job_counter_stage) + pipe.add_edge(sink_stage, completed_job_counter_stage)