Skip to content

Commit

Permalink
(WIP) Add memory and system prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinfrlch committed Feb 17, 2025
1 parent 53a19c8 commit 1bf93dd
Show file tree
Hide file tree
Showing 9 changed files with 401 additions and 37 deletions.
36 changes: 31 additions & 5 deletions custom_components/llmvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
CONF_CUSTOM_OPENAI_API_KEY,
CONF_CUSTOM_OPENAI_DEFAULT_MODEL,
CONF_RETENTION_TIME,
CONF_MEMORY_PATHS,
CONG_MEMORY_IMAGES_ENCODED,
CONF_MEMORY_STRINGS,
CONF_SYSTEM_PROMPT,
CONF_AWS_ACCESS_KEY_ID,
CONF_AWS_SECRET_ACCESS_KEY,
CONF_AWS_REGION_NAME,
Expand All @@ -30,6 +34,7 @@
CONF_OPENWEBUI_DEFAULT_MODEL,
MESSAGE,
REMEMBER,
USE_MEMORY,
MODEL,
PROVIDER,
MAXTOKENS,
Expand All @@ -52,6 +57,7 @@
)
from .calendar import SemanticIndex
from .providers import Request
from .memory import Memory
from .media_handlers import MediaProcessor
import re
from datetime import timedelta
Expand Down Expand Up @@ -89,6 +95,10 @@ async def async_setup_entry(hass, entry):
custom_openai_default_model = entry.data.get(
CONF_CUSTOM_OPENAI_DEFAULT_MODEL)
retention_time = entry.data.get(CONF_RETENTION_TIME)
memory_paths = entry.data.get(CONF_MEMORY_PATHS)
memory_images_encoded = entry.data.get(CONG_MEMORY_IMAGES_ENCODED)
memory_strings = entry.data.get(CONF_MEMORY_STRINGS)
system_prompt = entry.data.get(CONF_SYSTEM_PROMPT)
aws_access_key_id = entry.data.get(CONF_AWS_ACCESS_KEY_ID)
aws_secret_access_key = entry.data.get(CONF_AWS_SECRET_ACCESS_KEY)
aws_region_name = entry.data.get(CONF_AWS_REGION_NAME)
Expand Down Expand Up @@ -123,6 +133,10 @@ async def async_setup_entry(hass, entry):
CONF_CUSTOM_OPENAI_API_KEY: custom_openai_api_key,
CONF_CUSTOM_OPENAI_DEFAULT_MODEL: custom_openai_default_model,
CONF_RETENTION_TIME: retention_time,
CONF_MEMORY_PATHS: memory_paths,
CONG_MEMORY_IMAGES_ENCODED: memory_images_encoded,
CONF_MEMORY_STRINGS: memory_strings,
CONF_SYSTEM_PROMPT: system_prompt,
CONF_AWS_ACCESS_KEY_ID: aws_access_key_id,
CONF_AWS_SECRET_ACCESS_KEY: aws_secret_access_key,
CONF_AWS_REGION_NAME: aws_region_name,
Expand Down Expand Up @@ -153,7 +167,6 @@ async def async_remove_entry(hass, entry):
"""Remove config entry from hass.data"""
# Use the entry_id from the config entry as the UID
entry_uid = entry.entry_id

if entry_uid in hass.data[DOMAIN]:
# Remove the entry from hass.data
_LOGGER.info(f"Removing {entry.title} from hass.data")
Expand All @@ -162,20 +175,17 @@ async def async_remove_entry(hass, entry):
else:
_LOGGER.warning(
f"Entry {entry.title} not found but was requested to be removed")

return True


async def async_unload_entry(hass, entry) -> bool:
_LOGGER.debug(f"Unloading {entry.title} from hass.data")

# check if the entry is the calendar entry (has entry rentention_time)
if entry.data.get(CONF_RETENTION_TIME) is not None:
# unload the calendar
unload_ok = await hass.config_entries.async_unload_platforms(entry, ["calendar"])
else:
unload_ok = True

return unload_ok


Expand Down Expand Up @@ -291,6 +301,7 @@ def __init__(self, data_call):
MODEL))
self.message = str(data_call.data.get(MESSAGE, "")[0:2000])
self.remember = data_call.data.get(REMEMBER, False)
self.use_memory = data_call.data.get(USE_MEMORY, False)
self.image_paths = data_call.data.get(IMAGE_FILE, "").split(
"\n") if data_call.data.get(IMAGE_FILE) else None
self.image_entities = data_call.data.get(IMAGE_ENTITY)
Expand Down Expand Up @@ -345,7 +356,6 @@ async def image_analyzer(data_call):
max_tokens=call.max_tokens,
temperature=call.temperature,
)

# Fetch and preprocess images
processor = MediaProcessor(hass, request)
# Send images to RequestHandler client
Expand All @@ -357,6 +367,10 @@ async def image_analyzer(data_call):
expose_images_persist=call.expose_images_persist
)

if call.use_memory:
call.memory = Memory(hass)
await call.memory._update_memory()

