diff --git a/Dockerfile b/Dockerfile index e926ad64..7dff7a2c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,6 +26,7 @@ RUN apt-get update && apt-get install -y \ bzip2 \ ca-certificates \ curl \ + ffmpeg \ libgl1-mesa-glx \ software-properties-common \ wget \ diff --git a/client/src/nv_ingest_client/primitives/tasks/extract.py b/client/src/nv_ingest_client/primitives/tasks/extract.py index 5017eadb..2294e37a 100644 --- a/client/src/nv_ingest_client/primitives/tasks/extract.py +++ b/client/src/nv_ingest_client/primitives/tasks/extract.py @@ -45,6 +45,8 @@ "svg": "image", "tiff": "image", "xml": "lxml", + "mp3": "audio", + "wav": "audio", } _Type_Extract_Method_PDF = Literal[ @@ -63,6 +65,8 @@ _Type_Extract_Method_Image = Literal["image"] +_Type_Extract_Method_Audio = Literal["audio"] + _Type_Extract_Method_Map = { "docx": get_args(_Type_Extract_Method_DOCX), "jpeg": get_args(_Type_Extract_Method_Image), @@ -72,6 +76,8 @@ "pptx": get_args(_Type_Extract_Method_PPTX), "svg": get_args(_Type_Extract_Method_Image), "tiff": get_args(_Type_Extract_Method_Image), + "mp3": get_args(_Type_Extract_Method_Audio), + "wav": get_args(_Type_Extract_Method_Audio), } _Type_Extract_Tables_Method_PDF = Literal["yolox", "pdfium", "nemoretriever_parse"] diff --git a/client/src/nv_ingest_client/util/file_processing/extract.py b/client/src/nv_ingest_client/util/file_processing/extract.py index 97851481..ab430d67 100644 --- a/client/src/nv_ingest_client/util/file_processing/extract.py +++ b/client/src/nv_ingest_client/util/file_processing/extract.py @@ -32,6 +32,8 @@ class DocumentTypeEnum(str, Enum): svg = "svg" tiff = "tiff" txt = "text" + mp3 = "mp3" + wav = "wav" # Maps MIME types to DocumentTypeEnum @@ -64,6 +66,8 @@ class DocumentTypeEnum(str, Enum): "svg": DocumentTypeEnum.svg, "tiff": DocumentTypeEnum.tiff, "txt": DocumentTypeEnum.txt, + "mp3": DocumentTypeEnum.mp3, + "wav": DocumentTypeEnum.wav, # Add more as needed } diff --git a/conda/environments/nv_ingest_environment.yml b/conda/environments/nv_ingest_environment.yml index 3de4ff47..4508f100 100644 --- a/conda/environments/nv_ingest_environment.yml +++ b/conda/environments/nv_ingest_environment.yml @@ -10,6 +10,7 @@ dependencies: - click>=8.1.7 - fastapi>=0.115.6 - fastparquet>=2024.11.0 + - ffmpeg-python>=0.2.0 - fsspec>=2024.10.0 - httpx>=0.28.1 - isodate>=0.7.2 @@ -46,6 +47,7 @@ dependencies: - pip - pip: - llama-index-embeddings-nvidia + - nvidia-riva-client - opencv-python # For some reason conda cant solve our req set with py-opencv so we need to use pip - pymilvus>=2.5.0 - pymilvus[bulk_writer, model] diff --git a/docker-compose.yaml b/docker-compose.yaml index 4ae552f6..4258b416 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -196,6 +196,31 @@ services: profiles: - vlm + audio: + image: nvcr.io/nvidia/riva/riva-speech:2.18.0 + shm_size: 2gb + ports: + - "8019:50051" # grpc + - "8020:50000" # http + user: root + environment: + - MODEL_DEPLOY_KEY=tlt_encode + - NGC_CLI_API_KEY=${RIVA_NGC_API_KEY} + - NGC_CLI_ORG=nvidia + - NGC_CLI_TEAM=riva + - CUDA_VISIBLE_DEVICES=0 + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ["1"] + capabilities: [gpu] + runtime: nvidia + command: bash -c "download_and_deploy_ngc_models nvidia/riva/rmir_asr_conformer_en_us_ofl:2.18.0 && start-riva" + profiles: + - audio + nv-ingest-ms-runtime: image: nvcr.io/nvidia/nemo-microservices/nv-ingest:24.12 build: @@ -215,6 +240,8 @@ services: cap_add: - sys_nice environment: + - AUDIO_GRPC_ENDPOINT=audio:50051 + - AUDIO_INFER_PROTOCOL=grpc - CUDA_VISIBLE_DEVICES=-1 - MAX_INGEST_PROCESS_WORKERS=${MAX_PROCESS_WORKERS:-16} - EMBEDDING_NIM_MODEL_NAME=${EMBEDDING_NIM_MODEL_NAME:-nvidia/llama-3.2-nv-embedqa-1b-v2} diff --git a/src/nv_ingest/modules/injectors/metadata_injector.py b/src/nv_ingest/modules/injectors/metadata_injector.py index c3812c59..8bc2e23a 100644 --- a/src/nv_ingest/modules/injectors/metadata_injector.py +++ b/src/nv_ingest/modules/injectors/metadata_injector.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging +import traceback import mrc import pandas as pd @@ -46,6 +47,9 @@ def on_data(message: IngestControlMessage): "type": content_type.name.lower(), }, "error_metadata": None, + "audio_metadata": ( + None if content_type != ContentTypeEnum.AUDIO else {"audio_type": row["document_type"]} + ), "image_metadata": ( None if content_type != ContentTypeEnum.IMAGE else {"image_type": row["document_type"]} ), @@ -86,7 +90,12 @@ def _metadata_injection(builder: mrc.Builder): annotation_id=MODULE_NAME, raise_on_failure=validated_config.raise_on_failure, skip_processing_if_failed=True ) def _on_data(message: IngestControlMessage) -> IngestControlMessage: - return on_data(message) + try: + return on_data(message) + except Exception as e: + logger.error(f"Unhandled exception in metadata_injector: {e}") + traceback.print_exc() + raise node = builder.make_node("metadata_injector", _on_data) diff --git a/src/nv_ingest/schemas/audio_extractor_schema.py b/src/nv_ingest/schemas/audio_extractor_schema.py new file mode 100755 index 00000000..6b00b4e2 --- /dev/null +++ b/src/nv_ingest/schemas/audio_extractor_schema.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from typing import Optional +from typing import Tuple + +from pydantic import BaseModel +from pydantic import root_validator + +logger = logging.getLogger(__name__) + + +class AudioConfigSchema(BaseModel): + """ + Configuration schema for audio extraction endpoints and options. + + Parameters + ---------- + auth_token : Optional[str], default=None + Authentication token required for secure services. + + audio_endpoints : Tuple[str, str] + A tuple containing the gRPC and HTTP services for the audio_retriever endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + Methods + ------- + validate_endpoints(values) + Validates that at least one of the gRPC or HTTP services is provided for each endpoint. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + + Config + ------ + extra : str + Pydantic config option to forbid extra fields. + """ + + auth_token: Optional[str] = None + audio_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + audio_infer_protocol: Optional[str] = None + + @root_validator(pre=True) + def validate_endpoints(cls, values): + """ + Validates the gRPC and HTTP services for all endpoints. + + Parameters + ---------- + values : dict + Dictionary containing the values of the attributes for the class. + + Returns + ------- + dict + The validated dictionary of values. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + """ + + def clean_service(service): + """Set service to None if it's an empty string or contains only spaces or quotes.""" + if service is None or not service.strip() or service.strip(" \"'") == "": + return None + return service + + endpoint_name = "audio_endpoints" + grpc_service, http_service = values.get(endpoint_name) + grpc_service = clean_service(grpc_service) + http_service = clean_service(http_service) + + if not grpc_service and not http_service: + raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.") + + values[endpoint_name] = (grpc_service, http_service) + + protocol_name = "audio_infer_protocol" + protocol_value = values.get(protocol_name) + + if not protocol_value: + protocol_value = "http" if http_service else "grpc" if grpc_service else "" + + protocol_value = protocol_value.lower() + values[protocol_name] = protocol_value + + return values + + class Config: + extra = "forbid" + + +class AudioExtractorSchema(BaseModel): + """ + Configuration schema for the PDF extractor settings. + + Parameters + ---------- + max_queue_size : int, default=1 + The maximum number of items allowed in the processing queue. + + n_workers : int, default=16 + The number of worker threads to use for processing. + + raise_on_failure : bool, default=False + A flag indicating whether to raise an exception on processing failure. + + audio_extraction_config: Optional[AudioConfigSchema], default=None + Configuration schema for the audio extraction stage. + """ + + max_queue_size: int = 1 + n_workers: int = 16 + raise_on_failure: bool = False + + audio_extraction_config: Optional[AudioConfigSchema] = None + + class Config: + extra = "forbid" diff --git a/src/nv_ingest/schemas/ingest_job_schema.py b/src/nv_ingest/schemas/ingest_job_schema.py index ae3e38fe..b42643cb 100644 --- a/src/nv_ingest/schemas/ingest_job_schema.py +++ b/src/nv_ingest/schemas/ingest_job_schema.py @@ -32,6 +32,8 @@ class DocumentTypeEnum(str, Enum): svg = "svg" tiff = "tiff" txt = "text" + mp3 = "mp3" + wav = "wav" class TaskTypeEnum(str, Enum): diff --git a/src/nv_ingest/schemas/ingest_pipeline_config_schema.py b/src/nv_ingest/schemas/ingest_pipeline_config_schema.py index 84b7fafd..b8774cd4 100644 --- a/src/nv_ingest/schemas/ingest_pipeline_config_schema.py +++ b/src/nv_ingest/schemas/ingest_pipeline_config_schema.py @@ -30,6 +30,7 @@ class PipelineConfigSchema(BaseModel): + # TODO(Devin): Audio chart_extractor_module: ChartExtractorSchema = ChartExtractorSchema() text_splitter_module: TextSplitterSchema = TextSplitterSchema() embedding_storage_module: EmbeddingStorageModuleSchema = EmbeddingStorageModuleSchema() diff --git a/src/nv_ingest/schemas/metadata_schema.py b/src/nv_ingest/schemas/metadata_schema.py index 0f1f406d..6c1a4ea7 100644 --- a/src/nv_ingest/schemas/metadata_schema.py +++ b/src/nv_ingest/schemas/metadata_schema.py @@ -303,6 +303,11 @@ class ChartMetadataSchema(BaseModelNoExt): uploaded_image_uri: str = "" +class AudioMetadataSchema(BaseModelNoExt): + audio_transcript: str = "" + audio_type: str = "" + + # TODO consider deprecating this in favor of info msg... class ErrorMetadataSchema(BaseModelNoExt): task: TaskTypeEnum @@ -325,6 +330,7 @@ class MetadataSchema(BaseModelNoExt): embedding: Optional[List[float]] = None source_metadata: Optional[SourceMetadataSchema] = None content_metadata: Optional[ContentMetadataSchema] = None + audio_metadata: Optional[AudioMetadataSchema] = None text_metadata: Optional[TextMetadataSchema] = None image_metadata: Optional[ImageMetadataSchema] = None table_metadata: Optional[TableMetadataSchema] = None @@ -338,10 +344,12 @@ class MetadataSchema(BaseModelNoExt): @classmethod def check_metadata_type(cls, values): content_type = values.get("content_metadata", {}).get("type", None) - if content_type != ContentTypeEnum.TEXT: - values["text_metadata"] = None + if content_type != ContentTypeEnum.AUDIO: + values["audio_metadata"] = None if content_type != ContentTypeEnum.IMAGE: values["image_metadata"] = None + if content_type != ContentTypeEnum.TEXT: + values["text_metadata"] = None if content_type != ContentTypeEnum.STRUCTURED: values["table_metadata"] = None return values diff --git a/src/nv_ingest/stages/extractors/audio_extractor_stage.py b/src/nv_ingest/stages/extractors/audio_extractor_stage.py new file mode 100755 index 00000000..9b9aef8d --- /dev/null +++ b/src/nv_ingest/stages/extractors/audio_extractor_stage.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import functools +import traceback + +import pandas as pd +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple + +from morpheus.config import Config +from nv_ingest.schemas.audio_extractor_schema import AudioExtractorSchema +from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage + +from nv_ingest.util.audio.parakeet import call_audio_inference_model +from nv_ingest.util.audio.parakeet import create_audio_inference_client + +logger = logging.getLogger(f"morpheus.{__name__}") + + +def _update_metadata(row: pd.Series, audio_client: Any, trace_info: Dict) -> Dict: + """ + Modifies the metadata of a row if the conditions for table extraction are met. + + Parameters + ---------- + row : pd.Series + A row from the DataFrame containing metadata for the audio extraction. + + audio_client : Any + The client used to call the audio inference model. + + trace_info : Dict + Trace information used for logging or debugging. + + Returns + ------- + Dict + The modified metadata if conditions are met, otherwise the original metadata. + + Raises + ------ + ValueError + If critical information (such as metadata) is missing from the row. + """ + + metadata = row.get("metadata") + + if metadata is None: + logger.error("Row does not contain 'metadata'.") + raise ValueError("Row does not contain 'metadata'.") + + base64_audio = metadata.pop("content") + content_metadata = metadata.get("content_metadata", {}) + + # Only modify if content type is audio + if content_metadata.get("type") != "audio": + return metadata + + # Modify audio metadata with the result from the inference model + try: + audio_result = call_audio_inference_model(audio_client, base64_audio, trace_info=trace_info) + metadata["audio_metadata"] = {"audio_transcript": audio_result} + except Exception as e: + logger.error(f"Unhandled error calling audio inference model: {e}", exc_info=True) + traceback.print_exc() + raise + + return metadata + + +def _transcribe_audio( + df: pd.DataFrame, task_props: Dict[str, Any], validated_config: Any, trace_info: Optional[Dict] = None +) -> Tuple[pd.DataFrame, Dict]: + """ + Extracts audio data from a DataFrame. + + Parameters + ---------- + df : pd.DataFrame + DataFrame containing the content from which audio data is to be extracted. + + task_props : Dict[str, Any] + Dictionary containing task properties and configurations. + + validated_config : Any + The validated configuration object for audio extraction. + + trace_info : Optional[Dict], optional + Optional trace information for debugging or logging. Defaults to None. + + Returns + ------- + Tuple[pd.DataFrame, Dict] + A tuple containing the updated DataFrame and the trace information. + + Raises + ------ + Exception + If any error occurs during the audio data extraction process. + """ + logger.debug(f"Entering audio extraction stage with {len(df)} rows.") + + _ = task_props + + parakeet_client = create_audio_inference_client( + validated_config.audio_extraction_config.audio_endpoints, + auth_token=validated_config.audio_extraction_config.auth_token, + ) + + if trace_info is None: + trace_info = {} + logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") + + try: + # Apply the _update_metadata function to each row in the DataFrame + df["metadata"] = df.apply(_update_metadata, axis=1, args=(parakeet_client, trace_info)) + + return df, trace_info + + except Exception as e: + traceback.print_exc() + logger.error(f"Error occurred while extracting audio data: {e}", exc_info=True) + raise + + +def generate_audio_extractor_stage( + c: Config, + stage_config: Dict[str, Any], + task: str = "audio_data_extract", + task_desc: str = "audio_data_extraction", + pe_count: int = 1, +): + """ + Generates a multiprocessing stage to perform audio data extraction. + + Parameters + ---------- + c : Config + Morpheus global configuration object. + + stage_config : Dict[str, Any] + Configuration parameters for the audio content extractor, passed as a dictionary + validated against the `AudioExtractorSchema`. + + task : str, optional + The task name for the stage worker function, defining the specific audio extraction process. + Default is "audio_data_extract". + + task_desc : str, optional + A descriptor used for latency tracing and logging during audio extraction. + Default is "audio_data_extraction". + + pe_count : int, optional + The number of process engines to use for audio data extraction. This value controls + how many worker processes will run concurrently. Default is 1. + + Returns + ------- + MultiProcessingBaseStage + A configured Morpheus stage with an applied worker function that handles audio data extraction + from PDF content. + """ + + validated_config = AudioExtractorSchema(**stage_config) + _wrapped_process_fn = functools.partial(_transcribe_audio, validated_config=validated_config) + + return MultiProcessingBaseStage( + c=c, + pe_count=pe_count, + task=task, + task_desc=task_desc, + process_fn=_wrapped_process_fn, + # document_type="regex:^(mp3|wav)$", + document_type="wav", + ) diff --git a/src/nv_ingest/util/audio/__init__.py b/src/nv_ingest/util/audio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/nv_ingest/util/audio/parakeet.py b/src/nv_ingest/util/audio/parakeet.py new file mode 100644 index 00000000..2c0a11c4 --- /dev/null +++ b/src/nv_ingest/util/audio/parakeet.py @@ -0,0 +1,283 @@ +import base64 +import logging +from typing import Any +from typing import List +from typing import Optional +from typing import Tuple + +import ffmpeg +import grpc +import requests +import riva.client + +from nv_ingest.util.tracing.tagging import traceable_func + +logger = logging.getLogger(__name__) + + +class ParakeetClient: + """ + A simple interface for handling inference with a Parakeet model (e.g., speech, audio-related). + """ + + def __init__( + self, + endpoint: str, + auth_token: Optional[str] = None, + use_ssl: bool = False, + ssl_cert: Optional[str] = None, + auth_metadata: Optional[Tuple[str, str]] = None, + ): + self.endpoint = endpoint + self.auth_token = auth_token + self.use_ssl = use_ssl + self.ssl_cert = ssl_cert + self.auth_metadata = auth_metadata or [] + if self.auth_token: + self.auth_metadata.append(("authorization", f"Bearer {self.auth_token}")) + + @traceable_func(trace_name="{stage_name}::{model_name}") + def infer(self, data: dict, model_name: str, **kwargs) -> Any: + """ + Perform inference using the specified model and input data. + + Parameters + ---------- + data : dict + The input data for inference. + model_name : str + The model name. + kwargs : dict + Additional parameters for inference. + + Returns + ------- + Any + The processed inference results, coalesced in the same order as the input images. + """ + + response = self.transcribe_file(data) + if response is None: + return None, None + segments, transcript = process_transcription_response(response) + logger.debug("Processing Parakeet inference results (pass-through).") + + return transcript + + def transcribe_file( + self, + audio_content: str, + language_code: str = "en-US", + automatic_punctuation: bool = True, + word_time_offsets: bool = True, + max_alternatives: int = 1, + profanity_filter: bool = False, + verbatim_transcripts: bool = True, + speaker_diarization: bool = False, + boosted_lm_words: Optional[List[str]] = None, + boosted_lm_score: float = 0.0, + diarization_max_speakers: int = 0, + start_history: float = 0.0, + start_threshold: float = 0.0, + stop_history: float = 0.0, + stop_history_eou: bool = False, + stop_threshold: float = 0.0, + stop_threshold_eou: bool = False, + ): + # Create authentication and ASR service objects. + auth = riva.client.Auth(self.ssl_cert, self.use_ssl, self.endpoint, self.auth_metadata) + asr_service = riva.client.ASRService(auth) + + # Build the recognition configuration. + recognition_config = riva.client.RecognitionConfig( + language_code=language_code, + max_alternatives=max_alternatives, + profanity_filter=profanity_filter, + enable_automatic_punctuation=automatic_punctuation, + verbatim_transcripts=verbatim_transcripts, + enable_word_time_offsets=word_time_offsets, + ) + + # Add additional configuration parameters. + riva.client.add_word_boosting_to_config( + recognition_config, + boosted_lm_words or [], + boosted_lm_score, + ) + riva.client.add_speaker_diarization_to_config( + recognition_config, + speaker_diarization, + diarization_max_speakers, + ) + riva.client.add_endpoint_parameters_to_config( + recognition_config, + start_history, + start_threshold, + stop_history, + stop_history_eou, + stop_threshold, + stop_threshold_eou, + ) + audio_bytes = base64.b64decode(audio_content) + mono_audio_bytes = convert_to_mono_wav(audio_bytes) + + # Perform offline recognition and print the transcript. + try: + response = asr_service.offline_recognize(mono_audio_bytes, recognition_config) + return response + except grpc.RpcError as e: + logger.error(f"Error transcribing audio file: {e.details()}") + return None + + +def convert_to_mono_wav(audio_bytes): + process = ( + ffmpeg.input("pipe:") + .output("pipe:", format="wav", acodec="pcm_s16le", ar="44100", ac=1) # Added ac=1 + .run_async(pipe_stdin=True, pipe_stdout=True) + ) + + out, _ = process.communicate(input=audio_bytes) + + return out + + +def process_transcription_response(response): + """ + Process a Riva transcription response (a protobuf message) to extract: + - final_transcript: the complete transcript. + - segments: a list of segments with start/end times and text. + + Parameters: + response: The Riva transcription response message. + + Returns: + segments (list): Each segment is a dict with keys "start", "end", and "text". + final_transcript (str): The overall transcript. + """ + words_list = [] + # Iterate directly over the results. + for result in response.results: + # Ensure there is at least one alternative. + if not result.alternatives: + continue + alternative = result.alternatives[0] + # Each alternative has a repeated field "words" + for word_info in alternative.words: + words_list.append(word_info) + + # Build the overall transcript by joining the word strings. + final_transcript = " ".join(word.word for word in words_list) + + # Now, segment the transcript based on punctuation. + segments = [] + current_words = [] + segment_start = None + segment_end = None + punctuation_marks = {".", "?", "!"} + + for word in words_list: + # Mark the start of a segment if not already set. + if segment_start is None: + segment_start = word.start_time + segment_end = word.end_time + current_words.append(word.word) + + # End the segment when a word ends with punctuation. + if word.word and word.word[-1] in punctuation_marks: + segments.append({"start": segment_start, "end": segment_end, "text": " ".join(current_words)}) + current_words = [] + segment_start = None + segment_end = None + + # Add any remaining words as a segment. + if current_words: + segments.append({"start": segment_start, "end": segment_end, "text": " ".join(current_words)}) + + return segments, final_transcript + + +def create_audio_inference_client( + endpoints: Tuple[str, str], + auth_token: Optional[str] = None, + infer_protocol: Optional[str] = None, + timeout: float = 120.0, + max_retries: int = 5, +): + """ + Create a NimClient for interfacing with a model inference server. + + Parameters + ---------- + endpoints : tuple + A tuple containing the gRPC and HTTP endpoints. + model_interface : ModelInterface + The model interface implementation to use. + auth_token : str, optional + Authorization token for HTTP requests (default: None). + infer_protocol : str, optional + The protocol to use ("grpc" or "http"). If not specified, it is inferred from the endpoints. + + Returns + ------- + NimClient + The initialized NimClient. + + Raises + ------ + ValueError + If an invalid infer_protocol is specified. + """ + + grpc_endpoint, http_endpoint = endpoints + + if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()): + infer_protocol = "grpc" + elif infer_protocol is None and http_endpoint: + infer_protocol = "http" + + if infer_protocol not in ["grpc", "http"]: + raise ValueError("Invalid infer_protocol specified. Must be 'grpc' or 'http'.") + + return ParakeetClient(grpc_endpoint, auth_token=auth_token) + + +def call_audio_inference_model(client, audio_content: str, trace_info: dict): + """ + Calls an audio inference model using the provided client. + If the client is a gRPC client, the inference is performed using gRPC. Otherwise, it is performed using HTTP. + Parameters + ---------- + client : + The inference client, which is an HTTP client. + audio_content: str + The audio source to transcribe. + audio_id: str + The unique identifier for the audio content. + trace_info: dict + Trace information for debugging or logging. + Returns + ------- + str or None + The result of the inference as a string if successful, otherwise `None`. + Raises + ------ + RuntimeError + If the HTTP request fails or if the response format is not as expected. + """ + + try: + parakeet_result = client.infer( + audio_content, + model_name="parakeet", + trace_info=trace_info, # traceable_func arg + stage_name="audio_extraction", + ) + + return parakeet_result + except requests.exceptions.RequestException as e: + raise RuntimeError(f"HTTP request failed: {e}") + except KeyError as e: + raise RuntimeError(f"Missing expected key in response: {e}") + except Exception as e: + raise RuntimeError(f"An error occurred during inference: {e}") diff --git a/src/nv_ingest/util/converters/type_mappings.py b/src/nv_ingest/util/converters/type_mappings.py index 4fbfb0a9..2a8a8626 100644 --- a/src/nv_ingest/util/converters/type_mappings.py +++ b/src/nv_ingest/util/converters/type_mappings.py @@ -11,12 +11,14 @@ DocumentTypeEnum.docx: ContentTypeEnum.STRUCTURED, DocumentTypeEnum.html: ContentTypeEnum.STRUCTURED, DocumentTypeEnum.jpeg: ContentTypeEnum.IMAGE, + DocumentTypeEnum.mp3: ContentTypeEnum.AUDIO, DocumentTypeEnum.pdf: ContentTypeEnum.STRUCTURED, DocumentTypeEnum.png: ContentTypeEnum.IMAGE, DocumentTypeEnum.pptx: ContentTypeEnum.STRUCTURED, DocumentTypeEnum.svg: ContentTypeEnum.IMAGE, DocumentTypeEnum.tiff: ContentTypeEnum.IMAGE, DocumentTypeEnum.txt: ContentTypeEnum.TEXT, + DocumentTypeEnum.wav: ContentTypeEnum.AUDIO, } diff --git a/src/nv_ingest/util/pipeline/pipeline_builders.py b/src/nv_ingest/util/pipeline/pipeline_builders.py index 3c245e78..0ec42250 100644 --- a/src/nv_ingest/util/pipeline/pipeline_builders.py +++ b/src/nv_ingest/util/pipeline/pipeline_builders.py @@ -34,6 +34,7 @@ def setup_ingestion_pipeline( image_extractor_stage = add_image_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) docx_extractor_stage = add_docx_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) pptx_extractor_stage = add_pptx_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) + audio_extractor_stage = add_audio_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) ######################################################################################################## ######################################################################################################## @@ -81,7 +82,8 @@ def setup_ingestion_pipeline( pipe.add_edge(pdf_extractor_stage, image_extractor_stage) pipe.add_edge(image_extractor_stage, docx_extractor_stage) pipe.add_edge(docx_extractor_stage, pptx_extractor_stage) - pipe.add_edge(pptx_extractor_stage, image_dedup_stage) + pipe.add_edge(pptx_extractor_stage, audio_extractor_stage) + pipe.add_edge(audio_extractor_stage, image_dedup_stage) pipe.add_edge(image_dedup_stage, image_filter_stage) pipe.add_edge(image_filter_stage, table_extraction_stage) pipe.add_edge(table_extraction_stage, chart_extraction_stage) diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py index 1a73abd2..e33cd513 100644 --- a/src/nv_ingest/util/pipeline/stage_builders.py +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -20,6 +20,7 @@ from nv_ingest.modules.transforms.text_splitter import TextSplitterLoaderFactory from nv_ingest.stages.docx_extractor_stage import generate_docx_extractor_stage from nv_ingest.stages.embeddings.text_embeddings import generate_text_embed_extractor_stage +from nv_ingest.stages.extractors.audio_extractor_stage import generate_audio_extractor_stage from nv_ingest.stages.extractors.image_extractor_stage import generate_image_extractor_stage from nv_ingest.stages.filters import generate_dedup_stage from nv_ingest.stages.filters import generate_image_filter_stage @@ -351,6 +352,60 @@ def add_pptx_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, defa return pptx_extractor_stage +def get_audio_retrieval_service(env_var_prefix): + prefix = env_var_prefix.upper() + grpc_endpoint = os.environ.get( + "AUDIO_GRPC_ENDPOINT", + "", + ) + http_endpoint = os.environ.get( + "AUDIO_HTTP_ENDPOINT", + "", + ) + auth_token = os.environ.get( + "NVIDIA_BUILD_API_KEY", + "", + ) or os.environ.get( + "NGC_API_KEY", + "", + ) + infer_protocol = os.environ.get( + "AUDIO_INFER_PROTOCOL", + "http" if http_endpoint else "grpc" if grpc_endpoint else "", + ) + + logger.info(f"{prefix}_GRPC_TRITON: {grpc_endpoint}") + logger.info(f"{prefix}_HTTP_TRITON: {http_endpoint}") + logger.info(f"{prefix}_INFER_PROTOCOL: {infer_protocol}") + + return grpc_endpoint, http_endpoint, auth_token, infer_protocol + + +def add_audio_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + audio_grpc, audio_http, audio_auth, audio_infer_protocol = get_audio_retrieval_service("audio") + audio_extractor_config = ingest_config.get( + "audio_extraction_module", + { + "audio_extraction_config": { + "audio_endpoints": (audio_grpc, audio_http), + "audio_infer_protocol": audio_infer_protocol, + "auth_token": audio_auth, + # All auth tokens are the same for the moment + } + }, + ) + audio_extractor_stage = pipe.add_stage( + generate_audio_extractor_stage( + morpheus_pipeline_config, + stage_config=audio_extractor_config, + pe_count=8, + task="extract", + task_desc="audio_content_extractor", + ) + ) + return audio_extractor_stage + + def add_image_dedup_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): image_dedup_config = ingest_config.get("dedup_module", {}) image_dedup_stage = pipe.add_stage(