Skip to content

Commit

Permalink
genAI titles
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinfrlch committed Nov 13, 2024
1 parent e499200 commit bdba949
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 154 deletions.
156 changes: 67 additions & 89 deletions custom_components/llmvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
TEMPERATURE,
INCLUDE_FILENAME,
EXPOSE_IMAGES,
GENERATE_TITLE,
SENSOR_ENTITY,
)
from .calendar import SemanticIndex
from datetime import timedelta
from homeassistant.util import dt as dt_util
from homeassistant.config_entries import ConfigEntry
from .providers import RequestHandler
from .providers import Request
from .media_handlers import MediaProcessor
from homeassistant.core import SupportsResponse
from homeassistant.exceptions import ServiceValidationError
Expand Down Expand Up @@ -146,35 +147,8 @@ async def _remember(hass, call, start, response):
f"'Event Calendar' config not found")

semantic_index = SemanticIndex(hass, config_entry)
# Define a mapping of keywords to labels
keyword_to_label = {
"person": "Person",
"man": "Person",
"woman": "Person",
"individual": "Person",
"delivery": "Delivery",
"courier": "Courier",
"package": "Package",
"car": "Car",
"vehicle": "Car",
"bike": "Bike",
"bicycle": "Bike",
"bus": "Bus",
"truck": "Truck",
"motorcycle": "Motorcycle",
"bicycle": "Bicycle",
"dog": "Dog",
"cat": "Cat",
}

# Default label
label = "Unknown object"

# Check each keyword in the response text and update the label accordingly
for keyword, mapped_label in keyword_to_label.items():
if keyword in response["response_text"].lower():
label = mapped_label
break

title = response.get("title", "Unknown object seen")

if call.image_entities and len(call.image_entities) > 0:
camera_name = call.image_entities[0]
Expand All @@ -189,7 +163,7 @@ async def _remember(hass, call, start, response):
await semantic_index.remember(
start=start,
end=dt_util.now() + timedelta(minutes=1),
label=label + " seen",
label=title,
camera_name=camera_name,
summary=response["response_text"]
)
Expand All @@ -198,7 +172,8 @@ async def _remember(hass, call, start, response):
async def _update_sensor(hass, sensor_entity, new_value):
"""Update the value of a sensor entity."""
if sensor_entity:
_LOGGER.info(f"Updating sensor {sensor_entity} with new value: {new_value}")
_LOGGER.info(
f"Updating sensor {sensor_entity} with new value: {new_value}")
try:
hass.states.async_set(sensor_entity, new_value)
except Exception as e:
Expand Down Expand Up @@ -232,6 +207,7 @@ def __init__(self, data_call):
self.max_tokens = int(data_call.data.get(MAXTOKENS, 100))
self.include_filename = data_call.data.get(INCLUDE_FILENAME, False)
self.expose_images = data_call.data.get(EXPOSE_IMAGES, False)
self.generate_title = data_call.data.get(GENERATE_TITLE, False)
self.sensor_entity = data_call.data.get(SENSOR_ENTITY)

def get_service_call_data(self):
Expand All @@ -246,24 +222,24 @@ async def image_analyzer(data_call):
# Initialize call object with service call data
call = ServiceCallData(data_call).get_service_call_data()
# Initialize the RequestHandler client
client = RequestHandler(hass,
message=call.message,
max_tokens=call.max_tokens,
temperature=call.temperature,
)
request = Request(hass=hass,
message=call.message,
max_tokens=call.max_tokens,
temperature=call.temperature,
)

# Fetch and preprocess images
processor = MediaProcessor(hass, client)
processor = MediaProcessor(hass, request)
# Send images to RequestHandler client
client = await processor.add_images(image_entities=call.image_entities,
image_paths=call.image_paths,
target_width=call.target_width,
include_filename=call.include_filename,
expose_images=call.expose_images
)
request = await processor.add_images(image_entities=call.image_entities,
image_paths=call.image_paths,
target_width=call.target_width,
include_filename=call.include_filename,
expose_images=call.expose_images
)

# Validate configuration, input data and make the call
response = await client.forward_request(call)
response = await request.call(call)
await _remember(hass, call, start, response)
return response

