Skip to content

Commit

Permalink
langchain[minor]: allow CacheBackedEmbeddings to cache queries (langc…
Browse files Browse the repository at this point in the history
…hain-ai#20073)

Add optional caching of queries to cache backed embeddings
  • Loading branch information
jokester authored May 13, 2024
1 parent a156aac commit b53548d
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 21 deletions.
3 changes: 2 additions & 1 deletion docs/docs/how_to/caching_embeddings.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
"- document_embedding_cache: Any [`ByteStore`](/docs/integrations/stores/) for caching document embeddings.\n",
"- batch_size: (optional, defaults to `None`) The number of documents to embed between store updates.\n",
"- namespace: (optional, defaults to `\"\"`) The namespace to use for document cache. This namespace is used to avoid collisions with other caches. For example, set it to the name of the embedding model used.\n",
"- query_embedding_cache: (optional, defaults to `None` or not caching) A [`ByteStore`](/docs/integrations/stores/) for caching query embeddings, or `True` to use the same store as `document_embedding_cache`.\n",
"\n",
"**Attention**:\n",
"\n",
"- Be sure to set the `namespace` parameter to avoid collisions of the same text embedded using different embeddings models.\n",
"- Currently `CacheBackedEmbeddings` does not cache embedding created with `embed_query()` `aembed_query()` methods."
"- `CacheBackedEmbeddings` does not cache query embeddings by default. To enable query caching, one need to specify a `query_embedding_cache`."
]
},
{
Expand Down
75 changes: 55 additions & 20 deletions libs/langchain/langchain/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
The text is hashed and the hash is used as the key in the cache.
"""

from __future__ import annotations

import hashlib
Expand Down Expand Up @@ -59,6 +60,9 @@ class CacheBackedEmbeddings(Embeddings):
If need be, the interface can be extended to accept other implementations
of the value serializer and deserializer, as well as the key encoder.
Note that by default only document embeddings are cached. To cache query
embeddings too, pass in a query_embedding_store to constructor.
Examples:
.. code-block: python
Expand Down Expand Up @@ -87,16 +91,20 @@ def __init__(
document_embedding_store: BaseStore[str, List[float]],
*,
batch_size: Optional[int] = None,
query_embedding_store: Optional[BaseStore[str, List[float]]] = None,
) -> None:
"""Initialize the embedder.
Args:
underlying_embeddings: the embedder to use for computing embeddings.
document_embedding_store: The store to use for caching document embeddings.
batch_size: The number of documents to embed between store updates.
query_embedding_store: The store to use for caching query embeddings.
If None, query embeddings are not cached.
"""
super().__init__()
self.document_embedding_store = document_embedding_store
self.query_embedding_store = query_embedding_store
self.underlying_embeddings = underlying_embeddings
self.batch_size = batch_size

Expand Down Expand Up @@ -173,42 +181,48 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_query(self, text: str) -> List[float]:
"""Embed query text.
This method does not support caching at the moment.
Support for caching queries is easily to implement, but might make
sense to hold off to see the most common patterns.
If the cache has an eviction policy, we may need to be a bit more careful
about sharing the cache between documents and queries. Generally,
one is OK evicting query caches, but document caches should be kept.
By default, this method does not cache queries. To enable caching, set the
`cache_query` parameter to `True` when initializing the embedder.
Args:
text: The text to embed.
Returns:
The embedding for the given text.
"""
return self.underlying_embeddings.embed_query(text)
if not self.query_embedding_store:
return self.underlying_embeddings.embed_query(text)

async def aembed_query(self, text: str) -> List[float]:
"""Embed query text.
(cached,) = self.query_embedding_store.mget([text])
if cached is not None:
return cached

This method does not support caching at the moment.
vector = self.underlying_embeddings.embed_query(text)
self.query_embedding_store.mset([(text, vector)])
return vector

Support for caching queries is easily to implement, but might make
sense to hold off to see the most common patterns.
async def aembed_query(self, text: str) -> List[float]:
"""Embed query text.
If the cache has an eviction policy, we may need to be a bit more careful
about sharing the cache between documents and queries. Generally,
one is OK evicting query caches, but document caches should be kept.
By default, this method does not cache queries. To enable caching, set the
`cache_query` parameter to `True` when initializing the embedder.
Args:
text: The text to embed.
Returns:
The embedding for the given text.
"""
return await self.underlying_embeddings.aembed_query(text)
if not self.query_embedding_store:
return await self.underlying_embeddings.aembed_query(text)

(cached,) = await self.query_embedding_store.amget([text])
if cached is not None:
return cached

vector = await self.underlying_embeddings.aembed_query(text)
await self.query_embedding_store.amset([(text, vector)])
return vector

@classmethod
def from_bytes_store(
Expand All @@ -218,6 +232,7 @@ def from_bytes_store(
*,
namespace: str = "",
batch_size: Optional[int] = None,
query_embedding_cache: Union[bool, ByteStore] = False,
) -> CacheBackedEmbeddings:
"""On-ramp that adds the necessary serialization and encoding to the store.
Expand All @@ -229,13 +244,33 @@ def from_bytes_store(
This namespace is used to avoid collisions with other caches.
For example, set it to the name of the embedding model used.
batch_size: The number of documents to embed between store updates.
query_embedding_cache: The cache to use for storing query embeddings.
True to use the same cache as document embeddings.
False to not cache query embeddings.
"""
namespace = namespace
key_encoder = _create_key_encoder(namespace)
encoder_backed_store = EncoderBackedStore[str, List[float]](
document_embedding_store = EncoderBackedStore[str, List[float]](
document_embedding_cache,
key_encoder,
_value_serializer,
_value_deserializer,
)
return cls(underlying_embeddings, encoder_backed_store, batch_size=batch_size)
if query_embedding_cache is True:
query_embedding_store = document_embedding_store
elif query_embedding_cache is False:
query_embedding_store = None
else:
query_embedding_store = EncoderBackedStore[str, List[float]](
query_embedding_cache,
key_encoder,
_value_serializer,
_value_deserializer,
)

return cls(
underlying_embeddings,
document_embedding_store,
batch_size=batch_size,
query_embedding_store=query_embedding_store,
)
35 changes: 35 additions & 0 deletions libs/langchain/tests/unit_tests/embeddings/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ def cache_embeddings_batch() -> CacheBackedEmbeddings:
)


@pytest.fixture
def cache_embeddings_with_query() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings with query caching."""
doc_store = InMemoryStore()
query_store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings,
document_embedding_cache=doc_store,
namespace="test_namespace",
query_embedding_cache=query_store,
)


def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
texts = ["1", "22", "a", "333"]
vectors = cache_embeddings.embed_documents(texts)
Expand Down Expand Up @@ -73,6 +87,17 @@ def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
vector = cache_embeddings.embed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector
assert cache_embeddings.query_embedding_store is None


def test_embed_cached_query(cache_embeddings_with_query: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = cache_embeddings_with_query.embed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
assert len(keys) == 1
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"


async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
Expand Down Expand Up @@ -112,3 +137,13 @@ async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
vector = await cache_embeddings.aembed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector


async def test_aembed_query_cached(
cache_embeddings_with_query: CacheBackedEmbeddings,
) -> None:
text = "query_text"
await cache_embeddings_with_query.aembed_query(text)
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
assert len(keys) == 1
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"

0 comments on commit b53548d

Please sign in to comment.