From 4d4e4ea147eb3cc587921404e132b3f51fc976b1 Mon Sep 17 00:00:00 2001 From: Mayank Solanki Date: Sat, 28 Sep 2024 03:43:37 +0530 Subject: [PATCH 1/4] async add function impl --- mem0/embeddings/openai.py | 17 ++- mem0/llms/openai.py | 59 +++++++++- mem0/llms/openai_structured.py | 38 ++++++- mem0/memory/graph_memory.py | 194 +++++++++++++++++++++++++++++++++ mem0/memory/main.py | 168 ++++++++++++++++++++++++++++ 5 files changed, 473 insertions(+), 3 deletions(-) diff --git a/mem0/embeddings/openai.py b/mem0/embeddings/openai.py index b68b8ffc09..0ba2a7518f 100644 --- a/mem0/embeddings/openai.py +++ b/mem0/embeddings/openai.py @@ -1,7 +1,7 @@ import os from typing import Optional -from openai import OpenAI +from openai import AsyncOpenAI, OpenAI from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.embeddings.base import EmbeddingBase @@ -17,6 +17,7 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None): api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") self.client = OpenAI(api_key=api_key, base_url=base_url) + self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url) def embed(self, text): """ @@ -30,3 +31,17 @@ def embed(self, text): """ text = text.replace("\n", " ") return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding + + async def aembed(self, text): + """ + Get the embedding for the given text using OpenAI. + + Args: + text (str): The text to embed. + + Returns: + list: The embedding vector. + """ + text = text.replace("\n", " ") + response = await self.async_client.embeddings.create(input=[text], model=self.config.model) + return response.data[0].embedding \ No newline at end of file diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index c585162f29..80b352c3c8 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -2,7 +2,7 @@ import os from typing import Dict, List, Optional -from openai import OpenAI +from openai import AsyncOpenAI, OpenAI from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase @@ -20,10 +20,15 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1", ) + self.async_client = AsyncOpenAI( + api_key=os.environ.get("OPENROUTER_API_KEY"), + base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1", + ) else: api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" self.client = OpenAI(api_key=api_key, base_url=base_url) + self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url) def _parse_response(self, response, tools): """ @@ -106,3 +111,55 @@ def generate_response( response = self.client.chat.completions.create(**params) return self._parse_response(response, tools) + + async def agenerate_response( + self, + messages: List[Dict[str, str]], + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): + """ + Generate a response based on the given messages using OpenAI. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". + + Returns: + str: The generated response. + """ + params = { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + "max_tokens": self.config.max_tokens, + "top_p": self.config.top_p, + } + + if os.getenv("OPENROUTER_API_KEY"): + openrouter_params = {} + if self.config.models: + openrouter_params["models"] = self.config.models + openrouter_params["route"] = self.config.route + params.pop("model") + + if self.config.site_url and self.config.app_name: + extra_headers = { + "HTTP-Referer": self.config.site_url, + "X-Title": self.config.app_name, + } + openrouter_params["extra_headers"] = extra_headers + + params.update(**openrouter_params) + + if response_format: + params["response_format"] = response_format + if tools: # TODO: Remove tools if no issues found with new memory addition logic + params["tools"] = tools + params["tool_choice"] = tool_choice + + response = await self.async_client.chat.completions.create(**params) + return self._parse_response(response, tools) diff --git a/mem0/llms/openai_structured.py b/mem0/llms/openai_structured.py index 1ccba28a39..2382db8709 100644 --- a/mem0/llms/openai_structured.py +++ b/mem0/llms/openai_structured.py @@ -2,7 +2,7 @@ import os from typing import Dict, List, Optional -from openai import OpenAI +from openai import AsyncOpenAI, OpenAI from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase @@ -18,6 +18,7 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" self.client = OpenAI(api_key=api_key, base_url=base_url) + self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url) def _parse_response(self, response, tools): """ @@ -85,3 +86,38 @@ def generate_response( response = self.client.beta.chat.completions.parse(**params) return self._parse_response(response, tools) + + async def agenerate_response( + self, + messages: List[Dict[str, str]], + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): + """ + Generate a response based on the given messages using OpenAI. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". + + Returns: + str: The generated response. + """ + params = { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + } + + if response_format: + params["response_format"] = response_format + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice + + response = await self.async_client.beta.chat.completions.parse(**params) + + return self._parse_response(response, tools) diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index 987b923e39..c30145fc0e 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -359,3 +359,197 @@ def _update_relationship(self, source, target, relationship, filters): if not result: raise Exception(f"Failed to update or create relationship between {source} and {target}") + + async def aadd(self, data, filters): + """ + Adds data to the graph. + + Args: + data (str): The data to add to the graph. + filters (dict): A dictionary containing filters to be applied during the addition. + """ + if self.config.llm.provider not in ['openai', 'openai_structured']: + raise NotImplementedError(f"The async add is not implemented for {self.config.llm.provider}.") + + if self.config.embedder.provider not in ['openai']: + raise NotImplementedError(f"The async add is not implemented for {self.config.embedder.provider}.") + + # retrieve the search results + search_output = await self._asearch(data, filters) + + if self.config.graph_store.custom_prompt: + messages = [ + { + "role": "system", + "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace( + "CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}" + ), + }, + {"role": "user", "content": data}, + ] + else: + messages = [ + { + "role": "system", + "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id), + }, + {"role": "user", "content": data}, + ] + + _tools = [ADD_MESSAGE_TOOL] + if self.llm_provider in ["azure_openai_structured", "openai_structured"]: + _tools = [ADD_MESSAGE_STRUCT_TOOL] + + extracted_entities = await self.llm.agenerate_response( + messages=messages, + tools=_tools, + ) + + if extracted_entities["tool_calls"]: + extracted_entities = extracted_entities["tool_calls"][0]["arguments"]["entities"] + else: + extracted_entities = [] + + logger.debug(f"Extracted entities: {extracted_entities}") + + update_memory_prompt = get_update_memory_messages(search_output, extracted_entities) + + _tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL] + if self.llm_provider in ["azure_openai_structured", "openai_structured"]: + _tools = [ + UPDATE_MEMORY_STRUCT_TOOL_GRAPH, + ADD_MEMORY_STRUCT_TOOL_GRAPH, + NOOP_STRUCT_TOOL, + ] + + memory_updates = await self.llm.agenerate_response( + messages=update_memory_prompt, + tools=_tools, + ) + + to_be_added = [] + + for item in memory_updates["tool_calls"]: + if item["name"] == "add_graph_memory": + to_be_added.append(item["arguments"]) + elif item["name"] == "update_graph_memory": + self._update_relationship( + item["arguments"]["source"], + item["arguments"]["destination"], + item["arguments"]["relationship"], + filters, + ) + elif item["name"] == "noop": + continue + + returned_entities = [] + + for item in to_be_added: + source = item["source"].lower().replace(" ", "_") + source_type = item["source_type"].lower().replace(" ", "_") + relation = item["relationship"].lower().replace(" ", "_") + destination = item["destination"].lower().replace(" ", "_") + destination_type = item["destination_type"].lower().replace(" ", "_") + + returned_entities.append({"source": source, "relationship": relation, "target": destination}) + + # Create embeddings + source_embedding = await self.embedding_model.aembed(source) + dest_embedding = await self.embedding_model.aembed(destination) + + # Updated Cypher query to include node types and embeddings + cypher = f""" + MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}}) + ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding + ON MATCH SET n.embedding = $source_embedding + MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}}) + ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding + ON MATCH SET m.embedding = $dest_embedding + MERGE (n)-[rel:{relation}]->(m) + ON CREATE SET rel.created = timestamp() + RETURN n, rel, m + """ + + params = { + "source_name": source, + "dest_name": destination, + "source_embedding": source_embedding, + "dest_embedding": dest_embedding, + "user_id": filters["user_id"], + } + + _ = self.graph.query(cypher, params=params) + + logger.info(f"Added {len(to_be_added)} new memories to the graph") + + return returned_entities + + async def _asearch(self, query, filters): + _tools = [SEARCH_TOOL] + if self.llm_provider in ["azure_openai_structured", "openai_structured"]: + _tools = [SEARCH_STRUCT_TOOL] + search_results = await self.llm.agenerate_response( + messages=[ + { + "role": "system", + "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities.", + }, + {"role": "user", "content": query}, + ], + tools=_tools, + ) + + node_list = [] + relation_list = [] + + for item in search_results["tool_calls"]: + if item["name"] == "search": + try: + node_list.extend(item["arguments"]["nodes"]) + except Exception as e: + logger.error(f"Error in search tool: {e}") + + node_list = list(set(node_list)) + relation_list = list(set(relation_list)) + + node_list = [node.lower().replace(" ", "_") for node in node_list] + relation_list = [relation.lower().replace(" ", "_") for relation in relation_list] + + logger.debug(f"Node list for search query : {node_list}") + + result_relations = [] + + for node in node_list: + n_embedding = await self.embedding_model.aembed(node) + + cypher_query = """ + MATCH (n) + WHERE n.embedding IS NOT NULL AND n.user_id = $user_id + WITH n, + round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity + WHERE similarity >= $threshold + MATCH (n)-[r]->(m) + RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity + UNION + MATCH (n) + WHERE n.embedding IS NOT NULL AND n.user_id = $user_id + WITH n, + round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity + WHERE similarity >= $threshold + MATCH (m)-[r]->(n) + RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity + ORDER BY similarity DESC + """ + params = { + "n_embedding": n_embedding, + "threshold": self.threshold, + "user_id": filters["user_id"], + } + ans = self.graph.query(cypher_query, params=params) + result_relations.extend(ans) + + return result_relations \ No newline at end of file diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 7d82c63238..68ead1a89d 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import hashlib import json @@ -581,3 +582,170 @@ def reset(self): def chat(self, query): raise NotImplementedError("Chat function not implemented yet.") + + async def _aadd_to_vector_store(self, messages, metadata, filters): + parsed_messages = parse_messages(messages) + + if self.custom_prompt: + system_prompt = self.custom_prompt + user_prompt = f"Input: {parsed_messages}" + else: + system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages) + + response = await self.llm.agenerate_response( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + ) + + try: + new_retrieved_facts = json.loads(response)["facts"] + except Exception as e: + logging.error(f"Error in new_retrieved_facts: {e}") + new_retrieved_facts = [] + + retrieved_old_memory = [] + new_message_embeddings = {} + for new_mem in new_retrieved_facts: + messages_embeddings = await self.embedding_model.aembed(new_mem) + new_message_embeddings[new_mem] = messages_embeddings + # todo: add async vector store call + existing_memories = self.vector_store.search( + query=messages_embeddings, + limit=5, + filters=filters, + ) + for mem in existing_memories: + retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]}) + + logging.info(f"Total existing memories: {len(retrieved_old_memory)}") + + function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts) + new_memories_with_actions = await self.llm.agenerate_response( + messages=[{"role": "user", "content": function_calling_prompt}], + response_format={"type": "json_object"}, + ) + new_memories_with_actions = json.loads(new_memories_with_actions) + + returned_memories = [] + try: + for resp in new_memories_with_actions["memory"]: + logging.info(resp) + try: + if resp["event"] == "ADD": + _ = self._create_memory(data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) + returned_memories.append( + { + "memory": resp["text"], + "event": resp["event"], + } + ) + elif resp["event"] == "UPDATE": + self._update_memory(memory_id=resp["id"], data=resp["text"], existing_embeddings=new_message_embeddings, metadata=metadata) + returned_memories.append( + { + "memory": resp["text"], + "event": resp["event"], + "previous_memory": resp["old_memory"], + } + ) + elif resp["event"] == "DELETE": + self._delete_memory(memory_id=resp["id"]) + returned_memories.append( + { + "memory": resp["text"], + "event": resp["event"], + } + ) + elif resp["event"] == "NONE": + logging.info("NOOP for Memory.") + except Exception as e: + logging.error(f"Error in new_memories_with_actions: {e}") + except Exception as e: + logging.error(f"Error in new_memories_with_actions: {e}") + + capture_event("mem0.add", self) + + return returned_memories + + async def _aadd_to_graph(self, messages, filters): + added_entities = [] + if self.version == "v1.1" and self.enable_graph: + if filters["user_id"]: + self.graph.user_id = filters["user_id"] + else: + self.graph.user_id = "USER" + data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"]) + await self.graph.aadd(data, filters) + + return added_entities + + async def aadd( + self, + messages, + user_id=None, + agent_id=None, + run_id=None, + metadata=None, + filters=None, + prompt=None, + ): + """ + Create a new memory asynchronously. + + Args: + messages (str or List[Dict[str, str]]): Messages to store in the memory. + user_id (str, optional): ID of the user creating the memory. Defaults to None. + agent_id (str, optional): ID of the agent creating the memory. Defaults to None. + run_id (str, optional): ID of the run creating the memory. Defaults to None. + metadata (dict, optional): Metadata to store with the memory. Defaults to None. + filters (dict, optional): Filters to apply to the search. Defaults to None. + prompt (str, optional): Prompt to use for memory deduction. Defaults to None. + + Returns: + dict: A dictionary containing the result of the memory addition operation. + """ + if self.config.llm.provider not in ['openai', 'openai_structured']: + raise NotImplementedError(f"The async add is not implemented for {self.config.llm.provider}.") + + if self.config.embedder.provider not in ['openai']: + raise NotImplementedError(f"The async add is not implemented for {self.config.embedder.provider}.") + + if metadata is None: + metadata = {} + + filters = filters or {} + if user_id: + filters["user_id"] = metadata["user_id"] = user_id + if agent_id: + filters["agent_id"] = metadata["agent_id"] = agent_id + if run_id: + filters["run_id"] = metadata["run_id"] = run_id + + if not any(key in filters for key in ("user_id", "agent_id", "run_id")): + raise ValueError("One of the filters: user_id, agent_id or run_id is required!") + + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + vector_store_result, graph_result = await asyncio.gather( + self._aadd_to_vector_store(messages, metadata, filters), + self._aadd_to_graph(messages, filters) + ) + + if self.version == "v1.1": + return { + "results": vector_store_result, + "relations": graph_result, + } + else: + warnings.warn( + "The current add API output format is deprecated. " + "To use the latest format, set `api_version='v1.1'`. " + "The current format will be removed in mem0ai 1.1.0 and later versions.", + category=DeprecationWarning, + stacklevel=2, + ) + return {"message": "ok"} \ No newline at end of file From 6b0ce98d2635fe0091a99317aae09e9ece806519 Mon Sep 17 00:00:00 2001 From: Mayank Solanki Date: Mon, 30 Sep 2024 18:51:27 +0530 Subject: [PATCH 2/4] tests edited --- tests/embeddings/test_openai_embeddings.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/embeddings/test_openai_embeddings.py b/tests/embeddings/test_openai_embeddings.py index 113a8e64c9..9ddc5b2e6e 100644 --- a/tests/embeddings/test_openai_embeddings.py +++ b/tests/embeddings/test_openai_embeddings.py @@ -1,7 +1,9 @@ -import pytest from unittest.mock import Mock, patch -from mem0.embeddings.openai import OpenAIEmbedding + +import pytest + from mem0.configs.embeddings.base import BaseEmbedderConfig +from mem0.embeddings.openai import OpenAIEmbedding @pytest.fixture @@ -11,8 +13,14 @@ def mock_openai_client(): mock_openai.return_value = mock_client yield mock_client +@pytest.fixture +def mock_openai_async_client(): + with patch("mem0.embeddings.openai.AsyncOpenAI") as mock_async_openai: + mock_async_client = Mock() + mock_async_openai.return_value = mock_async_client + yield mock_async_client -def test_embed_default_model(mock_openai_client): +def test_embed_default_model(mock_openai_client, mock_openai_async_client): config = BaseEmbedderConfig() embedder = OpenAIEmbedding(config) mock_response = Mock() @@ -25,7 +33,7 @@ def test_embed_default_model(mock_openai_client): assert result == [0.1, 0.2, 0.3] -def test_embed_custom_model(mock_openai_client): +def test_embed_custom_model(mock_openai_client, mock_openai_async_client): config = BaseEmbedderConfig(model="text-embedding-2-medium", embedding_dims=1024) embedder = OpenAIEmbedding(config) mock_response = Mock() @@ -40,7 +48,7 @@ def test_embed_custom_model(mock_openai_client): assert result == [0.4, 0.5, 0.6] -def test_embed_removes_newlines(mock_openai_client): +def test_embed_removes_newlines(mock_openai_client, mock_openai_async_client): config = BaseEmbedderConfig() embedder = OpenAIEmbedding(config) mock_response = Mock() @@ -53,7 +61,7 @@ def test_embed_removes_newlines(mock_openai_client): assert result == [0.7, 0.8, 0.9] -def test_embed_without_api_key_env_var(mock_openai_client): +def test_embed_without_api_key_env_var(mock_openai_client, mock_openai_async_client): config = BaseEmbedderConfig(api_key="test_key") embedder = OpenAIEmbedding(config) mock_response = Mock() @@ -68,7 +76,7 @@ def test_embed_without_api_key_env_var(mock_openai_client): assert result == [1.0, 1.1, 1.2] -def test_embed_uses_environment_api_key(mock_openai_client, monkeypatch): +def test_embed_uses_environment_api_key(mock_openai_client, monkeypatch, mock_openai_async_client): monkeypatch.setenv("OPENAI_API_KEY", "env_key") config = BaseEmbedderConfig() embedder = OpenAIEmbedding(config) From 05cac06b83f397a955fda4234aec0e8e814f222f Mon Sep 17 00:00:00 2001 From: Mayank Solanki Date: Mon, 30 Sep 2024 21:30:53 +0530 Subject: [PATCH 3/4] minor --- tests/llms/test_openai.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index 9e62c6f2cb..9d0125a631 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -1,5 +1,6 @@ -from unittest.mock import Mock, patch import os +from unittest.mock import Mock, patch + import pytest from mem0.configs.llms.base import BaseLlmConfig @@ -13,6 +14,13 @@ def mock_openai_client(): mock_openai.return_value = mock_client yield mock_client +@pytest.fixture +def mock_openai_async_client(): + with patch("mem0.llms.openai.AsyncOpenAI") as mock_async_openai: + mock_client = Mock() + mock_async_openai.return_value = mock_client + yield mock_client + def test_openai_llm_base_url(): # case1: default config: with openai official base url @@ -38,7 +46,7 @@ def test_openai_llm_base_url(): assert str(llm.client.base_url) == config_base_url + "/" -def test_generate_response_without_tools(mock_openai_client): +def test_generate_response_without_tools(mock_openai_client, mock_openai_async_client): config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0) llm = OpenAILLM(config) messages = [ @@ -58,7 +66,7 @@ def test_generate_response_without_tools(mock_openai_client): assert response == "I'm doing well, thank you for asking!" -def test_generate_response_with_tools(mock_openai_client): +def test_generate_response_with_tools(mock_openai_client, mock_openai_async_client): config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0) llm = OpenAILLM(config) messages = [ From 3102c099cefc1c278f98922dce5878b1a524740d Mon Sep 17 00:00:00 2001 From: Mayank Solanki Date: Tue, 8 Oct 2024 23:03:02 +0530 Subject: [PATCH 4/4] executor added --- mem0/embeddings/base.py | 14 +++++++++++++ mem0/llms/base.py | 14 +++++++++++++ mem0/utils/concurrency.py | 43 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 mem0/utils/concurrency.py diff --git a/mem0/embeddings/base.py b/mem0/embeddings/base.py index c5d12b6e23..43cea14b6f 100644 --- a/mem0/embeddings/base.py +++ b/mem0/embeddings/base.py @@ -2,6 +2,7 @@ from typing import Optional from mem0.configs.embeddings.base import BaseEmbedderConfig +from mem0.utils.concurrency import run_in_executor class EmbeddingBase(ABC): @@ -16,6 +17,19 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None): self.config = BaseEmbedderConfig() else: self.config = config + + async def aembed(self, text): + """Async version of the embed method. + + The default implementation delegates to the synchronous generate_response method using + `run_in_executor`. Subclasses that need to provide a true async implementation + should override this method to reduce the overhead of using `run_in_executor`. + """ + return await run_in_executor( + None, + self.embed, + text + ) @abstractmethod def embed(self, text): diff --git a/mem0/llms/base.py b/mem0/llms/base.py index a1ed3c794f..b025d116f8 100644 --- a/mem0/llms/base.py +++ b/mem0/llms/base.py @@ -2,6 +2,7 @@ from typing import Optional from mem0.configs.llms.base import BaseLlmConfig +from mem0.utils.concurrency import run_in_executor class LLMBase(ABC): @@ -15,7 +16,20 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): self.config = BaseLlmConfig() else: self.config = config + + async def agenerate_response(self, message): + """Async version of the generate_response method. + The default implementation delegates to the synchronous generate_response method using + `run_in_executor`. Subclasses that need to provide a true async implementation + should override this method to reduce the overhead of using `run_in_executor`. + """ + return await run_in_executor( + None, + self.generate_response, + message + ) + @abstractmethod def generate_response(self, messages): """ diff --git a/mem0/utils/concurrency.py b/mem0/utils/concurrency.py new file mode 100644 index 0000000000..d3e838793e --- /dev/null +++ b/mem0/utils/concurrency.py @@ -0,0 +1,43 @@ +import asyncio +from contextvars import copy_context +from functools import partial +from typing import Callable, TypeVar, cast + +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") + +async def run_in_executor( + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: + """Run a function in an executor. + + Args: + func (Callable[P, Output]): The function. + *args (Any): The positional arguments to the function. + **kwargs (Any): The keyword arguments to the function. + + Returns: + Output: The output of the function. + + Raises: + RuntimeError: If the function raises a StopIteration. + """ + + def wrapper() -> T: + try: + return func(*args, **kwargs) + except StopIteration as exc: + # StopIteration can't be set on an asyncio.Future + # it raises a TypeError and leaves the Future pending forever + # so we need to convert it to a RuntimeError + raise RuntimeError from exc + + #TODO: Rethink of executor. + return await asyncio.get_running_loop().run_in_executor( + None, + cast(Callable[..., T], partial(copy_context().run, wrapper)), + ) \ No newline at end of file