Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add config option for vertex embedding tasks #2266

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/components/embedders/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ Here's a comprehensive list of all parameters that can be used across different
| `azure_kwargs` | Key-Value arguments for the AzureOpenAI embedding model |
| `openai_base_url` | Base URL for OpenAI API | OpenAI |
| `vertex_credentials_json` | Path to the Google Cloud credentials JSON file for VertexAI |
| `memory_add_embedding_type` | The type of embedding to use for the add memory action |
| `memory_update_embedding_type` | The type of embedding to use for the update memory action |
| `memory_search_embedding_type` | The type of embedding to use for the search memory action |


## Supported Embedding Models
Expand Down
17 changes: 15 additions & 2 deletions docs/components/embedders/models/vertexai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,25 @@ config = {
"embedder": {
"provider": "vertexai",
"config": {
"model": "text-embedding-004"
"model": "text-embedding-004",
"memory_add_embedding_type": "RETRIEVAL_DOCUMENT",
"memory_update_embedding_type": "RETRIEVAL_DOCUMENT",
"memory_search_embedding_type": "RETRIEVAL_QUERY"
}
}
}

m = Memory.from_config(config)
m.add("I'm visiting Paris", user_id="john")
```

The embedding types can be one of the following:
- SEMANTIC_SIMILARITY
- CLASSIFICATION
- CLUSTERING
- RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, QUESTION_ANSWERING, FACT_VERIFICATION
- CODE_RETRIEVAL_QUERY
Check out the [Vertex AI documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types#supported_task_types) for more information.

### Config

Here are the parameters available for configuring the Vertex AI embedder:
Expand All @@ -34,3 +44,6 @@ Here are the parameters available for configuring the Vertex AI embedder:
| `model` | The name of the Vertex AI embedding model to use | `text-embedding-004` |
| `vertex_credentials_json` | Path to the Google Cloud credentials JSON file | `None` |
| `embedding_dims` | Dimensions of the embedding model | `256` |
| `memory_add_embedding_type` | The type of embedding to use for the add memory action | `RETRIEVAL_DOCUMENT` |
| `memory_update_embedding_type` | The type of embedding to use for the update memory action | `RETRIEVAL_DOCUMENT` |
| `memory_search_embedding_type` | The type of embedding to use for the search memory action | `RETRIEVAL_QUERY` |
14 changes: 14 additions & 0 deletions mem0/configs/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(
http_client_proxies: Optional[Union[Dict, str]] = None,
# VertexAI specific
vertex_credentials_json: Optional[str] = None,
memory_add_embedding_type: Optional[str] = None,
memory_update_embedding_type: Optional[str] = None,
memory_search_embedding_type: Optional[str] = None,
):
"""
Initializes a configuration class instance for the Embeddings.
Expand All @@ -47,6 +50,14 @@ def __init__(
:type azure_kwargs: Optional[Dict[str, Any]], defaults a dict inside init
:param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
:type http_client_proxies: Optional[Dict | str], optional
:param vertex_credentials_json: The path to the Vertex AI credentials JSON file, defaults to None
:type vertex_credentials_json: Optional[str], optional
:param memory_add_embedding_type: The type of embedding to use for the add memory action, defaults to None
:type memory_add_embedding_type: Optional[str], optional
:param memory_update_embedding_type: The type of embedding to use for the update memory action, defaults to None
:type memory_update_embedding_type: Optional[str], optional
:param memory_search_embedding_type: The type of embedding to use for the search memory action, defaults to None
:type memory_search_embedding_type: Optional[str], optional
"""

self.model = model
Expand All @@ -68,3 +79,6 @@ def __init__(

# VertexAI specific
self.vertex_credentials_json = vertex_credentials_json
self.memory_add_embedding_type = memory_add_embedding_type
self.memory_update_embedding_type = memory_update_embedding_type
self.memory_search_embedding_type = memory_search_embedding_type
6 changes: 3 additions & 3 deletions mem0/embeddings/azure_openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Literal, Optional

from openai import AzureOpenAI

Expand All @@ -26,13 +26,13 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
default_headers=default_headers,
)

def embed(self, text):
def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding for the given text using OpenAI.

Args:
text (str): The text to embed.

memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""
Expand Down
6 changes: 3 additions & 3 deletions mem0/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Literal, Optional

from mem0.configs.embeddings.base import BaseEmbedderConfig

Expand All @@ -18,13 +18,13 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
self.config = config

@abstractmethod
def embed(self, text):
def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]]):
"""
Get the embedding for the given text.

Args:
text (str): The text to embed.

memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""
Expand Down
5 changes: 3 additions & 2 deletions mem0/embeddings/gemini.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Literal, Optional

import google.generativeai as genai

Expand All @@ -18,11 +18,12 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):

genai.configure(api_key=api_key)

def embed(self, text):
def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding for the given text using Google Generative AI.
Args:
text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""
Expand Down
6 changes: 3 additions & 3 deletions mem0/embeddings/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional

from sentence_transformers import SentenceTransformer

Expand All @@ -16,13 +16,13 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):

self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()

def embed(self, text):
def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding for the given text using Hugging Face.

Args:
text (str): The text to embed.

memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""
Expand Down
6 changes: 3 additions & 3 deletions mem0/embeddings/ollama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import subprocess
import sys
from typing import Optional
from typing import Literal, Optional

from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
Expand Down Expand Up @@ -39,13 +39,13 @@ def _ensure_model_exists(self):
if not any(model.get("name") == self.config.model for model in local_models):
self.client.pull(self.config.model)

def embed(self, text):
def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding for the given text using Ollama.

Args:
text (str): The text to embed.

memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""
Expand Down
6 changes: 3 additions & 3 deletions mem0/embeddings/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Literal, Optional

from openai import OpenAI

Expand All @@ -18,13 +18,13 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE")
self.client = OpenAI(api_key=api_key, base_url=base_url)

def embed(self, text):
def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding for the given text using OpenAI.

Args:
text (str): The text to embed.

memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""
Expand Down
6 changes: 3 additions & 3 deletions mem0/embeddings/together.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Literal, Optional

from together import Together

Expand All @@ -17,13 +17,13 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
self.config.embedding_dims = self.config.embedding_dims or 768
self.client = Together(api_key=api_key)

def embed(self, text):
def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding for the given text using OpenAI.

Args:
text (str): The text to embed.

memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""
Expand Down
26 changes: 20 additions & 6 deletions mem0/embeddings/vertexai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Optional
from typing import Literal, Optional

from vertexai.language_models import TextEmbeddingModel
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel

from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase
Expand All @@ -13,7 +13,13 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):

self.config.model = self.config.model or "text-embedding-004"
self.config.embedding_dims = self.config.embedding_dims or 256


self.embedding_types = {
"add": self.config.memory_add_embedding_type or "RETRIEVAL_DOCUMENT",
"update": self.config.memory_update_embedding_type or "RETRIEVAL_DOCUMENT",
"search": self.config.memory_search_embedding_type or "RETRIEVAL_QUERY"
}

credentials_path = self.config.vertex_credentials_json

if credentials_path:
Expand All @@ -25,16 +31,24 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):

self.model = TextEmbeddingModel.from_pretrained(self.config.model)

def embed(self, text):
def embed(self, text, memory_action:Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding for the given text using Vertex AI.

Args:
text (str): The text to embed.

memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""
embeddings = self.model.get_embeddings(texts=[text], output_dimensionality=self.config.embedding_dims)
embedding_type = "SEMANTIC_SIMILARITY"
if memory_action is not None:
if memory_action not in self.embedding_types:
raise ValueError(f"Invalid memory action: {memory_action}")

embedding_type = self.embedding_types[memory_action]

text_input = TextEmbeddingInput(text=text, task_type=embedding_type)
embeddings = self.model.get_embeddings(texts=[text_input], output_dimensionality=self.config.embedding_dims)

return embeddings[0].values
13 changes: 7 additions & 6 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytz
from pydantic import ValidationError
from mem0.memory.utils import parse_vision_messages

from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase
Expand All @@ -19,6 +19,7 @@
from mem0.memory.utils import (
get_fact_retrieval_messages,
parse_messages,
parse_vision_messages,
remove_code_blocks,
)
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
Expand Down Expand Up @@ -167,7 +168,7 @@ def _add_to_vector_store(self, messages, metadata, filters):
retrieved_old_memory = []
new_message_embeddings = {}
for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem)
messages_embeddings = self.embedding_model.embed(new_mem, "add")
new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search(
query=messages_embeddings,
Expand Down Expand Up @@ -446,7 +447,7 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil
return original_memories

def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query)
embeddings = self.embedding_model.embed(query, "search")
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters)

excluded_keys = {
Expand Down Expand Up @@ -494,7 +495,7 @@ def update(self, memory_id, data):
"""
capture_event("mem0.update", self, {"memory_id": memory_id})

existing_embeddings = {data: self.embedding_model.embed(data)}
existing_embeddings = {data: self.embedding_model.embed(data, "update")}

self._update_memory(memory_id, data, existing_embeddings)
return {"message": "Memory updated successfully!"}
Expand Down Expand Up @@ -562,7 +563,7 @@ def _create_memory(self, data, existing_embeddings, metadata=None):
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data)
embeddings = self.embedding_model.embed(data, "add")
memory_id = str(uuid.uuid4())
metadata = metadata or {}
metadata["data"] = data
Expand Down Expand Up @@ -603,7 +604,7 @@ def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data)
embeddings = self.embedding_model.embed(data, "update")
self.vector_store.update(
vector_id=memory_id,
vector=embeddings,
Expand Down
Loading