Skip to content

Commit

Permalink
Enable the embedding task for all content types including image capti…
Browse files Browse the repository at this point in the history
…ons (#336)
  • Loading branch information
edknv authored Jan 16, 2025
1 parent 9ef8e0b commit cf8c5f5
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 131 deletions.
4 changes: 1 addition & 3 deletions client/src/nv_ingest_client/nv_ingest_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
--task 'extract:{"document_type":"pdf", "extract_method":"unstructured_io"}'
--task 'extract:{"document_type":"docx", "extract_text":true, "extract_images":true}'
--task 'store:{"content_type":"image", "store_method":"minio", "endpoint":"minio:9000"}'
--task 'embed:{"text":true, "tables":true}'
--task 'embed'
--task 'vdb_upload'
--task 'caption:{}'
Expand All @@ -143,8 +143,6 @@
- embed: Computes embeddings on multimodal extractions.
Options:
- filter_errors (bool): Flag to filter embedding errors. Optional.
- tables (bool): Flag to create embeddings for table extractions. Optional.
- text (bool): Flag to create embeddings for text extractions. Optional.
\b
- extract: Extracts content from documents, customizable per document type.
Can be specified multiple times for different 'document_type' values.
Expand Down
37 changes: 27 additions & 10 deletions client/src/nv_ingest_client/primitives/tasks/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,30 @@
import logging
from typing import Dict

from pydantic import BaseModel
from pydantic import BaseModel, root_validator

from .task_base import Task

logger = logging.getLogger(__name__)


class EmbedTaskSchema(BaseModel):
text: bool = True
tables: bool = True
filter_errors: bool = False

@root_validator(pre=True)
def handle_deprecated_fields(cls, values):
if "text" in values:
logger.warning(
"'text' parameter is deprecated and will be ignored. Future versions will remove this argument."
)
values.pop("text")
if "tables" in values:
logger.warning(
"'tables' parameter is deprecated and will be ignored. Future versions will remove this argument."
)
values.pop("tables")
return values

class Config:
extra = "forbid"

Expand All @@ -30,13 +42,22 @@ class EmbedTask(Task):
Object for document embedding task
"""

def __init__(self, text: bool = True, tables: bool = True, filter_errors: bool = False) -> None:
def __init__(self, text: bool = None, tables: bool = None, filter_errors: bool = False) -> None:
"""
Setup Embed Task Config
"""
super().__init__()
self._text = text
self._tables = tables

if text is not None:
logger.warning(
"'text' parameter is deprecated and will be ignored. Future versions will remove this argument."
)

if tables is not None:
logger.warning(
"'tables' parameter is deprecated and will be ignored. Future versions will remove this argument."
)

self._filter_errors = filter_errors

def __str__(self) -> str:
Expand All @@ -45,8 +66,6 @@ def __str__(self) -> str:
"""
info = ""
info += "Embed Task:\n"
info += f" text: {self._text}\n"
info += f" tables: {self._tables}\n"
info += f" filter_errors: {self._filter_errors}\n"
return info

Expand All @@ -56,8 +75,6 @@ def to_dict(self) -> Dict:
"""

task_properties = {
"text": self._text,
"tables": self._tables,
"filter_errors": False,
}

Expand Down
211 changes: 113 additions & 98 deletions src/nv_ingest/modules/transforms/embed_extractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,22 +281,54 @@ def _add_embeddings(row, embeddings, info_msgs):
return row


def _get_text_content(row):
def _get_pandas_text_content(row):
"""
A pandas UDF used to select extracted text content to be used to create embeddings.
"""

return row["content"]


def _get_table_content(row):
def _get_pandas_table_content(row):
"""
A pandas UDF used to select extracted table/chart content to be used to create embeddings.
"""

return row["table_metadata"]["table_content"]


def _get_pandas_image_content(row):
"""
A pandas UDF used to select extracted image captions to be used to create embeddings.
"""

return row["image_metadata"]["caption"]


def _get_cudf_text_content(df: cudf.DataFrame):
"""
A cuDF UDF used to select extracted text content to be used to create embeddings.
"""

return df.struct.field("content")


def _get_cudf_table_content(df: cudf.DataFrame):
"""
A cuDF UDF used to select extracted table/chart content to be used to create embeddings.
"""

return df.struct.field("table_metadata").struct.field("table_content")


def _get_cudf_image_content(df: cudf.DataFrame):
"""
A cuDF UDF used to select extracted image captions to be used to create embeddings.
"""

return df.struct.field("image_metadata").struct.field("caption")


def _batch_generator(iterable: Iterable, batch_size=10):
"""
A generator to yield batches of size `batch_size` from an interable.
Expand Down Expand Up @@ -349,7 +381,6 @@ def _generate_batches(prompts: List[str], batch_size: int = 100):

def _generate_embeddings(
ctrl_msg: ControlMessage,
content_type: ContentTypeEnum,
event_loop: asyncio.SelectorEventLoop,
batch_size: int,
api_key: str,
Expand All @@ -361,8 +392,10 @@ def _generate_embeddings(
filter_errors: bool,
):
"""
A function to generate embeddings for the supplied `ContentTypeEnum`. The `ContentTypeEnum` will
drive filtering criteria used to select rows of data to enrich with embeddings.
A function to generate text embeddings for supported content types (TEXT, STRUCTURED, IMAGE).
This function dynamically selects the appropriate metadata field based on content type and
calculates embeddings using the NIM embedding service. AUDIO and VIDEO types are stubbed and skipped.
Parameters
----------
Expand Down Expand Up @@ -403,53 +436,71 @@ def _generate_embeddings(
content_mask : cudf.Series
A boolean mask representing rows filtered to calculate embeddings.
"""
cudf_content_extractor = {
ContentTypeEnum.TEXT: _get_cudf_text_content,
ContentTypeEnum.STRUCTURED: _get_cudf_table_content,
ContentTypeEnum.IMAGE: _get_cudf_image_content,
ContentTypeEnum.AUDIO: lambda _: None, # Not supported yet.
ContentTypeEnum.VIDEO: lambda _: None, # Not supported yet.
}
pandas_content_extractor = {
ContentTypeEnum.TEXT: _get_pandas_text_content,
ContentTypeEnum.STRUCTURED: _get_pandas_table_content,
ContentTypeEnum.IMAGE: _get_pandas_image_content,
ContentTypeEnum.AUDIO: lambda _: None, # Not supported yet.
ContentTypeEnum.VIDEO: lambda _: None, # Not supported yet.
}

logger.debug("Generating text embeddings for supported content types: TEXT, STRUCTURED, IMAGE.")

embedding_dataframes = []
content_masks = []

with ctrl_msg.payload().mutable_dataframe() as mdf:
if mdf.empty:
return None, None

# generate table text mask
if content_type == ContentTypeEnum.TEXT:
content_mask = (mdf["document_type"] == content_type.value) & (
mdf["metadata"].struct.field("content") != ""
).fillna(False)
content_getter = _get_text_content
elif content_type == ContentTypeEnum.STRUCTURED:
table_mask = mdf["document_type"] == content_type.value
if not table_mask.any():
return None, None
content_mask = table_mask & (
mdf["metadata"].struct.field("table_metadata").struct.field("table_content") != ""
).fillna(False)
content_getter = _get_table_content

# exit if matches found
if not content_mask.any():
return None, None

df_text = mdf.loc[content_mask].to_pandas().reset_index(drop=True)
# get text list
filtered_text = df_text["metadata"].apply(content_getter)
# calculate embeddings
filtered_text_batches = _generate_batches(filtered_text.tolist(), batch_size)
text_embeddings = _async_runner(
filtered_text_batches,
api_key,
embedding_nim_endpoint,
embedding_model,
encoding_format,
input_type,
truncate,
event_loop,
filter_errors,
)
# update embeddings in metadata
df_text[["metadata", "document_type", "_contains_embeddings"]] = df_text.apply(
_add_embeddings, **text_embeddings, axis=1
)[["metadata", "document_type", "_contains_embeddings"]]
df_text["_content"] = filtered_text
return ctrl_msg

for content_type, content_getter in pandas_content_extractor.items():
if not content_getter:
logger.debug(f"Skipping unsupported content type: {content_type}")
continue

content_mask = mdf["document_type"] == content_type.value
if not content_mask.any():
continue

cudf_content_getter = cudf_content_extractor[content_type]
content_mask = (content_mask & (cudf_content_getter(mdf["metadata"]) != "")).fillna(False)
if not content_mask.any():
continue

df_content = mdf.loc[content_mask].to_pandas().reset_index(drop=True)
filtered_content = df_content["metadata"].apply(content_getter)
# calculate embeddings
filtered_content_batches = _generate_batches(filtered_content.tolist(), batch_size)
content_embeddings = _async_runner(
filtered_content_batches,
api_key,
embedding_nim_endpoint,
embedding_model,
encoding_format,
input_type,
truncate,
event_loop,
filter_errors,
)
# update embeddings in metadata
df_content[["metadata", "document_type", "_contains_embeddings"]] = df_content.apply(
_add_embeddings, **content_embeddings, axis=1
)[["metadata", "document_type", "_contains_embeddings"]]
df_content["_content"] = filtered_content

embedding_dataframes.append(df_content)
content_masks.append(content_mask)

message = _concatenate_extractions(ctrl_msg, embedding_dataframes, content_masks)

return df_text, content_mask
return message


def _concatenate_extractions(ctrl_msg: ControlMessage, dataframes: List[pd.DataFrame], masks: List[cudf.Series]):
Expand Down Expand Up @@ -493,8 +544,8 @@ def _concatenate_extractions(ctrl_msg: ControlMessage, dataframes: List[pd.DataF
@register_module(MODULE_NAME, MODULE_NAMESPACE)
def _embed_extractions(builder: mrc.Builder):
"""
A pipeline module that receives incoming messages in ControlMessage format and calculates embeddings for
supported document types.
A pipeline module that receives incoming messages in ControlMessage format
and calculates text embeddings for all supported content types.
Parameters
----------
Expand All @@ -519,56 +570,20 @@ def embed_extractions_fn(message: ControlMessage):
try:
task_props = message.remove_task("embed")
model_dump = task_props.model_dump()
embed_text = model_dump.get("text")
embed_tables = model_dump.get("tables")
filter_errors = model_dump.get("filter_errors", False)

logger.debug(f"Generating embeddings: text={embed_text}, tables={embed_tables}")
embedding_dataframes = []
content_masks = []

if embed_text:
df_text, content_mask = _generate_embeddings(
message,
ContentTypeEnum.TEXT,
event_loop,
validated_config.batch_size,
validated_config.api_key,
validated_config.embedding_nim_endpoint,
validated_config.embedding_model,
validated_config.encoding_format,
validated_config.input_type,
validated_config.truncate,
filter_errors,
)
if df_text is not None:
embedding_dataframes.append(df_text)
content_masks.append(content_mask)

if embed_tables:
df_tables, table_mask = _generate_embeddings(
message,
ContentTypeEnum.STRUCTURED,
event_loop,
validated_config.batch_size,
validated_config.api_key,
validated_config.embedding_nim_endpoint,
validated_config.embedding_model,
validated_config.encoding_format,
validated_config.input_type,
validated_config.truncate,
filter_errors,
)
if df_tables is not None:
embedding_dataframes.append(df_tables)
content_masks.append(table_mask)

if len(content_masks) == 0:
return message

message = _concatenate_extractions(message, embedding_dataframes, content_masks)

return message
return _generate_embeddings(
message,
event_loop,
validated_config.batch_size,
validated_config.api_key,
validated_config.embedding_nim_endpoint,
validated_config.embedding_model,
validated_config.encoding_format,
validated_config.input_type,
validated_config.truncate,
filter_errors,
)

except Exception as e:
traceback.print_exc()
Expand Down
2 changes: 0 additions & 2 deletions src/nv_ingest/schemas/ingest_job_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ class IngestTaskDedupSchema(BaseModelNoExt):


class IngestTaskEmbedSchema(BaseModelNoExt):
text: bool = True
tables: bool = True
filter_errors: bool = False


Expand Down
8 changes: 5 additions & 3 deletions src/nv_ingest/schemas/metadata_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ class AccessLevelEnum(int, Enum):


class ContentTypeEnum(str, Enum):
TEXT = "text"
AUDIO = "audio"
EMBEDDING = "embedding"
IMAGE = "image"
INFO_MSG = "info_message"
STRUCTURED = "structured"
TEXT = "text"
UNSTRUCTURED = "unstructured"
INFO_MSG = "info_message"
EMBEDDING = "embedding"
VIDEO = "video"


class StdContentDescEnum(str, Enum):
Expand Down
Loading

0 comments on commit cf8c5f5

Please sign in to comment.