Skip to content

Commit

Permalink
odr_caption init
Browse files Browse the repository at this point in the history
Signed-off-by: jphillips <[email protected]>
  • Loading branch information
fearnworks committed Oct 6, 2024
1 parent 6cb8395 commit 8ed1707
Show file tree
Hide file tree
Showing 12 changed files with 529 additions and 0 deletions.
2 changes: 2 additions & 0 deletions modules/odr_caption/instructions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
References :
https://docs.vllm.ai/en/latest/models/vlm.html
Empty file.
160 changes: 160 additions & 0 deletions modules/odr_caption/odr_caption/agents/ImageCaptioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import base64
from typing import Optional, Union
from PIL import Image
import io
import time
import asyncio
from openai import AsyncOpenAI
from odr_caption.utils.logger import logger
from odr_caption.schemas.vllm_schemas import VLLMResponse, VLLMRequest, Interaction
from odr_caption.utils.message_logger import MessageLogger


class ImageCaptioner:
def __init__(
self,
vllm_server_url: str,
model_name: str,
message_logger: Optional[MessageLogger] = None,
max_tokens: int = 2048,
temperature: float = 0.35,
max_size: int = 1280,
repetition_penalty: float = 1.0,
):
self.client = AsyncOpenAI(api_key="EMPTY", base_url=vllm_server_url)
self.model_name = model_name
self.max_tokens = max_tokens
self.temperature = temperature
self.max_size = max_size
self.repetition_penalty = repetition_penalty
self.message_logger = message_logger
logger.info(
f"ImageCaptioner initialized with model_name: {self.model_name} and base_url: {vllm_server_url}"
)

def encode_image(self, image_path: str) -> str:
with Image.open(image_path) as img:
img = img.convert("RGB")

# Resize image if the longest edge is greater than max_size
original_size = img.size
img.thumbnail((self.max_size, self.max_size))
resized_size = img.size

if original_size != resized_size:
logger.info(f"Image resized from {original_size} to {resized_size}")

buffered = io.BytesIO()
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")

async def caption_image(
self, image_path: str, system_message: str, prompt: Optional[str] = None
) -> Union[VLLMResponse, str]:
timestamp_start = time.time()
encoded_image = self.encode_image(image_path)
messages = [
{"role": "system", "content": system_message},
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt or "Describe this image in detail.",
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encoded_image}"},
},
],
},
]

try:
response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
presence_penalty=self.repetition_penalty,
)

vllm_response = self._convert_to_vllm_response(response)
timestamp_end = time.time()
duration = timestamp_end - timestamp_start

self._log_interaction(
messages, vllm_response, timestamp_start, timestamp_end, duration
)
return vllm_response

except Exception as e:
logger.error(f"Error captioning image: {e}", exc_info=True)
error_response = self._create_error_response(str(e))
timestamp_end = time.time()
duration = timestamp_end - timestamp_start
self._log_interaction(
messages, error_response, timestamp_start, timestamp_end, duration
)
return error_response

def _convert_to_vllm_response(self, openai_response) -> VLLMResponse:
return VLLMResponse(
id=openai_response.id,
object=openai_response.object,
created=openai_response.created,
model=openai_response.model,
choices=[
{
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content,
},
"finish_reason": choice.finish_reason,
}
for choice in openai_response.choices
],
usage={
"prompt_tokens": openai_response.usage.prompt_tokens,
"completion_tokens": openai_response.usage.completion_tokens,
"total_tokens": openai_response.usage.total_tokens,
},
)

def _create_error_response(self, error_message: str) -> VLLMResponse:
return VLLMResponse(
id="error",
object="chat.completion",
created=int(time.time()),
model=self.model_name,
choices=[
{
"index": 0,
"message": {
"role": "assistant",
"content": f"Error: {error_message}",
},
"finish_reason": "error",
}
],
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
)

