Skip to content

Commit

Permalink
Add the ability to provide additional descriptions to each image
Browse files Browse the repository at this point in the history
  • Loading branch information
Smiley73 committed Feb 8, 2025
1 parent aafe77b commit fe034b5
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 21 deletions.
6 changes: 5 additions & 1 deletion custom_components/llmvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TARGET_WIDTH,
IMAGE_FILE,
IMAGE_ENTITY,
IMAGE_DESCRIPTIONS,
VIDEO_FILE,
EVENT_ID,
FRIGATE_RETRY_ATTEMPTS,
Expand Down Expand Up @@ -295,13 +296,15 @@ def __init__(self, data_call):
self.max_tokens = int(data_call.data.get(MAXTOKENS, 100))
self.include_filename = data_call.data.get(INCLUDE_FILENAME, False)
self.expose_images = data_call.data.get(EXPOSE_IMAGES, False)
self.image_descriptions = data_call.data.get(IMAGE_DESCRIPTIONS, {})
self.expose_images_persist = data_call.data.get(
EXPOSE_IMAGES_PERSIST, False)
self.generate_title = data_call.data.get(GENERATE_TITLE, False)
self.sensor_entity = data_call.data.get(SENSOR_ENTITY)
# ------------ Added during call ------------
# self.base64_images : List[str] = []
# self.filenames : List[str] = []
# self.descriptions : List[str] = []

def get_service_call_data(self):
return self
Expand Down Expand Up @@ -329,7 +332,8 @@ async def image_analyzer(data_call):
target_width=call.target_width,
include_filename=call.include_filename,
expose_images=call.expose_images,
expose_images_persist=call.expose_images_persist
expose_images_persist=call.expose_images_persist,
image_descriptions=call.image_descriptions
)

# Validate configuration, input data and make the call
Expand Down
1 change: 1 addition & 0 deletions custom_components/llmvision/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
EXPOSE_IMAGES_PERSIST = 'expose_images_persist'
GENERATE_TITLE = 'generate_title'
SENSOR_ENTITY = 'sensor_entity'
IMAGE_DESCRIPTIONS = 'image_descriptions'

# Error messages
ERROR_NOT_CONFIGURED = "{provider} is not configured"
Expand Down
46 changes: 33 additions & 13 deletions custom_components/llmvision/media_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import time
import asyncio
from pathlib import Path
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from functools import partial
from PIL import Image, UnidentifiedImageError
Expand All @@ -27,7 +28,7 @@ def __init__(self, hass, client):
self.filenames = []
self.path = "/config/www/llmvision"
self.key_frame = ""

self.descriptions = []

async def _encode_image(self, img):
"""Encode image as base64"""
Expand Down Expand Up @@ -288,10 +289,21 @@ async def record_camera(image_entity, camera_number):
filename=frame_name
)

async def add_images(self, image_entities, image_paths, target_width, include_filename, expose_images, expose_images_persist):
async def add_images(self, image_entities, image_paths, target_width, include_filename, expose_images, expose_images_persist, image_descriptions):
"""Wrapper for client.add_frame for images"""
if image_entities:
for image_entity in image_entities:
friendly_name = self.hass.states.get(image_entity).attributes.get('friendly_name')
# Grab the image description if they're provided
description = None
if len(image_descriptions) > 0:
if friendly_name in image_descriptions:
description = image_descriptions.get(friendly_name)
else:
raise ServiceValidationError(
f"Image descriptions are defined, but are missing the one for '{friendly_name}'")
_LOGGER.debug(f"friendly_name={friendly_name} description={description}")

try:
base_url = get_url(self.hass)
image_url = base_url + \
Expand All @@ -309,8 +321,8 @@ async def add_images(self, image_entities, image_paths, target_width, include_fi
resized_image = await self.resize_image(target_width=target_width, image_data=image_data)
self.client.add_frame(
base64_image=resized_image,
filename=self.hass.states.get(
image_entity).attributes.get('friendly_name') if include_filename else ""
filename=friendly_name if include_filename else "",
description=description
)

if expose_images:
Expand All @@ -324,17 +336,24 @@ async def add_images(self, image_entities, image_paths, target_width, include_fi
f"Entity {image_entity} does not exist")
if image_paths:
for image_path in image_paths:
image_path = image_path.strip()
filename=Path(image_path).stem
# Grab the image description if they're provided
description = None
if len(image_descriptions) > 0:
if filename in image_descriptions:
description = image_descriptions.get(filename)
else:
raise ServiceValidationError(
f"Image descriptions are defined, but are missing the one for '{filename}'")
_LOGGER.debug(f"filename={filename} description={description}")

try:
image_path = image_path.strip()
if include_filename and os.path.exists(image_path):
self.client.add_frame(
base64_image=await self.resize_image(target_width=target_width, image_path=image_path),
filename=image_path.split('/')[-1].split('.')[-2]
)
elif os.path.exists(image_path):
if os.path.exists(image_path):
self.client.add_frame(
base64_image=await self.resize_image(target_width=target_width, image_path=image_path),
filename=""
filename=filename if include_filename else "",
description=description
)
if not os.path.exists(image_path):
raise ServiceValidationError(
Expand Down Expand Up @@ -511,6 +530,7 @@ async def add_visual_data(self, image_entities, image_paths, target_width, inclu
image_paths=image_paths,
target_width=target_width,
include_filename=include_filename,
expose_images=False
expose_images=False,
image_descriptions=None
)
return self.client
25 changes: 18 additions & 7 deletions custom_components/llmvision/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, hass, message, max_tokens, temperature):
self.temperature = temperature
self.base64_images = []
self.filenames = []
self.descriptions = []

@staticmethod
def sanitize_data(data):
Expand Down Expand Up @@ -133,6 +134,7 @@ async def call(self, call):
provider = Request.get_provider(self.hass, entry_id)
call.base64_images = self.base64_images
call.filenames = self.filenames
call.descriptions = self.descriptions

self.validate(call)

Expand Down Expand Up @@ -224,9 +226,10 @@ async def call(self, call):
else:
return {"response_text": response_text}

def add_frame(self, base64_image, filename):
def add_frame(self, base64_image, filename, description=None):
self.base64_images.append(base64_image)
self.filenames.append(filename)
self.descriptions.append(description)

async def _resolve_error(self, response, provider):
"""Translate response status to error message"""
Expand Down Expand Up @@ -369,9 +372,10 @@ def _prepare_vision_data(self, call) -> list:
"temperature": call.temperature
}