# Validate configuration, input data and make the call
response = await request.call(call)
# Add processor.key_frame to response if it exists
Expand Down Expand Up @@ -392,6 +406,10 @@ async def video_analyzer(data_call):
frigate_retry_attempts=call.frigate_retry_attempts,
frigate_retry_seconds=call.frigate_retry_seconds
)
if call.use_memory:
call.memory = Memory(hass)
await call.memory._update_memory()

response = await request.call(call)
# Add processor.key_frame to response if it exists
if processor.key_frame:
Expand Down Expand Up @@ -425,6 +443,10 @@ async def stream_analyzer(data_call):
expose_images_persist=call.expose_images_persist
)

if call.use_memory:
call.memory = Memory(hass)
await call.memory._update_memory()

response = await request.call(call)
# Add processor.key_frame to response if it exists
if processor.key_frame:
Expand Down Expand Up @@ -481,6 +503,10 @@ async def data_analyzer(data_call):
target_width=call.target_width,
include_filename=call.include_filename
)
if call.use_memory:
call.memory = Memory(hass)
await call.memory._update_memory()

response = await request.call(call)
_LOGGER.info(f"Response: {response}")
_LOGGER.info(f"Sensor type: {type}")
Expand Down
62 changes: 60 additions & 2 deletions custom_components/llmvision/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
CONF_CUSTOM_OPENAI_ENDPOINT,
CONF_CUSTOM_OPENAI_DEFAULT_MODEL,
CONF_RETENTION_TIME,
CONF_MEMORY_PATHS,
CONF_MEMORY_STRINGS,
CONF_SYSTEM_PROMPT,
CONF_AWS_ACCESS_KEY_ID,
CONF_AWS_SECRET_ACCESS_KEY,
CONF_AWS_REGION_NAME,
Expand All @@ -56,11 +59,12 @@ class llmvisionConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):

