Skip to content

Commit

Permalink
fix usage of chatstore memory
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed Feb 14, 2025
1 parent f27fafc commit 99be0a9
Show file tree
Hide file tree
Showing 13 changed files with 258 additions and 272 deletions.
9 changes: 5 additions & 4 deletions src/airunner/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool, MetaData
from alembic import context
from airunner.settings import DB_PATH


config = context.config
db_path = os.path.expanduser("~/.local/share/airunner/data/airunner.db")
config.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
config.set_main_option("sqlalchemy.url", f"sqlite:///{DB_PATH}")

# check if db file exists
if not os.path.exists(db_path):
print(f"Database file not found at {db_path}")
if not os.path.exists(DB_PATH):
print(f"Database file not found at {DB_PATH}")

# Import your models here
from airunner.data.models.settings_models import (
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""add chat_store_key column to Conversation table
Revision ID: 6579bf48ed83
Revises: c2c5d4cd4b80
Create Date: 2025-02-14 02:05:31.854218
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import sqlite

# revision identifiers, used by Alembic.
revision: str = '6579bf48ed83'
down_revision: Union[str, None] = 'c2c5d4cd4b80'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('conversations', sa.Column('key', sa.String(), nullable=True))
op.add_column('conversations', sa.Column('value', sa.JSON(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('conversations', 'key')
op.drop_column('converations', 'value')
# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions src/airunner/data/models/settings_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,8 @@ class Conversation(Base):
title = Column(String, nullable=True) # New column added
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
bot_mood = Column(Text, default="")
key = Column(String, nullable=True)
value = Column(JSON, nullable=False)


class Message(Base):
Expand Down
47 changes: 29 additions & 18 deletions src/airunner/handlers/llm/agent/mistral_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
Simple wrapper around AgentRunner + MistralAgentWorker.
"""
import os
import uuid
from typing import (
Any,
List,
Optional,
Union,
Dict,
)
from pydantic import Field
import json
import datetime
import platform
from PySide6.QtCore import QObject
Expand All @@ -35,6 +36,7 @@
from airunner.handlers.llm.agent.memory.chat_memory_buffer import ChatMemoryBuffer
from llama_index.core.memory import BaseMemory
from airunner.handlers.llm.agent.tools.react_agent_tool import ReActAgentTool
from airunner.settings import CHAT_STORE_DB_PATH


DEFAULT_MAX_FUNCTION_CALLS = 5
Expand Down Expand Up @@ -149,7 +151,6 @@ def rag_engine_tool(self) -> RAGEngineTool:
)
return self._rag_engine_tool


@property
def do_interrupt(self):
return self._do_interrupt
Expand All @@ -161,7 +162,7 @@ def do_interrupt(self, value):
@property
def conversation(self) -> Optional[Conversation]:
if self._conversation is None:
self.conversation = self.create_conversation()
self.conversation = self.create_conversation(self.chat_store_key)
return self._conversation

@conversation.setter
Expand Down Expand Up @@ -288,26 +289,27 @@ def _rag_system_prompt(self) -> str:
@property
def chat_store(self) -> SQLiteChatStore:
if not self._chat_store:
db_path = os.path.expanduser(
os.path.join(
"~",
".local",
"share",
"airunner",
"data",
"chat_store.db"
)
)
db_path = CHAT_STORE_DB_PATH
self._chat_store = SQLiteChatStore.from_uri(f"sqlite:///{db_path}")
return self._chat_store


@property
def chat_store_key(self) -> str:
if self._conversation:
return self._conversation.key
conversation = self.session.query(Conversation).order_by(Conversation.id.desc()).first()
if conversation:
self._conversation = conversation
return conversation.key
return "STK_" + uuid.uuid4().hex

@property
def chat_memory(self) -> ChatMemoryBuffer:
if not self._chat_memory:
self._chat_memory = ChatMemoryBuffer.from_defaults(
token_limit=3000,
chat_store=self.chat_store,
chat_store_key="user1"
chat_store_key=self.chat_store_key
)
return self._chat_memory

Expand All @@ -331,8 +333,17 @@ def reload_rag(self):
self._reload_rag()
self._rag_engine_tool = None

def clear_history(self):
pass
def clear_history(self, data: Optional[Dict] = None):
data = data or {}
conversation_id = data.get("conversation_id")
self._conversation = self.session.query(Conversation).filter_by(id=conversation_id).first()
if self._conversation:
self._chat_memory.chat_store_key = self._conversation.key
messages = self._chat_store.get_messages(self._conversation.key)
if messages:
self._chat_memory.set(json.dumps(messages))
if self._chat_engine:
self._chat_engine.memory = self._chat_memory

def chat(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import torch
from llama_index.llms.groq import Groq
from typing import Optional
from typing import Optional, Dict
from transformers.utils.quantization_config import BitsAndBytesConfig, GPTQConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.streamers import TextIteratorStreamer
Expand Down Expand Up @@ -216,14 +216,14 @@ def do_interrupt(self):
if self._chat_agent:
self._chat_agent.interrupt_process()

def clear_history(self):
def clear_history(self, data: Optional[Dict] = None):
"""
Public method to clear the chat agent history
"""
if not self._chat_agent:
return
self.logger.debug("Clearing chat history")
self._chat_agent.clear_history()
self._chat_agent.clear_history(data)

def add_chatbot_response_to_history(self, message):
"""
Expand Down
Loading

0 comments on commit 99be0a9

Please sign in to comment.