Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] audio integration #324

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ RUN apt-get update && apt-get install -y \
bzip2 \
ca-certificates \
curl \
ffmpeg \
libgl1-mesa-glx \
software-properties-common \
wget \
Expand Down
6 changes: 6 additions & 0 deletions client/src/nv_ingest_client/primitives/tasks/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
"svg": "image",
"tiff": "image",
"xml": "lxml",
"mp3": "audio",
"wav": "audio",
}

_Type_Extract_Method_PDF = Literal[
Expand All @@ -63,6 +65,8 @@

_Type_Extract_Method_Image = Literal["image"]

_Type_Extract_Method_Audio = Literal["audio"]

_Type_Extract_Method_Map = {
"docx": get_args(_Type_Extract_Method_DOCX),
"jpeg": get_args(_Type_Extract_Method_Image),
Expand All @@ -72,6 +76,8 @@
"pptx": get_args(_Type_Extract_Method_PPTX),
"svg": get_args(_Type_Extract_Method_Image),
"tiff": get_args(_Type_Extract_Method_Image),
"mp3": get_args(_Type_Extract_Method_Audio),
"wav": get_args(_Type_Extract_Method_Audio),
}

_Type_Extract_Tables_Method_PDF = Literal["yolox", "pdfium", "nemoretriever_parse"]
Expand Down
4 changes: 4 additions & 0 deletions client/src/nv_ingest_client/util/file_processing/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class DocumentTypeEnum(str, Enum):
svg = "svg"
tiff = "tiff"
txt = "text"
mp3 = "mp3"
wav = "wav"


# Maps MIME types to DocumentTypeEnum
Expand Down Expand Up @@ -64,6 +66,8 @@ class DocumentTypeEnum(str, Enum):
"svg": DocumentTypeEnum.svg,
"tiff": DocumentTypeEnum.tiff,
"txt": DocumentTypeEnum.txt,
"mp3": DocumentTypeEnum.mp3,
"wav": DocumentTypeEnum.wav,
# Add more as needed
}

Expand Down
2 changes: 2 additions & 0 deletions conda/environments/nv_ingest_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- click>=8.1.7
- fastapi>=0.115.6
- fastparquet>=2024.11.0
- ffmpeg-python>=0.2.0
- fsspec>=2024.10.0
- httpx>=0.28.1
- isodate>=0.7.2
Expand Down Expand Up @@ -46,6 +47,7 @@ dependencies:
- pip
- pip:
- llama-index-embeddings-nvidia
- nvidia-riva-client
- opencv-python # For some reason conda cant solve our req set with py-opencv so we need to use pip
- pymilvus>=2.5.0
- pymilvus[bulk_writer, model]
27 changes: 27 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,31 @@ services:
profiles:
- vlm

audio:
image: nvcr.io/nvidia/riva/riva-speech:2.18.0
shm_size: 2gb
ports:
- "8019:50051" # grpc
- "8020:50000" # http
user: root
environment:
- MODEL_DEPLOY_KEY=tlt_encode
- NGC_CLI_API_KEY=${RIVA_NGC_API_KEY}
- NGC_CLI_ORG=nvidia
- NGC_CLI_TEAM=riva
- CUDA_VISIBLE_DEVICES=0
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ["1"]
capabilities: [gpu]
runtime: nvidia
command: bash -c "download_and_deploy_ngc_models nvidia/riva/rmir_asr_conformer_en_us_ofl:2.18.0 && start-riva"
profiles:
- audio

nv-ingest-ms-runtime:
image: nvcr.io/nvidia/nemo-microservices/nv-ingest:24.12
build:
Expand All @@ -215,6 +240,8 @@ services:
cap_add:
- sys_nice
environment:
- AUDIO_GRPC_ENDPOINT=audio:50051
- AUDIO_INFER_PROTOCOL=grpc
- CUDA_VISIBLE_DEVICES=-1
- MAX_INGEST_PROCESS_WORKERS=${MAX_PROCESS_WORKERS:-16}
- EMBEDDING_NIM_MODEL_NAME=${EMBEDDING_NIM_MODEL_NAME:-nvidia/llama-3.2-nv-embedqa-1b-v2}
Expand Down
11 changes: 10 additions & 1 deletion src/nv_ingest/modules/injectors/metadata_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging
import traceback

import mrc
import pandas as pd
Expand Down Expand Up @@ -46,6 +47,9 @@ def on_data(message: IngestControlMessage):
"type": content_type.name.lower(),
},
"error_metadata": None,
"audio_metadata": (
None if content_type != ContentTypeEnum.AUDIO else {"audio_type": row["document_type"]}
),
"image_metadata": (
None if content_type != ContentTypeEnum.IMAGE else {"image_type": row["document_type"]}
),
Expand Down Expand Up @@ -86,7 +90,12 @@ def _metadata_injection(builder: mrc.Builder):
annotation_id=MODULE_NAME, raise_on_failure=validated_config.raise_on_failure, skip_processing_if_failed=True
)
def _on_data(message: IngestControlMessage) -> IngestControlMessage:
return on_data(message)
try:
return on_data(message)
except Exception as e:
logger.error(f"Unhandled exception in metadata_injector: {e}")
traceback.print_exc()
raise