Expand All @@ -272,20 +248,20 @@ async def video_analyzer(data_call):
start = dt_util.now()
call = ServiceCallData(data_call).get_service_call_data()
call.message = "The attached images are frames from a video. " + call.message
client = RequestHandler(hass,
message=call.message,
max_tokens=call.max_tokens,
temperature=call.temperature,
)
processor = MediaProcessor(hass, client)
client = await processor.add_videos(video_paths=call.video_paths,
event_ids=call.event_id,
max_frames=call.max_frames,
target_width=call.target_width,
include_filename=call.include_filename,
expose_images=call.expose_images
)
response = await client.forward_request(call)
request = Request(hass,
message=call.message,
max_tokens=call.max_tokens,
temperature=call.temperature,
)
processor = MediaProcessor(hass, request)
request = await processor.add_videos(video_paths=call.video_paths,
event_ids=call.event_id,
max_frames=call.max_frames,
target_width=call.target_width,
include_filename=call.include_filename,
expose_images=call.expose_images
)
response = await request.call(call)
await _remember(hass, call, start, response)
return response

Expand All @@ -294,21 +270,21 @@ async def stream_analyzer(data_call):
start = dt_util.now()
call = ServiceCallData(data_call).get_service_call_data()
call.message = "The attached images are frames from a live camera feed. " + call.message
client = RequestHandler(hass,
message=call.message,
max_tokens=call.max_tokens,
temperature=call.temperature,
)
processor = MediaProcessor(hass, client)
client = await processor.add_streams(image_entities=call.image_entities,
duration=call.duration,
max_frames=call.max_frames,
target_width=call.target_width,
include_filename=call.include_filename,
expose_images=call.expose_images
)

response = await client.forward_request(call)
request = Request(hass,
message=call.message,
max_tokens=call.max_tokens,
temperature=call.temperature,
)
processor = MediaProcessor(hass, request)
request = await processor.add_streams(image_entities=call.image_entities,
duration=call.duration,
max_frames=call.max_frames,
target_width=call.target_width,
include_filename=call.include_filename,
expose_images=call.expose_images
)

response = await request.call(call)
await _remember(hass, call, start, response)
return response

Expand All @@ -321,12 +297,12 @@ def is_number(s):
return True
except ValueError:
return False

start = dt_util.now()
call = ServiceCallData(data_call).get_service_call_data()
sensor_entity = data_call.data.get("sensor_entity")
_LOGGER.info(f"Sensor entity: {sensor_entity}")

# get current value to determine data type
state = hass.states.get(sensor_entity).state
_LOGGER.info(f"Current state: {state}")
Expand All @@ -338,24 +314,26 @@ def is_number(s):
data_type = "number"
else:
if "options" in hass.states.get(sensor_entity).attributes:
data_type = "one of these options: " + ", ".join([f"'{option}'" for option in hass.states.get(sensor_entity).attributes["options"]])
data_type = "one of these options: " + \
", ".join([f"'{option}'" for option in hass.states.get(
sensor_entity).attributes["options"]])
else:
data_type = "string"

message = f"Your job is to extract data from images. Return a {data_type} only. No additional text or other options allowed!. If unsure, choose the option that best matches. Follow these instructions: " + call.message
_LOGGER.info(f"Message: {message}")
client = RequestHandler(hass,
message=message,
max_tokens=call.max_tokens,
temperature=call.temperature,
)
processor = MediaProcessor(hass, client)
client = await processor.add_visual_data(image_entities=call.image_entities,
image_paths=call.image_paths,
target_width=call.target_width,
include_filename=call.include_filename
)
response = await client.forward_request(call)
request = Request(hass,
message=message,
max_tokens=call.max_tokens,
temperature=call.temperature,
)
processor = MediaProcessor(hass, request)
request = await processor.add_visual_data(image_entities=call.image_entities,
image_paths=call.image_paths,
target_width=call.target_width,
include_filename=call.include_filename
)
response = await request.call(call)
_LOGGER.info(f"Response: {response}")
# udpate sensor in data_call.data.get("sensor_entity")
await _update_sensor(hass, sensor_entity, response["response_text"])
Expand Down
1 change: 1 addition & 0 deletions custom_components/llmvision/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TEMPERATURE = 'temperature'
INCLUDE_FILENAME = 'include_filename'
EXPOSE_IMAGES = 'expose_images'
GENERATE_TITLE = 'generate_title'
SENSOR_ENTITY = 'sensor_entity'

# Error messages
Expand Down
2 changes: 1 addition & 1 deletion custom_components/llmvision/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"domain": "llmvision",
"name": "LLM Vision",
"codeowners": ["@valentinfrlch"],
"requirements": ["openai==1.54.3"],
"config_flow": true,
"documentation": "https://github.com/valentinfrlch/ha-llmvision",
"iot_class": "cloud_polling",
"issue_tracker": "https://github.com/valentinfrlch/ha-llmvision/issues",
"requirements": ["openai==1.54.3"],
"version": "1.3.2"
}
Loading

0 comments on commit bdba949

Please sign in to comment.