Skip to content

Commit

Permalink
feat: support pdf upload and add tutorial (#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar authored Feb 10, 2025
1 parent be7add5 commit f4abe55
Show file tree
Hide file tree
Showing 7 changed files with 34,075 additions and 3 deletions.
48 changes: 46 additions & 2 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
"""

import asyncio
import base64
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple, Union

import requests
from jinja2 import Template
from litellm.utils import ModelResponse
from pydantic import Field, field_validator
Expand Down Expand Up @@ -38,6 +40,7 @@ class schema(BaseOperation.schema):
clustering_method: Optional[str] = None
batch_prompt: Optional[str] = None
litellm_completion_kwargs: Dict[str, Any] = {}
pdf_url_key: Optional[str] = None

@field_validator("drop_keys")
def validate_drop_keys(cls, v):
Expand Down Expand Up @@ -166,6 +169,29 @@ def _process_map_item(
) -> Tuple[Optional[Dict], float]:

prompt = strict_render(self.config["prompt"], {"input": item})
messages = [{"role": "user", "content": prompt}]
if self.config.get("pdf_url_key", None):
# Append the pdf to the prompt
try:
pdf_url = item[self.config["pdf_url_key"]]
except KeyError:
raise ValueError(
f"PDF URL key '{self.config['pdf_url_key']}' not found in input data"
)

# Download content
if pdf_url.startswith("http"):
file_data = requests.get(pdf_url).content
else:
with open(pdf_url, "rb") as f:
file_data = f.read()
encoded_file = base64.b64encode(file_data).decode("utf-8")
base64_url = f"data:application/pdf;base64,{encoded_file}"

messages[0]["content"] = [
{"type": "image_url", "image_url": {"url": base64_url}},
{"type": "text", "text": prompt},
]

def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
output = (
Expand Down Expand Up @@ -196,7 +222,7 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
llm_result = self.runner.api.call_llm(
self.config.get("model", self.default_model),
"map",
[{"role": "user", "content": prompt}],
messages,
self.config["output"]["schema"],
tools=self.config.get("tools", None),
scratchpad=None,
Expand Down Expand Up @@ -244,6 +270,10 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]:
total_cost = 0
if len(items) > 1 and self.config.get("batch_prompt", None):
# Raise error if pdf_url_key is set
if self.config.get("pdf_url_key", None):
raise ValueError("Batch prompts do not support PDF URLs")

batch_prompt = strict_render(
self.config["batch_prompt"], {"inputs": items}
)
Expand Down Expand Up @@ -361,6 +391,7 @@ class schema(BaseOperation.schema):
prompts: List[Dict[str, Any]]
output: Dict[str, Any]
enable_observability: bool = False
pdf_url_key: Optional[str] = None

def __init__(
self,
Expand Down Expand Up @@ -489,6 +520,19 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:

def process_prompt(item, prompt_config):
prompt = strict_render(prompt_config["prompt"], {"input": item})
messages = [{"role": "user", "content": prompt}]
if self.config.get("pdf_url_key", None):
try:
pdf_url = item[self.config["pdf_url_key"]]
except KeyError:
raise ValueError(
f"PDF URL key '{self.config['pdf_url_key']}' not found in input data"
)
messages[0]["content"] = [
{"type": "image_url", "image_url": {"url": pdf_url}},
{"type": "text", "text": prompt},
]

local_output_schema = {
key: output_schema[key] for key in prompt_config["output_keys"]
}
Expand All @@ -501,7 +545,7 @@ def process_prompt(item, prompt_config):
response = self.runner.api.call_llm(
model,
"parallel_map",
[{"role": "user", "content": prompt}],
messages,
local_output_schema,
tools=prompt_config.get("tools", None),
timeout_seconds=self.config.get("timeout", 120),
Expand Down
5 changes: 5 additions & 0 deletions docetl/operations/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def truncate_messages(
messages: List[Dict[str, str]], model: str, from_agent: bool = False
) -> List[Dict[str, str]]:
"""Truncate messages to fit within model's context length."""
# if there's a pdf, don't truncate
for message in messages:
if isinstance(message["content"], list):
return messages

model_cost_info = model_cost.get(model, {})
if not model_cost_info:
# Try stripping the first part before the /
Expand Down
Loading

0 comments on commit f4abe55

Please sign in to comment.