From a0ec04549575de5547faa5381fa69fef61405ac8 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 1 Feb 2024 02:10:47 +0100 Subject: [PATCH] Add async methods to BaseStore (#16669) - **Description:** The BaseStore methods are currently blocking. Some implementations (AstraDBStore, RedisStore) would benefit from having async methods. Also once we have async methods for BaseStore, we can implement the async `aembed_documents` in CacheBackedEmbeddings to cache the embeddings asynchronously. * adds async methods amget, amset, amedelete and ayield_keys to BaseStore * implements the async methods for InMemoryStore * adds tests for InMemoryStore async methods - **Twitter handle:** cbornet_ --- libs/core/langchain_core/stores.py | 64 ++++++++++++++++++- libs/langchain/langchain/storage/in_memory.py | 49 ++++++++++++++ .../unit_tests/storage/test_in_memory.py | 47 ++++++++++++++ 3 files changed, 159 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/stores.py b/libs/core/langchain_core/stores.py index 8363fca3891f9..bb2f09f929ac6 100644 --- a/libs/core/langchain_core/stores.py +++ b/libs/core/langchain_core/stores.py @@ -1,5 +1,17 @@ from abc import ABC, abstractmethod -from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import ( + AsyncIterator, + Generic, + Iterator, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) + +from langchain_core.runnables import run_in_executor K = TypeVar("K") V = TypeVar("V") @@ -20,6 +32,18 @@ def mget(self, keys: Sequence[K]) -> List[Optional[V]]: If a key is not found, the corresponding value will be None. """ + async def amget(self, keys: Sequence[K]) -> List[Optional[V]]: + """Get the values associated with the given keys. + + Args: + keys (Sequence[K]): A sequence of keys. + + Returns: + A sequence of optional values associated with the keys. + If a key is not found, the corresponding value will be None. + """ + return await run_in_executor(None, self.mget, keys) + @abstractmethod def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: """Set the values for the given keys. @@ -28,6 +52,14 @@ def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs. """ + async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None: + """Set the values for the given keys. + + Args: + key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs. + """ + return await run_in_executor(None, self.mset, key_value_pairs) + @abstractmethod def mdelete(self, keys: Sequence[K]) -> None: """Delete the given keys and their associated values. @@ -36,6 +68,14 @@ def mdelete(self, keys: Sequence[K]) -> None: keys (Sequence[K]): A sequence of keys to delete. """ + async def amdelete(self, keys: Sequence[K]) -> None: + """Delete the given keys and their associated values. + + Args: + keys (Sequence[K]): A sequence of keys to delete. + """ + return await run_in_executor(None, self.mdelete, keys) + @abstractmethod def yield_keys( self, *, prefix: Optional[str] = None @@ -52,5 +92,27 @@ def yield_keys( depending on what makes more sense for the given store. """ + async def ayield_keys( + self, *, prefix: Optional[str] = None + ) -> Union[AsyncIterator[K], AsyncIterator[str]]: + """Get an iterator over keys that match the given prefix. + + Args: + prefix (str): The prefix to match. + + Returns: + Iterator[K | str]: An iterator over keys that match the given prefix. + + This method is allowed to return an iterator over either K or str + depending on what makes more sense for the given store. + """ + iterator = await run_in_executor(None, self.yield_keys, prefix=prefix) + done = object() + while True: + item = await run_in_executor(None, lambda it: next(it, done), iterator) + if item is done: + break + yield item + ByteStore = BaseStore[str, bytes] diff --git a/libs/langchain/langchain/storage/in_memory.py b/libs/langchain/langchain/storage/in_memory.py index 03679a34909d9..310f81ce28f7c 100644 --- a/libs/langchain/langchain/storage/in_memory.py +++ b/libs/langchain/langchain/storage/in_memory.py @@ -5,6 +5,7 @@ """ from typing import ( Any, + AsyncIterator, Dict, Generic, Iterator, @@ -60,6 +61,18 @@ def mget(self, keys: Sequence[str]) -> List[Optional[V]]: """ return [self.store.get(key) for key in keys] + async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: + """Get the values associated with the given keys. + + Args: + keys (Sequence[str]): A sequence of keys. + + Returns: + A sequence of optional values associated with the keys. + If a key is not found, the corresponding value will be None. + """ + return self.mget(keys) + def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: """Set the values for the given keys. @@ -72,6 +85,17 @@ def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: for key, value in key_value_pairs: self.store[key] = value + async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None: + """Set the values for the given keys. + + Args: + key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs. + + Returns: + None + """ + return self.mset(key_value_pairs) + def mdelete(self, keys: Sequence[str]) -> None: """Delete the given keys and their associated values. @@ -82,6 +106,14 @@ def mdelete(self, keys: Sequence[str]) -> None: if key in self.store: del self.store[key] + async def amdelete(self, keys: Sequence[str]) -> None: + """Delete the given keys and their associated values. + + Args: + keys (Sequence[str]): A sequence of keys to delete. + """ + self.mdelete(keys) + def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]: """Get an iterator over keys that match the given prefix. @@ -98,6 +130,23 @@ def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]: if key.startswith(prefix): yield key + async def ayield_keys(self, prefix: Optional[str] = None) -> AsyncIterator[str]: + """Get an async iterator over keys that match the given prefix. + + Args: + prefix (str, optional): The prefix to match. Defaults to None. + + Returns: + AsyncIterator[str]: An async iterator over keys that match the given prefix. + """ + if prefix is None: + for key in self.store.keys(): + yield key + else: + for key in self.store.keys(): + if key.startswith(prefix): + yield key + InMemoryStore = InMemoryBaseStore[Any] InMemoryByteStore = InMemoryBaseStore[bytes] diff --git a/libs/langchain/tests/unit_tests/storage/test_in_memory.py b/libs/langchain/tests/unit_tests/storage/test_in_memory.py index 4e969a01dd8f1..a12233b4f67b2 100644 --- a/libs/langchain/tests/unit_tests/storage/test_in_memory.py +++ b/libs/langchain/tests/unit_tests/storage/test_in_memory.py @@ -13,6 +13,18 @@ def test_mget() -> None: assert non_existent_value == [None] +async def test_amget() -> None: + store = InMemoryStore() + await store.amset([("key1", "value1"), ("key2", "value2")]) + + values = await store.amget(["key1", "key2"]) + assert values == ["value1", "value2"] + + # Test non-existent key + non_existent_value = await store.amget(["key3"]) + assert non_existent_value == [None] + + def test_mset() -> None: store = InMemoryStore() store.mset([("key1", "value1"), ("key2", "value2")]) @@ -21,6 +33,14 @@ def test_mset() -> None: assert values == ["value1", "value2"] +async def test_amset() -> None: + store = InMemoryStore() + await store.amset([("key1", "value1"), ("key2", "value2")]) + + values = await store.amget(["key1", "key2"]) + assert values == ["value1", "value2"] + + def test_mdelete() -> None: store = InMemoryStore() store.mset([("key1", "value1"), ("key2", "value2")]) @@ -34,6 +54,19 @@ def test_mdelete() -> None: store.mdelete(["key3"]) # No error should be raised +async def test_amdelete() -> None: + store = InMemoryStore() + await store.amset([("key1", "value1"), ("key2", "value2")]) + + await store.amdelete(["key1"]) + + values = await store.amget(["key1", "key2"]) + assert values == [None, "value2"] + + # Test deleting non-existent key + await store.amdelete(["key3"]) # No error should be raised + + def test_yield_keys() -> None: store = InMemoryStore() store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")]) @@ -46,3 +79,17 @@ def test_yield_keys() -> None: keys_with_invalid_prefix = list(store.yield_keys(prefix="x")) assert keys_with_invalid_prefix == [] + + +async def test_ayield_keys() -> None: + store = InMemoryStore() + await store.amset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")]) + + keys = [key async for key in store.ayield_keys()] + assert set(keys) == {"key1", "key2", "key3"} + + keys_with_prefix = [key async for key in store.ayield_keys(prefix="key")] + assert set(keys_with_prefix) == {"key1", "key2", "key3"} + + keys_with_invalid_prefix = [key async for key in store.ayield_keys(prefix="x")] + assert keys_with_invalid_prefix == []