diff --git a/src/airunner/alembic/env.py b/src/airunner/alembic/env.py index 4995a7f97..2acccf138 100644 --- a/src/airunner/alembic/env.py +++ b/src/airunner/alembic/env.py @@ -4,14 +4,15 @@ from logging.config import fileConfig from sqlalchemy import engine_from_config, pool, MetaData from alembic import context +from airunner.settings import DB_PATH + config = context.config -db_path = os.path.expanduser("~/.local/share/airunner/data/airunner.db") -config.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}") +config.set_main_option("sqlalchemy.url", f"sqlite:///{DB_PATH}") # check if db file exists -if not os.path.exists(db_path): - print(f"Database file not found at {db_path}") +if not os.path.exists(DB_PATH): + print(f"Database file not found at {DB_PATH}") # Import your models here from airunner.data.models.settings_models import ( diff --git a/src/airunner/alembic/versions/6579bf48ed83_add_chat_store_key_column_to_.py b/src/airunner/alembic/versions/6579bf48ed83_add_chat_store_key_column_to_.py new file mode 100644 index 000000000..b0e8fd7e7 --- /dev/null +++ b/src/airunner/alembic/versions/6579bf48ed83_add_chat_store_key_column_to_.py @@ -0,0 +1,32 @@ +"""add chat_store_key column to Conversation table + +Revision ID: 6579bf48ed83 +Revises: c2c5d4cd4b80 +Create Date: 2025-02-14 02:05:31.854218 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import sqlite + +# revision identifiers, used by Alembic. +revision: str = '6579bf48ed83' +down_revision: Union[str, None] = 'c2c5d4cd4b80' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('conversations', sa.Column('key', sa.String(), nullable=True)) + op.add_column('conversations', sa.Column('value', sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('conversations', 'key') + op.drop_column('converations', 'value') + # ### end Alembic commands ### diff --git a/src/airunner/data/models/settings_models.py b/src/airunner/data/models/settings_models.py index 25c3309b0..09d3c2a24 100644 --- a/src/airunner/data/models/settings_models.py +++ b/src/airunner/data/models/settings_models.py @@ -563,6 +563,8 @@ class Conversation(Base): title = Column(String, nullable=True) # New column added messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") bot_mood = Column(Text, default="") + key = Column(String, nullable=True) + value = Column(JSON, nullable=False) class Message(Base): diff --git a/src/airunner/handlers/llm/agent/mistral_agent.py b/src/airunner/handlers/llm/agent/mistral_agent.py index 638d33358..601787e91 100644 --- a/src/airunner/handlers/llm/agent/mistral_agent.py +++ b/src/airunner/handlers/llm/agent/mistral_agent.py @@ -2,14 +2,15 @@ Simple wrapper around AgentRunner + MistralAgentWorker. """ -import os +import uuid from typing import ( Any, List, Optional, Union, + Dict, ) -from pydantic import Field +import json import datetime import platform from PySide6.QtCore import QObject @@ -35,6 +36,7 @@ from airunner.handlers.llm.agent.memory.chat_memory_buffer import ChatMemoryBuffer from llama_index.core.memory import BaseMemory from airunner.handlers.llm.agent.tools.react_agent_tool import ReActAgentTool +from airunner.settings import CHAT_STORE_DB_PATH DEFAULT_MAX_FUNCTION_CALLS = 5 @@ -149,7 +151,6 @@ def rag_engine_tool(self) -> RAGEngineTool: ) return self._rag_engine_tool - @property def do_interrupt(self): return self._do_interrupt @@ -161,7 +162,7 @@ def do_interrupt(self, value): @property def conversation(self) -> Optional[Conversation]: if self._conversation is None: - self.conversation = self.create_conversation() + self.conversation = self.create_conversation(self.chat_store_key) return self._conversation @conversation.setter @@ -288,26 +289,27 @@ def _rag_system_prompt(self) -> str: @property def chat_store(self) -> SQLiteChatStore: if not self._chat_store: - db_path = os.path.expanduser( - os.path.join( - "~", - ".local", - "share", - "airunner", - "data", - "chat_store.db" - ) - ) + db_path = CHAT_STORE_DB_PATH self._chat_store = SQLiteChatStore.from_uri(f"sqlite:///{db_path}") return self._chat_store - + + @property + def chat_store_key(self) -> str: + if self._conversation: + return self._conversation.key + conversation = self.session.query(Conversation).order_by(Conversation.id.desc()).first() + if conversation: + self._conversation = conversation + return conversation.key + return "STK_" + uuid.uuid4().hex + @property def chat_memory(self) -> ChatMemoryBuffer: if not self._chat_memory: self._chat_memory = ChatMemoryBuffer.from_defaults( token_limit=3000, chat_store=self.chat_store, - chat_store_key="user1" + chat_store_key=self.chat_store_key ) return self._chat_memory @@ -331,8 +333,17 @@ def reload_rag(self): self._reload_rag() self._rag_engine_tool = None - def clear_history(self): - pass + def clear_history(self, data: Optional[Dict] = None): + data = data or {} + conversation_id = data.get("conversation_id") + self._conversation = self.session.query(Conversation).filter_by(id=conversation_id).first() + if self._conversation: + self._chat_memory.chat_store_key = self._conversation.key + messages = self._chat_store.get_messages(self._conversation.key) + if messages: + self._chat_memory.set(json.dumps(messages)) + if self._chat_engine: + self._chat_engine.memory = self._chat_memory def chat( self, diff --git a/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py b/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py index 9a2226ae6..d4831716f 100644 --- a/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py +++ b/src/airunner/handlers/llm/causal_lm_transformer_base_handler.py @@ -2,7 +2,7 @@ import os import torch from llama_index.llms.groq import Groq -from typing import Optional +from typing import Optional, Dict from transformers.utils.quantization_config import BitsAndBytesConfig, GPTQConfig from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.streamers import TextIteratorStreamer @@ -216,14 +216,14 @@ def do_interrupt(self): if self._chat_agent: self._chat_agent.interrupt_process() - def clear_history(self): + def clear_history(self, data: Optional[Dict] = None): """ Public method to clear the chat agent history """ if not self._chat_agent: return self.logger.debug("Clearing chat history") - self._chat_agent.clear_history() + self._chat_agent.clear_history(data) def add_chatbot_response_to_history(self, message): """ diff --git a/src/airunner/handlers/llm/storage/chat_store/sqlite.py b/src/airunner/handlers/llm/storage/chat_store/sqlite.py index fdc9c5c7d..fbb64c02c 100644 --- a/src/airunner/handlers/llm/storage/chat_store/sqlite.py +++ b/src/airunner/handlers/llm/storage/chat_store/sqlite.py @@ -1,6 +1,7 @@ import json from typing import Any, Optional from urllib.parse import urlparse +import datetime from sqlalchemy import ( Index, @@ -18,39 +19,7 @@ from sqlalchemy.dialects.sqlite import JSON, VARCHAR from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.storage.chat_store.base import BaseChatStore - - -def get_data_model( - base: type, - index_name: str, - schema_name: str, - use_jsonb: bool = False, -) -> Any: - """ - This part create a dynamic sqlalchemy model with a new table. - """ - tablename = f"data_{index_name}" # dynamic table name - class_name = f"Data{index_name}" # dynamic class name - - chat_dtype = JSON - - class AbstractData(base): # type: ignore - __abstract__ = True # this line is necessary - id = Column(Integer, primary_key=True, autoincrement=True) # Add primary key - key = Column(VARCHAR, nullable=False) - value = Column(chat_dtype) - - return type( - class_name, - (AbstractData,), - { - "__tablename__": tablename, - "__table_args__": ( - UniqueConstraint("key", name=f"{tablename}:unique_key"), - Index(f"{tablename}:idx_key", "key"), - ), - }, - ) +from airunner.data.models.settings_models import Conversation class SQLiteChatStore(BaseChatStore): @@ -78,17 +47,8 @@ def __init__( schema_name=schema_name.lower(), ) - # sqlalchemy model - base = declarative_base() - self._table_class = get_data_model( - base, - table_name, - schema_name, - use_jsonb=use_jsonb, - ) self._session = session self._async_session = async_session - self._initialize(base) @classmethod def from_params( @@ -143,84 +103,92 @@ def _connect( async_session = sessionmaker(_async_engine, class_=AsyncSession) return session, async_session - def _create_schema_if_not_exists(self) -> None: - # SQLite does not support schemas, so this method can be skipped - pass - - def _create_tables_if_not_exists(self, base) -> None: - with self._session() as session, session.begin(): - base.metadata.create_all(session.connection()) - - def _initialize(self, base) -> None: - self._create_tables_if_not_exists(base) - def set_messages(self, key: str, messages: list[ChatMessage]) -> None: """Set messages for a key.""" with self._session() as session: if messages is None or len(messages) == 0: # Retrieve the existing messages - result = session.execute(select(self._table_class).filter_by(key=key)).scalars().first() + result = session.query(Conversation).filter_by(key=key).first() if result: messages = result.value else: messages = [] - - # Update the database with the new list of messages - stmt = text( - f""" - INSERT INTO {self._table_class.__tablename__} (key, value) - VALUES (:key, :value) - ON CONFLICT (key) - DO UPDATE SET - value = :value; - """ - ) - params = { - "key": key, - "value": json.dumps([ - model.model_dump() if type(model) is ChatMessage else - model for model in messages - ])} - try: - session.execute(stmt, params) - session.commit() - except Exception as e: - print(e) + value = json.dumps([ + model.model_dump() if type(model) is ChatMessage else + model for model in messages + ]) + conversation = session.query(Conversation).filter_by(key=key).first() + if conversation: + conversation.value = value + else: + conversation = Conversation( + timestamp=datetime.datetime.now(datetime.timezone.utc), + title="", + key=key, + value=value + ) + session.add(conversation) + session.commit() async def aset_messages(self, key: str, messages: list[ChatMessage]) -> None: """Async version of Get messages for a key.""" async with self._async_session() as session: - stmt = text( - f""" - INSERT INTO {self._table_class.__tablename__} (key, value) - VALUES (:key, :value) - ON CONFLICT (key) - DO UPDATE SET - value = EXCLUDED.value; - """ - ) - - params = { - "key": key, - "value": json.dumps([ - message.model_dump() for message in messages - ]), - } - - # Execute the bulk upsert - await session.execute(stmt, params) + if messages is None or len(messages) == 0: + # Retrieve the existing messages + result = session.query(Conversation).filter_by(key=key).first() + if result: + messages = result.value + else: + messages = [] + value = json.dumps([ + model.model_dump() if type(model) is ChatMessage else + model for model in messages + ]) + conversation = await session.query(Conversation).filter_by(key=key).first() + if conversation: + conversation.value = value + else: + conversation = Conversation( + timestamp=datetime.datetime.now(datetime.timezone.utc), + title="", + key=key, + value=value + ) + session.add(conversation) await session.commit() + + def get_latest_chatstore(self) -> dict: + """Get the latest chatstore.""" + with self._session() as session: + result = session.query(Conversation).order_by( + Conversation.id.desc() + ).first() + return { + "key": result.key, + "value": result.value, + } if result else None + + def get_chatstores(self) -> list[dict]: + """Get all chatstores.""" + with self._session() as session: + result = session.query(Conversation).all() + return [ + { + "key": item.key, + "value": item.value, + } for item in result + ] def get_messages(self, key: str) -> list[ChatMessage]: """Get messages for a key.""" + messages = None with self._session() as session: - result = session.execute(select(self._table_class).filter_by(key=key)) - result = result.scalars().first() - if result: - if result: - messages = result.value - else: - messages = None + conversation = session.query(Conversation).filter_by(key=key).first() + if conversation: + data = conversation.value + if data: + messages = json.loads(data) + print("MESSAGES", messages) return [ ChatMessage.model_validate( ChatMessage( @@ -233,12 +201,11 @@ def get_messages(self, key: str) -> list[ChatMessage]: async def aget_messages(self, key: str) -> list[ChatMessage]: """Async version of Get messages for a key.""" async with self._async_session() as session: - result = await session.execute(select(self._table_class).filter_by(key=key)) - result = result.scalars().first() - if result: - messages = json.loads(result.value) + conversation = await session.query(Conversation).filter_by(key=key).first() + if conversation: + messages = json.loads(conversation.value) else: - messages = [] + messages = None return [ ChatMessage.model_validate( ChatMessage( @@ -252,78 +219,51 @@ def add_message(self, key: str, message: ChatMessage) -> None: """Add a message for a key.""" with self._session() as session: # Retrieve the existing messages - result = session.execute(select(self._table_class).filter_by(key=key)).scalars().first() - if result: + conversation = session.query(Conversation).filter_by(key=key).first() + if conversation: try: - messages = json.loads(result.value) + messages = json.loads(conversation.value) except TypeError: - messages = result.value + messages = conversation.value else: messages = [] + messages = [] if messages is None else messages # Append the new message messages.append(message.model_dump()) - - # Update the database with the new list of messages - stmt = text( - f""" - INSERT INTO {self._table_class.__tablename__} (key, value) - VALUES (:key, :value) - ON CONFLICT (key) - DO UPDATE SET - value = :value; - """ - ) - params = {"key": key, "value": json.dumps(messages)} - try: - session.execute(stmt, params) - session.commit() - except Exception as e: - print(e) + session.query(Conversation).filter_by(key=key).update({"value": json.dumps(messages)}) + session.commit() async def async_add_message(self, key: str, message: ChatMessage) -> None: """Async version of Add a message for a key.""" async with self._async_session() as session: # Retrieve the existing messages - result = await session.execute(select(self._table_class).filter_by(key=key)) - result = result.scalars().first() - if result: - messages = json.loads(result.value) + conversation = await session.query(Conversation).filter_by(key=key).first() + if conversation: + try: + messages = json.loads(conversation.value) + except TypeError: + messages = conversation.value else: messages = [] + messages = [] if messages is None else messages # Append the new message messages.append(message.model_dump()) - - # Update the database with the new list of messages - stmt = text( - f""" - INSERT INTO {self._table_class.__tablename__} (key, value) - VALUES (:key, :value) - ON CONFLICT (key) - DO UPDATE SET - value = :value; - """ - ) - params = {"key": key, "value": json.dumps(messages)} - try: - await session.execute(stmt, params) - await session.commit() - except Exception as e: - print(e) + await session.query(Conversation).filter_by(key=key).update({"value": json.dumps(messages)}) await session.commit() def delete_messages(self, key: str) -> Optional[list[ChatMessage]]: """Delete messages for a key.""" with self._session() as session: - session.execute(delete(self._table_class).filter_by(key=key)) + session.query(Conversation).filter_by(key=key).delete() session.commit() return None async def adelete_messages(self, key: str) -> Optional[list[ChatMessage]]: """Async version of Delete messages for a key.""" async with self._async_session() as session: - await session.execute(delete(self._table_class).filter_by(key=key)) + await session.query(Conversation).filter_by(key=key).delete() await session.commit() return None @@ -331,125 +271,109 @@ def delete_message(self, key: str, idx: int) -> Optional[ChatMessage]: """Delete specific message for a key.""" with self._session() as session: # First, retrieve the current list of messages - stmt = select(self._table_class.value).where(self._table_class.key == key) - result = session.execute(stmt).scalar_one_or_none() + conversation = session.query(Conversation).filter_by(key=key).first() + if conversation: + try: + messages = json.loads(conversation.value) + except TypeError: + messages = conversation.value + else: + messages = None - if result is None or idx < 0 or idx >= len(result): + if messages is None or idx < 0 or idx >= len(messages): # If the key doesn't exist or the index is out of bounds return None - + # Remove the message at the given index - removed_message = result[idx] - - stmt = text( - f""" - UPDATE {self._table_class.__tablename__} - SET value = json_remove({self._table_class.__tablename__}.value, '$[{idx}]') - WHERE key = :key; - """ - ) - - params = {"key": key} - session.execute(stmt, params) + removed_message = messages[idx] + messages.pop(idx) + session.query(Conversation).filter_by(key=key).update({"value": json.dumps(messages)}) session.commit() - return ChatMessage.model_validate(removed_message) async def adelete_message(self, key: str, idx: int) -> Optional[ChatMessage]: """Async version of Delete specific message for a key.""" async with self._async_session() as session: # First, retrieve the current list of messages - stmt = select(self._table_class.value).where(self._table_class.key == key) - result = (await session.execute(stmt)).scalar_one_or_none() + conversation = await session.query(Conversation).filter_by(key=key).first() + if conversation: + try: + messages = json.loads(conversation.value) + except TypeError: + messages = conversation.value + else: + messages = None - if result is None or idx < 0 or idx >= len(result): + if messages is None or idx < 0 or idx >= len(messages): # If the key doesn't exist or the index is out of bounds return None - + # Remove the message at the given index - removed_message = result[idx] - - stmt = text( - f""" - UPDATE {self._table_class.__tablename__} - SET value = json_remove({self._table_class.__tablename__}.value, '$[{idx}]') - WHERE key = :key; - """ - ) - - params = {"key": key} - await session.execute(stmt, params) + removed_message = messages[idx] + messages.pop(idx) + await session.query(Conversation).filter_by(key=key).update({"value": json.dumps(messages)}) await session.commit() - return ChatMessage.model_validate(removed_message) def delete_last_message(self, key: str) -> Optional[ChatMessage]: """Delete last message for a key.""" with self._session() as session: # First, retrieve the current list of messages - stmt = select(self._table_class.value).where(self._table_class.key == key) - result = session.execute(stmt).scalar_one_or_none() + conversation = session.query(Conversation).filter_by(key=key).first() + if conversation: + try: + messages = json.loads(conversation.value) + except TypeError: + messages = conversation.value + else: + messages = None - if result is None or len(result) == 0: + if messages is None or len(messages) == 0: # If the key doesn't exist or the array is empty return None - + # Remove the message at the given index - removed_message = result[-1] - - stmt = text( - f""" - UPDATE {self._table_class.__tablename__} - SET value = json_remove({self._table_class.__tablename__}.value, '$[#-1]') - WHERE key = :key; - """ - ) - params = {"key": key} - session.execute(stmt, params) + removed_message = messages[-1] + messages.pop(-1) + session.query(Conversation).filter_by(key=key).update({"value": json.dumps(messages)}) session.commit() - return ChatMessage.model_validate(removed_message) async def adelete_last_message(self, key: str) -> Optional[ChatMessage]: """Async version of Delete last message for a key.""" async with self._async_session() as session: # First, retrieve the current list of messages - stmt = select(self._table_class.value).where(self._table_class.key == key) - result = (await session.execute(stmt)).scalar_one_or_none() + conversation = await session.query(Conversation).filter_by(key=key).first() + if conversation: + try: + messages = json.loads(conversation.value) + except TypeError: + messages = conversation.value + else: + messages = None - if result is None or len(result) == 0: + if messages is None or len(messages) == 0: # If the key doesn't exist or the array is empty return None - + # Remove the message at the given index - removed_message = result[-1] - - stmt = text( - f""" - UPDATE {self._table_class.__tablename__} - SET value = json_remove({self._table_class.__tablename__}.value, '$[#-1]') - WHERE key = :key; - """ - ) - params = {"key": key} - await session.execute(stmt, params) + removed_message = messages[-1] + messages.pop(-1) + await session.query(Conversation).filter_by(key=key).update({"value": json.dumps(messages)}) await session.commit() - return ChatMessage.model_validate(removed_message) def get_keys(self) -> list[str]: """Get all keys.""" with self._session() as session: - stmt = select(self._table_class.key) - - return session.execute(stmt).scalars().all() + conversations = session.query(Conversation).all() + return [conversation.key for conversation in conversations] async def aget_keys(self) -> list[str]: """Async version of Get all keys.""" async with self._async_session() as session: - stmt = select(self._table_class.key) - - return (await session.execute(stmt)).scalars().all() + conversations = await session.query(Conversation).all() + return [conversation.key for conversation in conversations] def params_from_uri(uri: str) -> dict: diff --git a/src/airunner/main.py b/src/airunner/main.py index 578b29a96..757aa213e 100755 --- a/src/airunner/main.py +++ b/src/airunner/main.py @@ -11,7 +11,7 @@ # variables for the application. ################################################################ # import facehuggershield -from airunner.settings import NLTK_DOWNLOAD_DIR +from airunner.settings import DB_PATH import os base_path = os.path.join(os.path.expanduser("~"), ".local", "share", "airunner") # facehuggershield.huggingface.activate( @@ -45,7 +45,7 @@ from alembic.config import Config from alembic import command from pathlib import Path -from airunner.data.models.settings_models import ApplicationSettings, PathSettings +from airunner.data.models.settings_models import ApplicationSettings from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, scoped_session @@ -68,7 +68,7 @@ def main(): setup_database() # Get the first ApplicationSettings record from the database and check for run_setup_wizard boolean - engine = create_engine("sqlite:///" + os.path.join(base_dir, "airunner.db")) + engine = create_engine(f"sqlite:///{DB_PATH}") session = scoped_session(sessionmaker(bind=engine)) application_settings = session.query(ApplicationSettings).first() diff --git a/src/airunner/settings.py b/src/airunner/settings.py index ead744f6f..25667c451 100644 --- a/src/airunner/settings.py +++ b/src/airunner/settings.py @@ -368,3 +368,14 @@ "signal": SignalCode.REFRESH_STYLESHEET_SIGNAL.value }, ] +DB_PATH = os.path.expanduser( + os.path.join( + "~", + ".local", + "share", + "airunner", + "data", + "airunner.db" + ) +) +CHAT_STORE_DB_PATH = DB_PATH diff --git a/src/airunner/widgets/generator_form/generator_form_widget.py b/src/airunner/widgets/generator_form/generator_form_widget.py index 6a85a266e..e87e61e9a 100644 --- a/src/airunner/widgets/generator_form/generator_form_widget.py +++ b/src/airunner/widgets/generator_form/generator_form_widget.py @@ -485,7 +485,6 @@ def stop_progress_bar(self, do_clear=False): progressbar.setFormat("Complete") def _set_keyboard_shortcuts(self): - generate_image_key = self.session.query(ShortcutKeys).filter_by(display_name="Generate Image").first() interrupt_key = self.session.query(ShortcutKeys).filter_by(display_name="Interrupt").first() if generate_image_key: diff --git a/src/airunner/widgets/llm/chat_prompt_widget.py b/src/airunner/widgets/llm/chat_prompt_widget.py index 1a5bb0a00..f2c099f78 100644 --- a/src/airunner/widgets/llm/chat_prompt_widget.py +++ b/src/airunner/widgets/llm/chat_prompt_widget.py @@ -1,3 +1,5 @@ +import uuid + from PySide6.QtCore import Slot, QTimer, QPropertyAnimation from PySide6.QtWidgets import QSpacerItem, QSizePolicy from PySide6.QtCore import Qt @@ -6,6 +8,7 @@ from airunner.widgets.base_widget import BaseWidget from airunner.widgets.llm.templates.chat_prompt_ui import Ui_chat_prompt from airunner.widgets.llm.message_widget import MessageWidget +from airunner.data.models.settings_models import Conversation class ChatPromptWidget(BaseWidget): @@ -150,7 +153,11 @@ def _clear_conversation(self): self._create_conversation() def _create_conversation(self): - conversation = self.create_conversation() + conversation = self.session.query(Conversation).order_by( + Conversation.id.desc() + ).first() + if not conversation: + conversation = self.create_conversation("cpw_" + uuid.uuid4().hex) conversation_id = conversation.id self.emit_signal(SignalCode.LLM_CLEAR_HISTORY_SIGNAL, { "conversation_id": conversation_id diff --git a/src/airunner/widgets/llm/llm_history_widget.py b/src/airunner/widgets/llm/llm_history_widget.py index 880dd8d25..6d1b3a51f 100644 --- a/src/airunner/widgets/llm/llm_history_widget.py +++ b/src/airunner/widgets/llm/llm_history_widget.py @@ -7,6 +7,8 @@ from airunner.widgets.base_widget import BaseWidget from airunner.widgets.llm.llm_history_item_widget import LLMHistoryItemWidget from airunner.widgets.llm.templates.llm_history_widget_ui import Ui_llm_history_widget +from airunner.data.models.settings_models import Conversation + class LLMHistoryWidget(BaseWidget): widget_class_ = Ui_llm_history_widget @@ -43,8 +45,8 @@ def load_conversations(self): self.ui.scrollAreaWidgetContents.setLayout(layout) for conversation in conversations: - if conversation.title == "": - continue + # if conversation.title == "": + # continue llm_history_item_widget = LLMHistoryItemWidget( conversation=conversation ) diff --git a/src/airunner/windows/main/settings_mixin.py b/src/airunner/windows/main/settings_mixin.py index 43de3e7c0..30142f413 100644 --- a/src/airunner/windows/main/settings_mixin.py +++ b/src/airunner/windows/main/settings_mixin.py @@ -1,6 +1,6 @@ import logging import datetime -import os +import json from typing import List, Type from sqlalchemy import create_engine @@ -14,6 +14,7 @@ MemorySettings, Message, Conversation, Summary, ImageFilterValue, TargetFiles, WhisperSettings, Base, User from airunner.enums import SignalCode from airunner.utils.image.convert_binary_to_image import convert_binary_to_image +from airunner.settings import DB_PATH class SettingsMixinSharedInstance: @@ -28,16 +29,7 @@ def __new__(cls, *args, **kwargs): def __init__(self): if self._initialized: return - self.db_path = os.path.expanduser( - os.path.join( - "~", - ".local", - "share", - "airunner", - "data", - "airunner.db" - ) - ) + self.db_path = DB_PATH self.engine = create_engine(f'sqlite:///{self.db_path}') Base.metadata.create_all(self.engine) self.Session = scoped_session(sessionmaker(bind=self.engine)) @@ -738,9 +730,11 @@ def get_chatbot_by_id(self, chatbot_id) -> Type[Chatbot]: chatbot = self.session.query(Chatbot).options(joinedload(Chatbot.target_files)).first() return chatbot - def create_conversation(self): + def create_conversation(self, chat_store_key: str): # find conversation which has no title, bot_mood or messages - conversation = self.session.query(Conversation).filter_by(title="", bot_mood="").first() + conversation = self.session.query(Conversation).filter_by( + key=chat_store_key + ).first() if conversation: # ensure there are no messages in the conversation message = self.session.query(Message).filter_by(conversation_id=conversation.id).first() @@ -748,7 +742,9 @@ def create_conversation(self): return conversation conversation = Conversation( timestamp=datetime.datetime.now(datetime.timezone.utc), - title="" + title="", + key=chat_store_key, + value=json.dumps([]) ) self.session.add(conversation) self.session.commit() @@ -769,7 +765,7 @@ def add_summary(self, content, conversation_id): ) self.session.add(summary) self.session.commit() - + def get_all_conversations(self): conversations = self.session.query(Conversation).all() return conversations diff --git a/src/airunner/workers/llm_generate_worker.py b/src/airunner/workers/llm_generate_worker.py index 430008113..c46530314 100644 --- a/src/airunner/workers/llm_generate_worker.py +++ b/src/airunner/workers/llm_generate_worker.py @@ -1,4 +1,5 @@ import threading +from typing import Dict, Optional from airunner.handlers.llm.causal_lm_transformer_base_handler import CausalLMTransformerBaseHandler from airunner.enums import SignalCode @@ -47,9 +48,9 @@ def on_llm_on_unload_signal(self, data=None): def on_llm_load_model_signal(self, data): self._load_llm_thread(data) - def on_llm_clear_history_signal(self): + def on_llm_clear_history_signal(self, data:Optional[Dict] = None): if self.llm: - self.llm.clear_history() + self.llm.clear_history(data) def on_llm_request_signal(self, message: dict): self.add_to_queue(message)