def _log_interaction(
self,
request,
response: VLLMResponse,
timestamp_start: float,
timestamp_end: float,
duration: float,
):
if self.message_logger:
interaction = Interaction(
request=VLLMRequest(messages=request, model=self.model_name),
response=response,
timestamp_start=int(timestamp_start),
timestamp_end=int(timestamp_end),
duration=duration,
)
self.message_logger.log_interaction(interaction)
Empty file.
58 changes: 58 additions & 0 deletions modules/odr_caption/odr_caption/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
from odr_caption.schemas.vllm_schemas import (
VLLMFunction,
VLLMTool,
VLLMRequestMessage,
VLLMRequest,
ToolCall,
VLLMMessage,
VLLMChoice,
VLLMUsage,
VLLMResponse,
Interaction,
)


class TestCase(BaseModel):
image_path: str
expected_result: str
expected_keywords: List[str]


class TestSuite(BaseModel):
name: str
output_file: str
system_message: str
test_cases: List[TestCase]


class GlobalConfig(BaseModel):

model: str
max_tokens: int
temperature: float
vllm_server_url: str


class Config(BaseModel):
global_config: GlobalConfig = Field(..., alias="global")
test_suites: Dict[str, Dict[str, TestSuite]]


class ResponseAnalysis(BaseModel):
task_type: str
response_received: bool
response_content: str
total_tokens: int
expected_result: str
evaluation: str
keyword_match_percentage: float
image_path: str
test_suite_name: str


class ImageCaptionInputs(BaseModel):
system_message: str
image_data: str
prompt: str | None = None
88 changes: 88 additions & 0 deletions modules/odr_caption/odr_caption/schemas/vllm_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# schemas.py
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Union


class VLLMFunction(BaseModel):
name: str
description: Optional[str] = None
parameters: Dict[str, Any]


class VLLMTool(BaseModel):
type: str
function: VLLMFunction


class VLLMRequestMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]], None]
name: Optional[str] = None
function_call: Optional[Dict[str, Any]] = None


class VLLMRequest(BaseModel):
model: str
messages: List[VLLMRequestMessage]
functions: Optional[List[VLLMFunction]] = None
function_call: Optional[Union[str, Dict[str, Any]]] = None
tools: Optional[List[VLLMTool]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None


class ToolCall(BaseModel):
id: str
type: str
function: Dict[str, Any]


class VLLMMessage(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None


class VLLMChoice(BaseModel):
index: int
message: VLLMMessage
finish_reason: str


class VLLMUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int


class VLLMResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: List[VLLMChoice]
usage: VLLMUsage
system_fingerprint: Optional[str] = None

class Config:
allow_population_by_field_name = True


class Interaction(BaseModel):
request: VLLMRequest
response: VLLMResponse
timestamp_start: int
timestamp_end: int
duration: float

class Config:
allow_population_by_field_name = True
Empty file.
72 changes: 72 additions & 0 deletions modules/odr_caption/odr_caption/server/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
file: main.py
description: Main entry point for the vision worker
keywords: fastapi, florence, vision, caption
"""

from fastapi import FastAPI, File, UploadFile, HTTPException, Body, Form
from fastapi.middleware.cors import CORSMiddleware

from PIL import Image
import io
from typing import Optional
import os
import torch
import base64
import threading
from functools import lru_cache
import logging

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Current device: {torch.cuda.current_device()}")
print(f"Device name: {torch.cuda.get_device_name(0)}")


logger = logging.getLogger(__name__)

thread_local = threading.local()
app = FastAPI()


# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


async def generate_caption():
pass


def get_client():
pass


@app.post("/generate_caption")
async def generate_caption_endpoint(
file: UploadFile = File(...),
task: str = Form("more_detailed_caption"),
client_type: Optional[str] = Form(default=None),
):
try:
image = Image.open(file.file)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}")

# Generate caption

caption = await generate_caption(image, task=task, client=get_client(client_type))

if caption is None:
raise HTTPException(status_code=500, detail="Failed to generate caption")

return {"content": caption}


@app.get("/health")
async def health_check():
return {"status": "healthy"}
Loading

0 comments on commit 8ed1707

Please sign in to comment.