From 8ed17071f6bc27d93750685f54cda1c5cbbf41b6 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sun, 6 Oct 2024 08:42:23 -0500 Subject: [PATCH] odr_caption init Signed-off-by: jphillips --- modules/odr_caption/instructions.md | 2 + modules/odr_caption/odr_caption/__init__.py | 0 .../odr_caption/agents/ImageCaptioner.py | 160 ++++++++++++++++++ .../odr_caption/client/__init__.py | 0 .../odr_caption/schemas/__init__.py | 58 +++++++ .../odr_caption/schemas/vllm_schemas.py | 88 ++++++++++ .../odr_caption/server/__init__.py | 0 modules/odr_caption/odr_caption/server/app.py | 72 ++++++++ .../odr_caption/odr_caption/utils/logger.py | 38 +++++ .../odr_caption/utils/message_logger.py | 83 +++++++++ modules/odr_caption/pyproject.toml | 28 +++ modules/odr_caption/requirements.txt | 0 12 files changed, 529 insertions(+) create mode 100644 modules/odr_caption/instructions.md create mode 100644 modules/odr_caption/odr_caption/__init__.py create mode 100644 modules/odr_caption/odr_caption/agents/ImageCaptioner.py create mode 100644 modules/odr_caption/odr_caption/client/__init__.py create mode 100644 modules/odr_caption/odr_caption/schemas/__init__.py create mode 100644 modules/odr_caption/odr_caption/schemas/vllm_schemas.py create mode 100644 modules/odr_caption/odr_caption/server/__init__.py create mode 100644 modules/odr_caption/odr_caption/server/app.py create mode 100644 modules/odr_caption/odr_caption/utils/logger.py create mode 100644 modules/odr_caption/odr_caption/utils/message_logger.py create mode 100644 modules/odr_caption/pyproject.toml create mode 100644 modules/odr_caption/requirements.txt diff --git a/modules/odr_caption/instructions.md b/modules/odr_caption/instructions.md new file mode 100644 index 0000000..99068f8 --- /dev/null +++ b/modules/odr_caption/instructions.md @@ -0,0 +1,2 @@ +References : +https://docs.vllm.ai/en/latest/models/vlm.html diff --git a/modules/odr_caption/odr_caption/__init__.py b/modules/odr_caption/odr_caption/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/odr_caption/odr_caption/agents/ImageCaptioner.py b/modules/odr_caption/odr_caption/agents/ImageCaptioner.py new file mode 100644 index 0000000..ea9282e --- /dev/null +++ b/modules/odr_caption/odr_caption/agents/ImageCaptioner.py @@ -0,0 +1,160 @@ +import base64 +from typing import Optional, Union +from PIL import Image +import io +import time +import asyncio +from openai import AsyncOpenAI +from odr_caption.utils.logger import logger +from odr_caption.schemas.vllm_schemas import VLLMResponse, VLLMRequest, Interaction +from odr_caption.utils.message_logger import MessageLogger + + +class ImageCaptioner: + def __init__( + self, + vllm_server_url: str, + model_name: str, + message_logger: Optional[MessageLogger] = None, + max_tokens: int = 2048, + temperature: float = 0.35, + max_size: int = 1280, + repetition_penalty: float = 1.0, + ): + self.client = AsyncOpenAI(api_key="EMPTY", base_url=vllm_server_url) + self.model_name = model_name + self.max_tokens = max_tokens + self.temperature = temperature + self.max_size = max_size + self.repetition_penalty = repetition_penalty + self.message_logger = message_logger + logger.info( + f"ImageCaptioner initialized with model_name: {self.model_name} and base_url: {vllm_server_url}" + ) + + def encode_image(self, image_path: str) -> str: + with Image.open(image_path) as img: + img = img.convert("RGB") + + # Resize image if the longest edge is greater than max_size + original_size = img.size + img.thumbnail((self.max_size, self.max_size)) + resized_size = img.size + + if original_size != resized_size: + logger.info(f"Image resized from {original_size} to {resized_size}") + + buffered = io.BytesIO() + img.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + async def caption_image( + self, image_path: str, system_message: str, prompt: Optional[str] = None + ) -> Union[VLLMResponse, str]: + timestamp_start = time.time() + encoded_image = self.encode_image(image_path) + messages = [ + {"role": "system", "content": system_message}, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt or "Describe this image in detail.", + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{encoded_image}"}, + }, + ], + }, + ] + + try: + response = await self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=self.max_tokens, + temperature=self.temperature, + presence_penalty=self.repetition_penalty, + ) + + vllm_response = self._convert_to_vllm_response(response) + timestamp_end = time.time() + duration = timestamp_end - timestamp_start + + self._log_interaction( + messages, vllm_response, timestamp_start, timestamp_end, duration + ) + return vllm_response + + except Exception as e: + logger.error(f"Error captioning image: {e}", exc_info=True) + error_response = self._create_error_response(str(e)) + timestamp_end = time.time() + duration = timestamp_end - timestamp_start + self._log_interaction( + messages, error_response, timestamp_start, timestamp_end, duration + ) + return error_response + + def _convert_to_vllm_response(self, openai_response) -> VLLMResponse: + return VLLMResponse( + id=openai_response.id, + object=openai_response.object, + created=openai_response.created, + model=openai_response.model, + choices=[ + { + "index": choice.index, + "message": { + "role": choice.message.role, + "content": choice.message.content, + }, + "finish_reason": choice.finish_reason, + } + for choice in openai_response.choices + ], + usage={ + "prompt_tokens": openai_response.usage.prompt_tokens, + "completion_tokens": openai_response.usage.completion_tokens, + "total_tokens": openai_response.usage.total_tokens, + }, + ) + + def _create_error_response(self, error_message: str) -> VLLMResponse: + return VLLMResponse( + id="error", + object="chat.completion", + created=int(time.time()), + model=self.model_name, + choices=[ + { + "index": 0, + "message": { + "role": "assistant", + "content": f"Error: {error_message}", + }, + "finish_reason": "error", + } + ], + usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + ) + + def _log_interaction( + self, + request, + response: VLLMResponse, + timestamp_start: float, + timestamp_end: float, + duration: float, + ): + if self.message_logger: + interaction = Interaction( + request=VLLMRequest(messages=request, model=self.model_name), + response=response, + timestamp_start=int(timestamp_start), + timestamp_end=int(timestamp_end), + duration=duration, + ) + self.message_logger.log_interaction(interaction) diff --git a/modules/odr_caption/odr_caption/client/__init__.py b/modules/odr_caption/odr_caption/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/odr_caption/odr_caption/schemas/__init__.py b/modules/odr_caption/odr_caption/schemas/__init__.py new file mode 100644 index 0000000..ab38371 --- /dev/null +++ b/modules/odr_caption/odr_caption/schemas/__init__.py @@ -0,0 +1,58 @@ +from pydantic import BaseModel, Field +from typing import List, Dict, Any, Optional +from odr_caption.schemas.vllm_schemas import ( + VLLMFunction, + VLLMTool, + VLLMRequestMessage, + VLLMRequest, + ToolCall, + VLLMMessage, + VLLMChoice, + VLLMUsage, + VLLMResponse, + Interaction, +) + + +class TestCase(BaseModel): + image_path: str + expected_result: str + expected_keywords: List[str] + + +class TestSuite(BaseModel): + name: str + output_file: str + system_message: str + test_cases: List[TestCase] + + +class GlobalConfig(BaseModel): + + model: str + max_tokens: int + temperature: float + vllm_server_url: str + + +class Config(BaseModel): + global_config: GlobalConfig = Field(..., alias="global") + test_suites: Dict[str, Dict[str, TestSuite]] + + +class ResponseAnalysis(BaseModel): + task_type: str + response_received: bool + response_content: str + total_tokens: int + expected_result: str + evaluation: str + keyword_match_percentage: float + image_path: str + test_suite_name: str + + +class ImageCaptionInputs(BaseModel): + system_message: str + image_data: str + prompt: str | None = None diff --git a/modules/odr_caption/odr_caption/schemas/vllm_schemas.py b/modules/odr_caption/odr_caption/schemas/vllm_schemas.py new file mode 100644 index 0000000..11a0349 --- /dev/null +++ b/modules/odr_caption/odr_caption/schemas/vllm_schemas.py @@ -0,0 +1,88 @@ +# schemas.py +from pydantic import BaseModel, Field +from typing import List, Dict, Any, Optional, Union + + +class VLLMFunction(BaseModel): + name: str + description: Optional[str] = None + parameters: Dict[str, Any] + + +class VLLMTool(BaseModel): + type: str + function: VLLMFunction + + +class VLLMRequestMessage(BaseModel): + role: str + content: Union[str, List[Dict[str, Any]], None] + name: Optional[str] = None + function_call: Optional[Dict[str, Any]] = None + + +class VLLMRequest(BaseModel): + model: str + messages: List[VLLMRequestMessage] + functions: Optional[List[VLLMFunction]] = None + function_call: Optional[Union[str, Dict[str, Any]]] = None + tools: Optional[List[VLLMTool]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + n: Optional[int] = None + stream: Optional[bool] = None + stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = None + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + + +class ToolCall(BaseModel): + id: str + type: str + function: Dict[str, Any] + + +class VLLMMessage(BaseModel): + role: str + content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = None + + +class VLLMChoice(BaseModel): + index: int + message: VLLMMessage + finish_reason: str + + +class VLLMUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class VLLMResponse(BaseModel): + id: str + object: str + created: int + model: str + choices: List[VLLMChoice] + usage: VLLMUsage + system_fingerprint: Optional[str] = None + + class Config: + allow_population_by_field_name = True + + +class Interaction(BaseModel): + request: VLLMRequest + response: VLLMResponse + timestamp_start: int + timestamp_end: int + duration: float + + class Config: + allow_population_by_field_name = True diff --git a/modules/odr_caption/odr_caption/server/__init__.py b/modules/odr_caption/odr_caption/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/odr_caption/odr_caption/server/app.py b/modules/odr_caption/odr_caption/server/app.py new file mode 100644 index 0000000..09f0fc2 --- /dev/null +++ b/modules/odr_caption/odr_caption/server/app.py @@ -0,0 +1,72 @@ +""" +file: main.py +description: Main entry point for the vision worker +keywords: fastapi, florence, vision, caption +""" + +from fastapi import FastAPI, File, UploadFile, HTTPException, Body, Form +from fastapi.middleware.cors import CORSMiddleware + +from PIL import Image +import io +from typing import Optional +import os +import torch +import base64 +import threading +from functools import lru_cache +import logging + +print(f"CUDA available: {torch.cuda.is_available()}") +print(f"Current device: {torch.cuda.current_device()}") +print(f"Device name: {torch.cuda.get_device_name(0)}") + + +logger = logging.getLogger(__name__) + +thread_local = threading.local() +app = FastAPI() + + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +async def generate_caption(): + pass + + +def get_client(): + pass + + +@app.post("/generate_caption") +async def generate_caption_endpoint( + file: UploadFile = File(...), + task: str = Form("more_detailed_caption"), + client_type: Optional[str] = Form(default=None), +): + try: + image = Image.open(file.file) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}") + + # Generate caption + + caption = await generate_caption(image, task=task, client=get_client(client_type)) + + if caption is None: + raise HTTPException(status_code=500, detail="Failed to generate caption") + + return {"content": caption} + + +@app.get("/health") +async def health_check(): + return {"status": "healthy"} diff --git a/modules/odr_caption/odr_caption/utils/logger.py b/modules/odr_caption/odr_caption/utils/logger.py new file mode 100644 index 0000000..558f36f --- /dev/null +++ b/modules/odr_caption/odr_caption/utils/logger.py @@ -0,0 +1,38 @@ +import logging +from termcolor import colored + + +def setup_logger(): + class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "blue", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red", + } + + def format(self, record): + levelname = record.levelname + message = super().format(record) + meta_message = f" ({record.pathname}:{record.lineno}): {message}" + log_message = colored(meta_message, self.COLORS.get(levelname, "white")) + + return log_message + + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + + handler = logging.StreamHandler() + handler.setFormatter( + ColoredFormatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + logger.addHandler(handler) + + return logger + + +logger = setup_logger() diff --git a/modules/odr_caption/odr_caption/utils/message_logger.py b/modules/odr_caption/odr_caption/utils/message_logger.py new file mode 100644 index 0000000..9b944a2 --- /dev/null +++ b/modules/odr_caption/odr_caption/utils/message_logger.py @@ -0,0 +1,83 @@ +# knowledge_worker/graph/message_logger.py + +import json +from typing import List +from copy import deepcopy +from odr_caption.schemas.vllm_schemas import Interaction, VLLMRequest, VLLMResponse +from odr_caption.utils.logger import logger +import os +import logging + + +def configure_logger(output_log_dir: str, name: str = __name__) -> logging.Logger: + logger.setLevel(logging.DEBUG) # Set to DEBUG for detailed logs + output_log_path = f"{output_log_dir}/graph_extraction.log" + + # Ensure the log directory exists + os.makedirs(output_log_dir, exist_ok=True) + + # Create handlers + c_handler = logging.StreamHandler() + f_handler = logging.FileHandler(output_log_path) + + c_handler.setLevel(logging.INFO) + f_handler.setLevel(logging.DEBUG) + + # Create formatters and add to handlers + c_format = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + f_format = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s") + + c_handler.setFormatter(c_format) + f_handler.setFormatter(f_format) + + # Add handlers to the logger + if not logger.hasHandlers(): + logger.addHandler(c_handler) + logger.addHandler(f_handler) + + return logger + + +class MessageLogger: + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.interactions: List[Interaction] = [] + + def _sanitize_request(self, request: VLLMRequest) -> VLLMRequest: + sanitized_request = deepcopy(request.dict()) + for message in sanitized_request.get("messages", []): + if isinstance(message.get("content"), list): + for item in message["content"]: + if isinstance(item, dict) and item.get("type") == "image_url": + item["image_url"] = {"image_url": "[HIDDEN]"} + return VLLMRequest(**sanitized_request) + + def log_interaction(self, interaction: Interaction): + sanitized_request = self._sanitize_request(interaction.request) + interaction = Interaction( + request=sanitized_request, + response=interaction.response, + timestamp_start=interaction.timestamp_start, + timestamp_end=interaction.timestamp_end, + duration=interaction.duration, + ) + self.interactions.append(interaction) + logger.debug(f"Logged interaction: {interaction.model_dump_json(indent=2)}") + + def export_interactions(self, file_path: str = None): + if not file_path: + file_path = f"{self.output_dir}/message_interactions.json" + try: + sanitized_interactions = [ + interaction.dict(by_alias=True) for interaction in self.interactions + ] + for interaction in sanitized_interactions: + interaction["request"] = self._sanitize_request( + VLLMRequest(**interaction["request"]) + ).model_dump() + + with open(file_path, "w") as f: + json.dump(sanitized_interactions, f, indent=4) + logger.info(f"Message interactions exported to {file_path}") + except Exception as e: + logger.error(f"Error exporting message interactions: {e}") diff --git a/modules/odr_caption/pyproject.toml b/modules/odr_caption/pyproject.toml new file mode 100644 index 0000000..ad1fcbd --- /dev/null +++ b/modules/odr_caption/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +build-backend = "setuptools.build_meta" +requires = [ + "setuptools>=69.0", + "wheel", +] + +[project] +description = "odr_caption" +name = "odr_caption" +version = "0.0.1" +dynamic = ["dependencies", "optional-dependencies"] + + +[tool.setuptools.dynamic] +dependencies = { file = ["requirements.txt"] } + +[tool.pytest.ini_options] +pythonpath = [ + "odr_caption", +] + +[tool.setuptools.packages.find] +include = ["odr_caption", "odr_caption.*"] +exclude = [""] # exclude packages matching these glob patterns (empty by default) + +[tool.setuptools.package-data] +"genworker" = ["py.typed"] diff --git a/modules/odr_caption/requirements.txt b/modules/odr_caption/requirements.txt new file mode 100644 index 0000000..e69de29