node = builder.make_node("metadata_injector", _on_data)

Expand Down
127 changes: 127 additions & 0 deletions src/nv_ingest/schemas/audio_extractor_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import logging
from typing import Optional
from typing import Tuple

from pydantic import BaseModel
from pydantic import root_validator

logger = logging.getLogger(__name__)


class AudioConfigSchema(BaseModel):
"""
Configuration schema for audio extraction endpoints and options.

Parameters
----------
auth_token : Optional[str], default=None
Authentication token required for secure services.

audio_endpoints : Tuple[str, str]
A tuple containing the gRPC and HTTP services for the audio_retriever endpoint.
Either the gRPC or HTTP service can be empty, but not both.

Methods
-------
validate_endpoints(values)
Validates that at least one of the gRPC or HTTP services is provided for each endpoint.

Raises
------
ValueError
If both gRPC and HTTP services are empty for any endpoint.

Config
------
extra : str
Pydantic config option to forbid extra fields.
"""

auth_token: Optional[str] = None
audio_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
audio_infer_protocol: Optional[str] = None

@root_validator(pre=True)
def validate_endpoints(cls, values):
"""
Validates the gRPC and HTTP services for all endpoints.

Parameters
----------
values : dict
Dictionary containing the values of the attributes for the class.

Returns
-------
dict
The validated dictionary of values.

Raises
------
ValueError
If both gRPC and HTTP services are empty for any endpoint.
"""

def clean_service(service):
"""Set service to None if it's an empty string or contains only spaces or quotes."""
if service is None or not service.strip() or service.strip(" \"'") == "":
return None
return service

endpoint_name = "audio_endpoints"
grpc_service, http_service = values.get(endpoint_name)
grpc_service = clean_service(grpc_service)
http_service = clean_service(http_service)

if not grpc_service and not http_service:
raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.")

values[endpoint_name] = (grpc_service, http_service)

protocol_name = "audio_infer_protocol"
protocol_value = values.get(protocol_name)

if not protocol_value:
protocol_value = "http" if http_service else "grpc" if grpc_service else ""

protocol_value = protocol_value.lower()
values[protocol_name] = protocol_value

return values

class Config:
extra = "forbid"


class AudioExtractorSchema(BaseModel):
"""
Configuration schema for the PDF extractor settings.

Parameters
----------
max_queue_size : int, default=1
The maximum number of items allowed in the processing queue.

n_workers : int, default=16
The number of worker threads to use for processing.

raise_on_failure : bool, default=False
A flag indicating whether to raise an exception on processing failure.

audio_extraction_config: Optional[AudioConfigSchema], default=None
Configuration schema for the audio extraction stage.
"""

max_queue_size: int = 1
n_workers: int = 16
raise_on_failure: bool = False

audio_extraction_config: Optional[AudioConfigSchema] = None

class Config:
extra = "forbid"
2 changes: 2 additions & 0 deletions src/nv_ingest/schemas/ingest_job_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class DocumentTypeEnum(str, Enum):
svg = "svg"
tiff = "tiff"
txt = "text"
mp3 = "mp3"
wav = "wav"


class TaskTypeEnum(str, Enum):
Expand Down
1 change: 1 addition & 0 deletions src/nv_ingest/schemas/ingest_pipeline_config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@


class PipelineConfigSchema(BaseModel):
# TODO(Devin): Audio
chart_extractor_module: ChartExtractorSchema = ChartExtractorSchema()
text_splitter_module: TextSplitterSchema = TextSplitterSchema()
embedding_storage_module: EmbeddingStorageModuleSchema = EmbeddingStorageModuleSchema()
Expand Down
12 changes: 10 additions & 2 deletions src/nv_ingest/schemas/metadata_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ class ChartMetadataSchema(BaseModelNoExt):
uploaded_image_uri: str = ""


class AudioMetadataSchema(BaseModelNoExt):
audio_transcript: str = ""
audio_type: str = ""


# TODO consider deprecating this in favor of info msg...
class ErrorMetadataSchema(BaseModelNoExt):
task: TaskTypeEnum
Expand All @@ -325,6 +330,7 @@ class MetadataSchema(BaseModelNoExt):
embedding: Optional[List[float]] = None
source_metadata: Optional[SourceMetadataSchema] = None
content_metadata: Optional[ContentMetadataSchema] = None
audio_metadata: Optional[AudioMetadataSchema] = None
text_metadata: Optional[TextMetadataSchema] = None
image_metadata: Optional[ImageMetadataSchema] = None
table_metadata: Optional[TableMetadataSchema] = None
Expand All @@ -338,10 +344,12 @@ class MetadataSchema(BaseModelNoExt):
@classmethod
def check_metadata_type(cls, values):
content_type = values.get("content_metadata", {}).get("type", None)
if content_type != ContentTypeEnum.TEXT:
values["text_metadata"] = None
if content_type != ContentTypeEnum.AUDIO:
values["audio_metadata"] = None
if content_type != ContentTypeEnum.IMAGE:
values["image_metadata"] = None
if content_type != ContentTypeEnum.TEXT:
values["text_metadata"] = None
if content_type != ContentTypeEnum.STRUCTURED:
values["table_metadata"] = None
return values
Expand Down
Loading
Loading