async def handle_provider(self, provider):
provider_steps = {
"Event Calendar": self.async_step_semantic_index,
"Memory": self.async_step_memory,
"Anthropic": self.async_step_anthropic,
"AWS Bedrock": self.async_step_aws_bedrock,
"Azure": self.async_step_azure,
"Custom OpenAI": self.async_step_custom_openai,
"Event Calendar": self.async_step_semantic_index,
"Google": self.async_step_google,
"Groq": self.async_step_groq,
"LocalAI": self.async_step_localai,
Expand All @@ -81,7 +85,7 @@ async def async_step_user(self, user_input=None):
vol.Required("provider", default="Event Calendar"): selector({
"select": {
# Azure removed until fixed
"options": ["Anthropic", "AWS Bedrock", "Google", "Groq", "LocalAI", "Ollama", "OpenAI", "OpenWebUI", "Custom OpenAI", "Event Calendar"],
"options": ["Event Calendar", "Memory", "Anthropic", "AWS Bedrock", "Google", "Groq", "LocalAI", "Ollama", "OpenAI", "OpenWebUI", "Custom OpenAI"],
"mode": "dropdown",
"sort": False,
"custom_value": False
Expand Down Expand Up @@ -554,6 +558,60 @@ async def async_step_semantic_index(self, user_input=None):
data_schema=data_schema,
)

async def async_step_memory(self, user_input=None):
data_schema = vol.Schema({
vol.Optional(CONF_MEMORY_PATHS, default="/config/llmvision/memory/example.jpg"): selector({
"text": {
"multiline": False,
"multiple": True
}
}),
vol.Optional(CONF_MEMORY_STRINGS, default="Alice"): selector({
"text": {
"multiline": False,
"multiple": True
}
}),
vol.Optional(CONF_SYSTEM_PROMPT, default="You are a helpful AI assistant."): selector({
"text": {
"multiline": True,
"multiple": False
}
}),
})

if self.source == config_entries.SOURCE_RECONFIGURE:
# load existing configuration and add it to the dialog
self.init_info = self._get_reconfigure_entry().data
data_schema = self.add_suggested_values_to_schema(
data_schema, self.init_info
)

if user_input is not None:
user_input["provider"] = self.init_info["provider"]

try:
for uid in self.hass.data[DOMAIN]:
if 'system_prompt' in self.hass.data[DOMAIN][uid]:
self.async_abort(reason="already_configured")
except KeyError:
# no existing configuration, continue
pass
if self.source == config_entries.SOURCE_RECONFIGURE:
# we're reconfiguring an existing config
return self.async_update_reload_and_abort(
self._get_reconfigure_entry(),
data_updates=user_input,
)
else:
# New config entry
return self.async_create_entry(title="LLM Vision Memory", data=user_input)

return self.async_show_form(
step_id="memory",
data_schema=data_schema,
)

async def async_step_aws_bedrock(self, user_input=None):
data_schema = vol.Schema({
vol.Required(CONF_AWS_REGION_NAME, default="us-east-1"): str,
Expand Down
5 changes: 5 additions & 0 deletions custom_components/llmvision/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
CONF_CUSTOM_OPENAI_API_KEY = 'custom_openai_api_key'
CONF_CUSTOM_OPENAI_DEFAULT_MODEL = 'custom_openai_default_model'
CONF_RETENTION_TIME = 'retention_time'
CONF_MEMORY_PATHS = 'memory_paths'
CONG_MEMORY_IMAGES_ENCODED = 'memory_images_encoded'
CONF_MEMORY_STRINGS = 'memory_strings'
CONF_SYSTEM_PROMPT = 'system_prompt'
CONF_AWS_ACCESS_KEY_ID = 'aws_access_key_id'
CONF_AWS_SECRET_ACCESS_KEY = 'aws_secret_access_key'
CONF_AWS_REGION_NAME = 'aws_region_name'
Expand All @@ -35,6 +39,7 @@
# service call constants
MESSAGE = 'message'
REMEMBER = 'remember'
USE_MEMORY = 'use_memory'
PROVIDER = 'provider'
MAXTOKENS = 'max_tokens'
TARGET_WIDTH = 'target_width'
Expand Down
135 changes: 135 additions & 0 deletions custom_components/llmvision/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from .const import (
DOMAIN,
CONF_MEMORY_PATHS,
CONG_MEMORY_IMAGES_ENCODED,
CONF_MEMORY_STRINGS,
CONF_SYSTEM_PROMPT
)
import base64
import io
from PIL import Image
import logging

_LOGGER = logging.getLogger(__name__)


class Memory:
def __init__(self, hass):
self.hass = hass
self.entry = self._find_memory_entry()
self.system_prompt = self.entry.data.get(CONF_SYSTEM_PROMPT, "")
self.memory_strings = self.entry.data.get(CONF_MEMORY_STRINGS, [])
self.memory_paths = self.entry.data.get(CONF_MEMORY_PATHS, [])
self.memory_images = self.entry.data.get(
CONG_MEMORY_IMAGES_ENCODED, [])

def get_memory_strings(self):
return self.memory_strings

def _get_memory_images(self, type="OpenAI"):
content = []
if type == "OpenAI":
content.append(
{"type": "text", "text": "The following images along with descriptions serve as reference. They are not to be mentioned in the response."})
for image in self.memory_images:
tag = self.memory_strings[self.memory_images.index(image)]

content.append(
{"type": "text", "text": tag + ":"})
content.append({"type": "image_url", "image_url": {
"url": f"data:image/jpeg;base64,{image}"}})

elif type == "Anthropic":
content.append(
{"type": "text", "text": "The following images along with descriptions serve as reference. They are not to be mentioned in the response."})
for image in self.memory_images:
tag = self.memory_strings[self.memory_images.index(image)]

content.append(
{"type": "text", "text": tag + ":"})
content.append({"type": "image", "source": {
"type": "base64", "media_type": "image/jpeg", "data": f"{image}"}})
elif type == "Google":
content.append(
{"type": "text", "text": "The following images along with descriptions serve as reference. They are not to be mentioned in the response."})
for image in self.memory_images:
tag = self.memory_strings[self.memory_images.index(image)]

content.append(
{"type": "text", "text": tag + ":"})
content.append({"type": "image", "source": {
"type": "base64", "data": f"{image}"}})
elif type == "AWS":
content.append(
{"type": "text", "text": "The following images along with descriptions serve as reference. They are not to be mentioned in the response."})
for image in self.memory_images:
tag = self.memory_strings[self.memory_images.index(image)]

content.append(
{"text": tag + ":"})
content.append({"image": {
"format": "jpeg", "source": {"bytes": image}}})
else:
return None

return content

def get_system_prompt(self):
return "System prompt: " + self.system_prompt

def _find_memory_entry(self):
memory_entry = None
for entry in self.hass.config_entries.async_entries(DOMAIN):
# Check if the config entry is empty
if entry.data["provider"] == "Memory":
memory_entry = entry
break

if memory_entry is None:
_LOGGER.error("Memory entry not set up")
return None

return memory_entry

async def _encode_images(self, image_paths):
"""Encode images as base64"""
encoded_images = []

for image_path in image_paths:
img = await self.hass.loop.run_in_executor(None, Image.open, image_path)
with img:
# calculate new height and width based on aspect ratio
width, height = img.size
aspect_ratio = width / height
if aspect_ratio > 1:
new_width = 512
new_height = int(512 / aspect_ratio)
else:
new_height = 512
new_width = int(512 * aspect_ratio)
img = img.resize((new_width, new_height))

# Encode the image to base64
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='JPEG')
base64_image = base64.b64encode(
img_byte_arr.getvalue()).decode('utf-8')
encoded_images.append(base64_image)

return encoded_images

async def _update_memory(self):
"""Manage encoded images"""

# check if len(memory_paths) != len(memory_images)
if len(self.memory_paths) != len(self.memory_images):
self.memory_images = await self._encode_images(self.memory_paths)

# update memory with new images
memory = self.entry.data.copy()
memory['images'] = self.memory_images
self.hass.config_entries.async_update_entry(
self.entry, data=memory)

def __str__(self):
return f"Memory:({self.memory_strings}, {self.memory_paths})"
Loading

0 comments on commit 1bf93dd

Please sign in to comment.