From 988701b795b6cc513da11e28f308d1a1f503c29b Mon Sep 17 00:00:00 2001 From: Vasyl Dizhak Date: Sun, 9 Jun 2024 20:03:38 +0100 Subject: [PATCH] Add support for the set functions from issue #597 Co-authored-by: Ali Rezaei --- changelog.d/730.feature | 1 + django_redis/cache.py | 68 +++++++++ django_redis/client/default.py | 245 +++++++++++++++++++++++++++++++- django_redis/client/sharded.py | 149 ++++++++++++++++++- django_redis/compressors/lz4.py | 2 +- tests/test_backend.py | 151 ++++++++++++++++++++ 6 files changed, 613 insertions(+), 3 deletions(-) create mode 100644 changelog.d/730.feature diff --git a/changelog.d/730.feature b/changelog.d/730.feature new file mode 100644 index 00000000..d41ae639 --- /dev/null +++ b/changelog.d/730.feature @@ -0,0 +1 @@ +Support for sets and support basic operations, sadd, scard, sdiff, sdiffstore, sinter, sinterstore, smismember, sismember, smembers, smove, spop, srandmember, srem, sscan, sscan_iter, sunion, sunionstore \ No newline at end of file diff --git a/django_redis/cache.py b/django_redis/cache.py index d26c33fa..f7b943a3 100644 --- a/django_redis/cache.py +++ b/django_redis/cache.py @@ -185,6 +185,74 @@ def close(self, **kwargs): def touch(self, *args, **kwargs): return self.client.touch(*args, **kwargs) + @omit_exception + def sadd(self, *args, **kwargs): + return self.client.sadd(*args, **kwargs) + + @omit_exception + def scard(self, *args, **kwargs): + return self.client.scard(*args, **kwargs) + + @omit_exception + def sdiff(self, *args, **kwargs): + return self.client.sdiff(*args, **kwargs) + + @omit_exception + def sdiffstore(self, *args, **kwargs): + return self.client.sdiffstore(*args, **kwargs) + + @omit_exception + def sinter(self, *args, **kwargs): + return self.client.sinter(*args, **kwargs) + + @omit_exception + def sinterstore(self, *args, **kwargs): + return self.client.sinterstore(*args, **kwargs) + + @omit_exception + def sismember(self, *args, **kwargs): + return self.client.sismember(*args, **kwargs) + + @omit_exception + def smembers(self, *args, **kwargs): + return self.client.smembers(*args, **kwargs) + + @omit_exception + def smove(self, *args, **kwargs): + return self.client.smove(*args, **kwargs) + + @omit_exception + def spop(self, *args, **kwargs): + return self.client.spop(*args, **kwargs) + + @omit_exception + def srandmember(self, *args, **kwargs): + return self.client.srandmember(*args, **kwargs) + + @omit_exception + def srem(self, *args, **kwargs): + return self.client.srem(*args, **kwargs) + + @omit_exception + def sscan(self, *args, **kwargs): + return self.client.sscan(*args, **kwargs) + + @omit_exception + def sscan_iter(self, *args, **kwargs): + return self.client.sscan_iter(*args, **kwargs) + + @omit_exception + def smismember(self, *args, **kwargs): + return self.client.smismember(*args, **kwargs) + + @omit_exception + def sunion(self, *args, **kwargs): + return self.client.sunion(*args, **kwargs) + + @omit_exception + def sunionstore(self, *args, **kwargs): + return self.client.sunionstore(*args, **kwargs) + @omit_exception def hset(self, *args, **kwargs): return self.client.hset(*args, **kwargs) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index b9a5c1b0..eaa2890e 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -3,7 +3,7 @@ import socket from collections import OrderedDict from contextlib import suppress -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union from django.conf import settings from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func @@ -778,6 +778,249 @@ def make_pattern( return CacheKey(self._backend.key_func(pattern, prefix, version_str)) + def sadd( + self, + key: Any, + *values: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + values = [self.encode(value) for value in values] + return int(client.sadd(key, *values)) + + def scard( + self, + key: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + return int(client.scard(key)) + + def sdiff( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sdiff(*keys)} + + def sdiffstore( + self, + dest: Any, + *keys, + version_dest: Optional[int] = None, + version_keys: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + dest = self.make_key(dest, version=version_dest) + keys = [self.make_key(key, version=version_keys) for key in keys] + return int(client.sdiffstore(dest, *keys)) + + def sinter( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sinter(*keys)} + + def sinterstore( + self, + dest: Any, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + dest = self.make_key(dest, version=version) + keys = [self.make_key(key, version=version) for key in keys] + return int(client.sinterstore(dest, *keys)) + + def smismember( + self, + key: Any, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + encoded_members = [self.encode(member) for member in members] + + return [bool(value) for value in client.smismember(key, *encoded_members)] + + def sismember( + self, + key: Any, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + member = self.encode(member) + return bool(client.sismember(key, member)) + + def smembers( + self, + key: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + return {self.decode(value) for value in client.smembers(key)} + + def smove( + self, + source: Any, + destination: Any, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=True) + + source = self.make_key(source, version=version) + destination = self.make_key(destination) + member = self.encode(member) + return bool(client.smove(source, destination, member)) + + def spop( + self, + key: Any, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[Set, Any]: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + result = client.spop(key, count) + if isinstance(result, list): + return {self.decode(value) for value in result} + return self.decode(result) + + def srandmember( + self, + key: Any, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[Set, Any]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + result = client.srandmember(key, count) + if isinstance(result, list): + return {self.decode(value) for value in result} + return self.decode(result) + + def srem( + self, + key: Any, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + members = [self.encode(member) for member in members] + return int(client.srem(key, *members)) + + def sscan( + self, + key: Any, + match: Optional[str] = None, + count: Optional[int] = 10, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set[Any]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + + cursor, result = client.sscan( + key, match=self.encode(match) if match else None, count=count + ) + return {self.decode(value) for value in result} + + def sscan_iter( + self, + key: Any, + match: Optional[str] = None, + count: Optional[int] = 10, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Iterator[Any]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + for value in client.sscan_iter( + key, match=self.encode(match) if match else None, count=count + ): + yield self.decode(value) + + def sunion( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sunion(*keys)} + + def sunionstore( + self, + destination: Any, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + destination = self.make_key(destination, version=version) + keys = [self.make_key(key, version=version) for key in keys] + return int(client.sunionstore(destination, *keys)) + def close(self) -> None: close_flag = self._options.get( "CLOSE_CONNECTION", diff --git a/django_redis/client/sharded.py b/django_redis/client/sharded.py index dbb1d200..b480ea17 100644 --- a/django_redis/client/sharded.py +++ b/django_redis/client/sharded.py @@ -1,8 +1,9 @@ import re from collections import OrderedDict from datetime import datetime -from typing import Union +from typing import Any, Iterator, Optional, Set, Union +from redis import Redis from redis.exceptions import ConnectionError from django_redis.client.default import DEFAULT_TIMEOUT, DefaultClient @@ -335,3 +336,149 @@ def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None, client=None): def clear(self, client=None): for connection in self._serverdict.values(): connection.flushdb() + + def sadd( + self, + key: Any, + *values: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().sadd(key, *values, version=version, client=client) + + def scard( + self, + key: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().scard(key=key, version=version, client=client) + + def smembers( + self, + key: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().smembers(key=key, version=version, client=client) + + # TODO + def smove( + self, + source: Any, + destination: Any, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ): + if client is None: + source = self.make_key(source, version=version) + client = self.get_server(source) + destination = self.make_key(destination, version=version) + + return super().smove( + source=source, + destination=destination, + member=member, + version=version, + client=client, + ) + + def srem( + self, + key: Any, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().srem(key, *members, version=version, client=client) + + def sscan( + self, + key: Any, + match: Optional[str] = None, + count: Optional[int] = 10, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set[Any]: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().sscan( + key=key, match=match, count=count, version=version, client=client + ) + + def sscan_iter( + self, + key: Any, + match: Optional[str] = None, + count: Optional[int] = 10, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Iterator[Any]: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().sscan_iter( + key=key, match=match, count=count, version=version, client=client + ) + + def srandmember( + self, + key: Any, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[Set, Any]: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().srandmember(key=key, count=count, version=version, client=client) + + def sismember( + self, + key: Any, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().sismember(key, member, version=version, client=client) + + def spop( + self, + key: Any, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[Set, Any]: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().spop(key=key, count=count, version=version, client=client) + + def smismember( + self, + key: Any, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().smismember(key, *members, version=version, client=client) diff --git a/django_redis/compressors/lz4.py b/django_redis/compressors/lz4.py index 32183321..940c96d5 100644 --- a/django_redis/compressors/lz4.py +++ b/django_redis/compressors/lz4.py @@ -16,5 +16,5 @@ def compress(self, value: bytes) -> bytes: def decompress(self, value: bytes) -> bytes: try: return _decompress(value) - except Exception as e: # noqa: BLE001 + except Exception as e: raise CompressorError from e diff --git a/tests/test_backend.py b/tests/test_backend.py index 4ff60983..5b9562ee 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -856,3 +856,154 @@ def test_hexists(self, cache: RedisCache): cache.hset("foo_hash5", "foo1", "bar1") assert cache.hexists("foo_hash5", "foo1") assert not cache.hexists("foo_hash5", "foo") + + def test_sadd(self, cache: RedisCache): + assert cache.sadd("foo", "bar") == 1 + assert cache.smembers("foo") == {"bar"} + + def test_scard(self, cache: RedisCache): + cache.sadd("foo", "bar", "bar2") + assert cache.scard("foo") == 2 + + def test_sdiff(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sdiff("foo1", "foo2") == {"bar1"} + + def test_sdiffstore(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sdiffstore("foo3", "foo1", "foo2") == 1 + assert cache.smembers("foo3") == {"bar1"} + + def test_sdiffstore_with_keys_version(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2", version=2) + cache.sadd("foo2", "bar2", "bar3", version=2) + assert cache.sdiffstore("foo3", "foo1", "foo2", version_keys=2) == 1 + assert cache.smembers("foo3") == {"bar1"} + + def test_sdiffstore_with_different_keys_versions_without_initial_set_in_version( + self, cache: RedisCache + ): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2", version=1) + cache.sadd("foo2", "bar2", "bar3", version=2) + assert cache.sdiffstore("foo3", "foo1", "foo2", version_keys=2) == 0 + + def test_sdiffstore_with_different_keys_versions_with_initial_set_in_version( + self, cache: RedisCache + ): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2", version=2) + cache.sadd("foo2", "bar2", "bar3", version=1) + assert cache.sdiffstore("foo3", "foo1", "foo2", version_keys=2) == 2 + + def test_sinter(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sinter("foo1", "foo2") == {"bar2"} + + def test_interstore(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sinterstore("foo3", "foo1", "foo2") == 1 + assert cache.smembers("foo3") == {"bar2"} + + def test_sismember(self, cache: RedisCache): + cache.sadd("foo", "bar") + assert cache.sismember("foo", "bar") is True + assert cache.sismember("foo", "bar2") is False + + def test_smove(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.smove("foo1", "foo2", "bar1") is True + assert cache.smove("foo1", "foo2", "bar4") is False + assert cache.smembers("foo1") == {"bar2"} + assert cache.smembers("foo2") == {"bar1", "bar2", "bar3"} + + def test_spop_default_count(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.spop("foo") in {"bar1", "bar2"} + assert cache.smembers("foo") in [{"bar1"}, {"bar2"}] + + def test_spop(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.spop("foo", 1) in [{"bar1"}, {"bar2"}] + assert cache.smembers("foo") in [{"bar1"}, {"bar2"}] + + def test_srandmember_default_count(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srandmember("foo") in {"bar1", "bar2"} + + def test_srandmember(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srandmember("foo", 1) in [{"bar1"}, {"bar2"}] + + def test_srem(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srem("foo", "bar1") == 1 + assert cache.srem("foo", "bar3") == 0 + + def test_sscan(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + items = cache.sscan("foo") + assert items == {"bar1", "bar2"} + + def test_sscan_with_match(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2", "zoo") + items = cache.sscan("foo", match="zoo") + assert items == {"zoo"} + + def test_sscan_iter(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + items = cache.sscan_iter("foo") + assert set(items) == {"bar1", "bar2"} + + def test_sscan_iter_with_match(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2", "zoo") + items = cache.sscan_iter("foo", match="bar*") + assert set(items) == {"bar1", "bar2"} + + def test_smismember(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2", "bar3") + assert cache.smismember("foo", "bar1", "bar2", "xyz") == [True, True, False] + + def test_sunion(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sunion("foo1", "foo2") == {"bar1", "bar2", "bar3"} + + def test_sunionstore(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sunionstore("foo3", "foo1", "foo2") == 3 + assert cache.smembers("foo3") == {"bar1", "bar2", "bar3"}