diff --git a/mem0/configs/embeddings/base.py b/mem0/configs/embeddings/base.py index a3d989ee1c..2219907da0 100644 --- a/mem0/configs/embeddings/base.py +++ b/mem0/configs/embeddings/base.py @@ -27,6 +27,8 @@ def __init__( http_client_proxies: Optional[Union[Dict, str]] = None, # VertexAI specific vertex_credentials_json: Optional[str] = None, + # Jina specific + jina_base_url: Optional[str] = None, ): """ Initializes a configuration class instance for the Embeddings. @@ -47,6 +49,8 @@ def __init__( :type azure_kwargs: Optional[Dict[str, Any]], defaults a dict inside init :param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None :type http_client_proxies: Optional[Dict | str], optional + :param jina_base_url: Base URL for the Jina API, defaults to None + :type jina_base_url: Optional[str], optional """ self.model = model @@ -68,3 +72,6 @@ def __init__( # VertexAI specific self.vertex_credentials_json = vertex_credentials_json + + # Jina specific + self.jina_base_url = jina_base_url diff --git a/mem0/embeddings/jina.py b/mem0/embeddings/jina.py new file mode 100644 index 0000000000..8395c8a533 --- /dev/null +++ b/mem0/embeddings/jina.py @@ -0,0 +1,52 @@ +import os +from typing import Optional + +import requests + +from mem0.configs.embeddings.base import BaseEmbedderConfig +from mem0.embeddings.base import EmbeddingBase + + +class JinaEmbedding(EmbeddingBase): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config) + + self.config.model = self.config.model or "jina-embeddings-v3" + self.config.embedding_dims = self.config.embedding_dims or 768 + + api_key = self.config.api_key or os.getenv("JINA_API_KEY") + if not api_key: + raise ValueError("Jina API key is required. Set it in config or JINA_API_KEY environment variable.") + + base_url = self.config.jina_base_url or os.getenv("JINA_API_BASE", "https://api.jina.ai") + + self.base_url = f"{base_url}/v1/embeddings" + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + def embed(self, text): + """ + Get the embedding for the given text using Jina AI. + + Args: + text (str): The text to embed. + + Returns: + list: The embedding vector. + """ + text = text.replace("\n", " ") + + data = { + "model": self.config.model, + "input": [{"text": text}] + } + + if self.config.model_kwargs: + data.update(self.config.model_kwargs) + + response = requests.post(self.base_url, headers=self.headers, json=data) + response.raise_for_status() + + return response.json()["data"][0]["embedding"] \ No newline at end of file diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 0489bfa60b..82f54f7260 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -45,6 +45,7 @@ class EmbedderFactory: "gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding", "vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding", "together": "mem0.embeddings.together.TogetherEmbedding", + "jina": "mem0.embeddings.jina.JinaEmbedding", } @classmethod diff --git a/tests/embeddings/test_jina_embeddings.py b/tests/embeddings/test_jina_embeddings.py new file mode 100644 index 0000000000..c2d22961c8 --- /dev/null +++ b/tests/embeddings/test_jina_embeddings.py @@ -0,0 +1,170 @@ +from unittest.mock import Mock, patch + +import pytest + +from mem0.configs.embeddings.base import BaseEmbedderConfig +from mem0.embeddings.jina import JinaEmbedding + + +@pytest.fixture +def mock_jina_client(monkeypatch): + # Clear any existing env var + monkeypatch.delenv("JINA_API_KEY", raising=False) + with patch("mem0.embeddings.jina.requests") as mock_req: + yield mock_req + + +def test_embed_default_model(mock_jina_client, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "default_key") # Set a default key + config = BaseEmbedderConfig() + embedder = JinaEmbedding(config) + mock_response = Mock() + mock_response.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]} + mock_jina_client.post.return_value = mock_response + + result = embedder.embed("Test embedding") + + mock_jina_client.post.assert_called_once_with( + "https://api.jina.ai/v1/embeddings", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer default_key" # Use the default key + }, + json={ + "model": "jina-embeddings-v3", + "input": [{"text": "Test embedding"}] + } + ) + assert result == [0.1, 0.2, 0.3] + + +def test_embed_custom_model(mock_jina_client, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "test_key") + config = BaseEmbedderConfig(model="jina-embeddings-v3", embedding_dims=1024) + embedder = JinaEmbedding(config) + + mock_response = Mock() + mock_response.json.return_value = {"data": [{"embedding": [0.4, 0.5, 0.6]}]} + mock_jina_client.post.return_value = mock_response + + result = embedder.embed("Test embedding") + + mock_jina_client.post.assert_called_once_with( + "https://api.jina.ai/v1/embeddings", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer test_key" + }, + json={ + "model": "jina-embeddings-v3", + "input": [{"text": "Test embedding"}] + } + ) + assert result == [0.4, 0.5, 0.6] + + +def test_embed_removes_newlines(mock_jina_client, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "test_key") + config = BaseEmbedderConfig() + embedder = JinaEmbedding(config) + + mock_response = Mock() + mock_response.json.return_value = {"data": [{"embedding": [0.7, 0.8, 0.9]}]} + mock_jina_client.post.return_value = mock_response + + result = embedder.embed("Hello\nworld") + + mock_jina_client.post.assert_called_once_with( + "https://api.jina.ai/v1/embeddings", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer test_key" + }, + json={ + "model": "jina-embeddings-v3", + "input": [{"text": "Hello world"}] + } + ) + assert result == [0.7, 0.8, 0.9] + + +def test_embed_with_model_kwargs(mock_jina_client, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "test_key") + config = BaseEmbedderConfig(model_kwargs={"dimensions": 512, "normalized": True}) + embedder = JinaEmbedding(config) + + mock_response = Mock() + mock_response.json.return_value = {"data": [{"embedding": [1.0, 1.1, 1.2]}]} + mock_jina_client.post.return_value = mock_response + + result = embedder.embed("Test with kwargs") + + mock_jina_client.post.assert_called_once_with( + "https://api.jina.ai/v1/embeddings", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer test_key" + }, + json={ + "model": "jina-embeddings-v3", + "input": [{"text": "Test with kwargs"}], + "dimensions": 512, + "normalized": True + } + ) + assert result == [1.0, 1.1, 1.2] + + +def test_embed_without_api_key_env_var(mock_jina_client): + config = BaseEmbedderConfig(api_key="test_key") + embedder = JinaEmbedding(config) + + mock_response = Mock() + mock_response.json.return_value = {"data": [{"embedding": [1.3, 1.4, 1.5]}]} + mock_jina_client.post.return_value = mock_response + + result = embedder.embed("Testing API key") + + mock_jina_client.post.assert_called_once_with( + "https://api.jina.ai/v1/embeddings", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer test_key" + }, + json={ + "model": "jina-embeddings-v3", + "input": [{"text": "Testing API key"}] + } + ) + assert result == [1.3, 1.4, 1.5] + + +def test_embed_uses_environment_api_key(mock_jina_client, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "env_key") + config = BaseEmbedderConfig() + embedder = JinaEmbedding(config) + + mock_response = Mock() + mock_response.json.return_value = {"data": [{"embedding": [1.6, 1.7, 1.8]}]} + mock_jina_client.post.return_value = mock_response + + result = embedder.embed("Environment key test") + + mock_jina_client.post.assert_called_once_with( + "https://api.jina.ai/v1/embeddings", + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer env_key" + }, + json={ + "model": "jina-embeddings-v3", + "input": [{"text": "Environment key test"}] + } + ) + assert result == [1.6, 1.7, 1.8] + + +def test_raises_error_without_api_key(): + config = BaseEmbedderConfig() + with pytest.raises(ValueError, match="Jina API key is required"): + JinaEmbedding(config) \ No newline at end of file