Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla committed Dec 25, 2024
1 parent 7a08385 commit 960a464
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
16 changes: 9 additions & 7 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Tuple, Union, Optional
from typing import Callable, Tuple, Union, Optional, Type

import pytest
import pytest_asyncio
Expand All @@ -7,6 +7,8 @@
import fakeredis
from fakeredis._server import _create_version

ServerDetails = Type[Tuple[str, Union[None, Tuple[int, ...]]]]


def _check_lua_module_supported() -> bool:
redis = fakeredis.FakeRedis(lua_modules={"cjson"})
Expand All @@ -18,7 +20,7 @@ def _check_lua_module_supported() -> bool:


@pytest_asyncio.fixture(scope="session")
def real_redis_version() -> Tuple[str, Union[None, Tuple[int, ...]]]:
def real_server_details() -> ServerDetails:
"""Returns server's version or None if server is not running"""
client = None
try:
Expand All @@ -39,8 +41,8 @@ def real_redis_version() -> Tuple[str, Union[None, Tuple[int, ...]]]:


@pytest_asyncio.fixture(name="fake_server")
def _fake_server(request, real_redis_version) -> fakeredis.FakeServer:
server_type, _ = real_redis_version
def _fake_server(request, real_server_details: ServerDetails) -> fakeredis.FakeServer:
server_type, _ = real_server_details
min_server_marker = request.node.get_closest_marker("min_server")
server_version = min_server_marker.args[0] if min_server_marker else "6.2"
server = fakeredis.FakeServer(server_type=server_type, version=server_version)
Expand All @@ -49,7 +51,7 @@ def _fake_server(request, real_redis_version) -> fakeredis.FakeServer:


@pytest_asyncio.fixture
def r(request, create_redis) -> redis.Redis:
def r(request, create_redis: Callable[[int], redis.Redis]) -> redis.Redis:
rconn = create_redis(db=2)
connected = request.node.get_closest_marker("disconnected") is None
if connected:
Expand Down Expand Up @@ -77,7 +79,7 @@ def _marker_version_value(request, marker_name: str):
)
def _create_connection(request) -> Callable[[int], redis.Redis]:
cls_name = request.param
server_type, server_version = request.getfixturevalue("real_redis_version")
server_type, server_version = request.getfixturevalue("real_server_details")
if not cls_name.startswith("Fake") and not server_version:
pytest.skip("Redis is not running")
unsupported_server_types = request.node.get_closest_marker("unsupported_server_types")
Expand Down Expand Up @@ -112,7 +114,7 @@ def factory(db=2):
params=[pytest.param("fake", marks=pytest.mark.fake), pytest.param("real", marks=pytest.mark.real)],
)
async def _req_aioredis2(request) -> redis.asyncio.Redis:
server_type, server_version = request.getfixturevalue("real_redis_version")
server_type, server_version = request.getfixturevalue("real_server_details")
if request.param != "fake" and not server_version:
pytest.skip("Redis is not running")

Expand Down
19 changes: 16 additions & 3 deletions test/test_mixins/test_acl_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,34 @@

from fakeredis.model import get_categories, get_commands_by_category
from test import testtools
from test.conftest import ServerDetails

pytestmark = []
pytestmark.extend([pytest.mark.min_server("7"), testtools.run_test_if_redispy_ver("gte", "5")])

_VALKEY_UNSUPPORTED_COMMANDS = {
"hexpiretime",
"hexpireat",
"hpexpireat",
"hexpire",
"hpttl",
"hpexpire",
"hpexpiretime",
"httl",
}

def test_acl_cat(r: redis.Redis):

def test_acl_cat(r: redis.Redis, real_server_details: ServerDetails):
categories = get_categories()
categories = [cat.decode() for cat in categories]
assert set(r.acl_cat()) == set(categories)
for cat in categories:
commands = get_commands_by_category(cat)
commands = {cmd.decode() for cmd in commands}
if "hpersist" in commands:
commands.remove("hpersist")
assert len(commands) > 0
commands.discard("hpersist")
if real_server_details[0] == "valkey":
commands = commands - _VALKEY_UNSUPPORTED_COMMANDS
commands = {cmd.replace(" ", "|") for cmd in commands}
diff = set(commands) - set(r.acl_cat(cat))
assert len(diff) == 0, f"Commands not found in category {cat}: {diff}"
Expand Down

0 comments on commit 960a464

Please sign in to comment.