Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add option to disable asynchronous memory addition in completio… #2124

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions mem0/proxy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"])
import litellm
except subprocess.CalledProcessError:
print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.")
print(
"Failed to install 'litellm'. Please install it manually using 'pip install litellm'."
)
sys.exit(1)
else:
raise ImportError("The required 'litellm' library is not installed.")
Expand Down Expand Up @@ -96,6 +98,8 @@ def create(
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
# New parameter
add_to_memory: bool = True,
):
if not any([user_id, agent_id, run_id]):
raise ValueError("One of user_id, agent_id, run_id must be provided")
Expand All @@ -107,10 +111,17 @@ def create(

prepared_messages = self._prepare_messages(messages)
if prepared_messages[-1]["role"] == "user":
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
if add_to_memory:
self._async_add_to_memory(
messages, user_id, agent_id, run_id, metadata, filters
)
relevant_memories = self._fetch_relevant_memories(
messages, user_id, agent_id, run_id, filters, limit
)
logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
prepared_messages[-1]["content"] = self._format_query_with_memories(
messages, relevant_memories
)

response = litellm.completion(
model=model,
Expand Down Expand Up @@ -151,7 +162,9 @@ def _prepare_messages(self, messages: List[dict]) -> List[dict]:
return [{"role": "system", "content": MEMORY_ANSWER_PROMPT}] + messages
return messages

def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
def _async_add_to_memory(
self, messages, user_id, agent_id, run_id, metadata, filters
):
def add_task():
logger.debug("Adding to memory asynchronously")
self.mem0_client.add(
Expand All @@ -165,9 +178,13 @@ def add_task():

threading.Thread(target=add_task, daemon=True).start()

def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
def _fetch_relevant_memories(
self, messages, user_id, agent_id, run_id, filters, limit
):
# Currently, only pass the last 6 messages to the search API to prevent long query
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
message_input = [
f"{message['role']}: {message['content']}" for message in messages
][-6:]
# TODO: Make it better by summarizing the past conversation
return self.mem0_client.search(
query="\n".join(message_input),
Expand All @@ -182,7 +199,9 @@ def _format_query_with_memories(self, messages, relevant_memories):
# Check if self.mem0_client is an instance of Memory or MemoryClient

if isinstance(self.mem0_client, mem0.memory.main.Memory):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"])
memories_text = "\n".join(
memory["memory"] for memory in relevant_memories["results"]
)
elif isinstance(self.mem0_client, mem0.client.main.MemoryClient):
memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
return f"- Relevant Memories/Facts: {memories_text}\n\n- User Question: {messages[-1]['content']}"