diff --git a/test/conftest.py b/test/conftest.py index f0c0291e..d65e450b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,4 +1,4 @@ -from typing import Callable, Tuple, Union, Optional +from typing import Callable, Tuple, Union, Optional, Type import pytest import pytest_asyncio @@ -7,6 +7,8 @@ import fakeredis from fakeredis._server import _create_version +ServerDetails = Type[Tuple[str, Union[None, Tuple[int, ...]]]] + def _check_lua_module_supported() -> bool: redis = fakeredis.FakeRedis(lua_modules={"cjson"}) @@ -18,7 +20,7 @@ def _check_lua_module_supported() -> bool: @pytest_asyncio.fixture(scope="session") -def real_redis_version() -> Tuple[str, Union[None, Tuple[int, ...]]]: +def real_server_details() -> ServerDetails: """Returns server's version or None if server is not running""" client = None try: @@ -39,8 +41,8 @@ def real_redis_version() -> Tuple[str, Union[None, Tuple[int, ...]]]: @pytest_asyncio.fixture(name="fake_server") -def _fake_server(request, real_redis_version) -> fakeredis.FakeServer: - server_type, _ = real_redis_version +def _fake_server(request, real_server_details: ServerDetails) -> fakeredis.FakeServer: + server_type, _ = real_server_details min_server_marker = request.node.get_closest_marker("min_server") server_version = min_server_marker.args[0] if min_server_marker else "6.2" server = fakeredis.FakeServer(server_type=server_type, version=server_version) @@ -49,7 +51,7 @@ def _fake_server(request, real_redis_version) -> fakeredis.FakeServer: @pytest_asyncio.fixture -def r(request, create_redis) -> redis.Redis: +def r(request, create_redis: Callable[[int], redis.Redis]) -> redis.Redis: rconn = create_redis(db=2) connected = request.node.get_closest_marker("disconnected") is None if connected: @@ -77,7 +79,7 @@ def _marker_version_value(request, marker_name: str): ) def _create_connection(request) -> Callable[[int], redis.Redis]: cls_name = request.param - server_type, server_version = request.getfixturevalue("real_redis_version") + server_type, server_version = request.getfixturevalue("real_server_details") if not cls_name.startswith("Fake") and not server_version: pytest.skip("Redis is not running") unsupported_server_types = request.node.get_closest_marker("unsupported_server_types") @@ -112,7 +114,7 @@ def factory(db=2): params=[pytest.param("fake", marks=pytest.mark.fake), pytest.param("real", marks=pytest.mark.real)], ) async def _req_aioredis2(request) -> redis.asyncio.Redis: - server_type, server_version = request.getfixturevalue("real_redis_version") + server_type, server_version = request.getfixturevalue("real_server_details") if request.param != "fake" and not server_version: pytest.skip("Redis is not running") diff --git a/test/test_mixins/test_acl_commands.py b/test/test_mixins/test_acl_commands.py index c289cfc4..c0c58ee3 100644 --- a/test/test_mixins/test_acl_commands.py +++ b/test/test_mixins/test_acl_commands.py @@ -4,21 +4,34 @@ from fakeredis.model import get_categories, get_commands_by_category from test import testtools +from test.conftest import ServerDetails pytestmark = [] pytestmark.extend([pytest.mark.min_server("7"), testtools.run_test_if_redispy_ver("gte", "5")]) +_VALKEY_UNSUPPORTED_COMMANDS = { + "hexpiretime", + "hexpireat", + "hpexpireat", + "hexpire", + "hpttl", + "hpexpire", + "hpexpiretime", + "httl", +} -def test_acl_cat(r: redis.Redis): + +def test_acl_cat(r: redis.Redis, real_server_details: ServerDetails): categories = get_categories() categories = [cat.decode() for cat in categories] assert set(r.acl_cat()) == set(categories) for cat in categories: commands = get_commands_by_category(cat) commands = {cmd.decode() for cmd in commands} - if "hpersist" in commands: - commands.remove("hpersist") assert len(commands) > 0 + commands.discard("hpersist") + if real_server_details[0] == "valkey": + commands = commands - _VALKEY_UNSUPPORTED_COMMANDS commands = {cmd.replace(" ", "|") for cmd in commands} diff = set(commands) - set(r.acl_cat(cat)) assert len(diff) == 0, f"Commands not found in category {cat}: {diff}"