for image, filename in zip(call.base64_images, call.filenames):
for image, filename, description in zip(call.base64_images, call.filenames, call.descriptions):
tag = ("Image " + str(call.base64_images.index(image) + 1)
) if filename == "" else filename
tag += (f"\n{description}") if description is not None else ""
payload["messages"][0]["content"].append(
{"type": "text", "text": tag + ":"})
payload["messages"][0]["content"].append({"type": "image_url", "image_url": {
Expand Down Expand Up @@ -430,9 +434,10 @@ def _prepare_vision_data(self, call) -> list:
"temperature": call.temperature,
"stream": False
}
for image, filename in zip(call.base64_images, call.filenames):
for image, filename, description in zip(call.base64_images, call.filenames, call.descriptions):
tag = ("Image " + str(call.base64_images.index(image) + 1)
) if filename == "" else filename
tag += (f"\n{description}") if description is not None else ""
payload["messages"][0]["content"].append(
{"type": "text", "text": tag + ":"})
payload["messages"][0]["content"].append({"type": "image_url", "image_url": {
Expand Down Expand Up @@ -491,9 +496,10 @@ def _prepare_vision_data(self, call) -> dict:
"max_tokens": call.max_tokens,
"temperature": call.temperature
}
for image, filename in zip(call.base64_images, call.filenames):
for image, filename, description in zip(call.base64_images, call.filenames, call.descriptions):
tag = ("Image " + str(call.base64_images.index(image) + 1)
) if filename == "" else filename
tag += (f"\n{description}") if description is not None else ""
data["messages"][0]["content"].append(
{"type": "text", "text": tag + ":"})
data["messages"][0]["content"].append({"type": "image", "source": {
Expand Down Expand Up @@ -547,9 +553,10 @@ async def _make_request(self, data) -> str:
def _prepare_vision_data(self, call) -> dict:
data = {"contents": [], "generationConfig": {
"maxOutputTokens": call.max_tokens, "temperature": call.temperature}}
for image, filename in zip(call.base64_images, call.filenames):
for image, filename, description in zip(call.base64_images, call.filenames, call.descriptions):
tag = ("Image " + str(call.base64_images.index(image) + 1)
) if filename == "" else filename
tag += (f"\n{description}") if description is not None else ""
data["contents"].append({"role": "user", "parts": [
{"text": tag + ":"}, {"inline_data": {"mime_type": "image/jpeg", "data": image}}]})
data["contents"].append(
Expand Down Expand Up @@ -654,9 +661,10 @@ async def _make_request(self, data) -> str:
def _prepare_vision_data(self, call) -> dict:
data = {"model": call.model, "messages": [{"role": "user", "content": [
]}], "max_tokens": call.max_tokens, "temperature": call.temperature}
for image, filename in zip(call.base64_images, call.filenames):
for image, filename, description in zip(call.base64_images, call.filenames, call.descriptions):
tag = ("Image " + str(call.base64_images.index(image) + 1)
) if filename == "" else filename
tag += (f"\n{description}") if description is not None else ""
data["messages"][0]["content"].append(
{"type": "text", "text": tag + ":"})
data["messages"][0]["content"].append(
Expand Down Expand Up @@ -704,6 +712,7 @@ async def _make_request(self, data) -> str:
port=port,
protocol=protocol
)
_LOGGER.debug(f"Calling Ollama url:{endpoint} data:{Request.sanitize_data(data)}")

response = await self._post(url=endpoint, headers={}, data=data)
response_text = response.get("message").get("content")
Expand All @@ -712,9 +721,10 @@ async def _make_request(self, data) -> str:
def _prepare_vision_data(self, call) -> dict:
data = {"model": call.model, "messages": [], "stream": False, "options": {
"num_predict": call.max_tokens, "temperature": call.temperature}}
for image, filename in zip(call.base64_images, call.filenames):
for image, filename, description in zip(call.base64_images, call.filenames, call.descriptions):
tag = ("Image " + str(call.base64_images.index(image) + 1)
) if filename == "" else filename
tag += (f"\n{description}") if description is not None else ""
image_message = {"role": "user",
"content": tag + ":", "images": [image]}
data["messages"].append(image_message)
Expand Down Expand Up @@ -829,6 +839,7 @@ def _prepare_vision_data(self, call) -> list:
for image, filename in zip(call.base64_images, call.filenames):
tag = ("Image " + str(call.base64_images.index(image) + 1)
) if filename == "" else filename
tag += (f"\n{description}") if description is not None else ""
data["messages"][0]["content"].append(
{"text": tag + ":"})
data["messages"][0]["content"].append({
Expand Down
6 changes: 6 additions & 0 deletions custom_components/llmvision/services.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ image_analyzer:
entity:
domain: ["image", "camera"]
multiple: true
image_descriptions:
name: Image Descriptions
required: false
description: 'Dict with information about each image. The filename or camera entity name is the key.'
selector:
object:
include_filename:
name: Include Filename
required: true
Expand Down

0 comments on commit fe034b5

Please sign in to comment.