diff --git a/gptcache/manager/scalar_data/redis_storage.py b/gptcache/manager/scalar_data/redis_storage.py index b0cb0e13..dd4d3c83 100644 --- a/gptcache/manager/scalar_data/redis_storage.py +++ b/gptcache/manager/scalar_data/redis_storage.py @@ -20,17 +20,32 @@ from redis_om import JsonModel, EmbeddedJsonModel, NotFoundError, Field, Migrator -def get_models(global_key): +def get_models(global_key: str, redis_connection: Redis): + """ + Get all the models for the given global key and redis connection. + :param global_key: Global key will be used as a prefix for all the keys + :type global_key: str + + :param redis_connection: Redis connection to use for all the models. + Note: This needs to be explicitly mentioned in `Meta` class for each Object Model, + otherwise it will use the default connection from the pool. + :type redis_connection: Redis + """ + class Counter: + """ + counter collection + """ key_name = global_key + ":counter" + database = redis_connection @classmethod - def incr(cls, con: Redis): - con.incr(cls.key_name) + def incr(cls): + cls.database.incr(cls.key_name) @classmethod - def get(cls, con: Redis): - return con.get(cls.key_name) + def get(cls): + return cls.database.get(cls.key_name) class Embedding: """ @@ -75,6 +90,9 @@ class Answers(EmbeddedJsonModel): answer: str answer_type: int + class Meta: + database = redis_connection + class Questions(JsonModel): """ questions collection @@ -89,6 +107,7 @@ class Questions(JsonModel): class Meta: global_key_prefix = global_key model_key_prefix = "questions" + database = redis_connection class Sessions(JsonModel): """ @@ -98,6 +117,7 @@ class Sessions(JsonModel): class Meta: global_key_prefix = global_key model_key_prefix = "sessions" + database = redis_connection session_id: str = Field(index=True) session_question: str @@ -111,6 +131,7 @@ class QuestionDeps(JsonModel): class Meta: global_key_prefix = global_key model_key_prefix = "ques_deps" + database = redis_connection question_id: str = Field(index=True) dep_name: str @@ -125,6 +146,7 @@ class Report(JsonModel): class Meta: global_key_prefix = global_key model_key_prefix = "report" + database = redis_connection user_question: str cache_question_id: int = Field(index=True) @@ -194,7 +216,7 @@ def __init__( self._session, self._counter, self._report, - ) = get_models(global_key_prefix) + ) = get_models(global_key_prefix, redis_connection=self.con) Migrator().run() @@ -202,8 +224,8 @@ def create(self): pass def _insert(self, data: CacheData, pipeline: Pipeline = None): - self._counter.incr(self.con) - pk = str(self._counter.get(self.con)) + self._counter.incr() + pk = str(self._counter.get()) answers = data.answers if isinstance(data.answers, list) else [data.answers] all_data = [] for answer in answers: @@ -360,7 +382,8 @@ def delete_session(self, keys: List[str]): self._session.delete_many(sessions_to_delete, pipeline) pipeline.execute() - def report_cache(self, user_question, cache_question, cache_question_id, cache_answer, similarity_value, cache_delta_time): + def report_cache(self, user_question, cache_question, cache_question_id, cache_answer, similarity_value, + cache_delta_time): self._report( user_question=user_question, cache_question=cache_question, diff --git a/tests/unit_tests/manager/test_redis_cache_storage.py b/tests/unit_tests/manager/test_redis_cache_storage.py index 459d1978..cef3d985 100644 --- a/tests/unit_tests/manager/test_redis_cache_storage.py +++ b/tests/unit_tests/manager/test_redis_cache_storage.py @@ -4,28 +4,30 @@ import numpy as np from gptcache.manager.scalar_data.base import CacheData, Question -from gptcache.manager.scalar_data.redis_storage import RedisCacheStorage +from gptcache.manager.scalar_data.redis_storage import RedisCacheStorage, get_models from gptcache.utils import import_redis import_redis() -from redis_om import get_redis_connection +from redis_om import get_redis_connection, RedisModel class TestRedisStorage(unittest.TestCase): test_dbname = "gptcache_test" + url = "redis://default:default@localhost:6379" def setUp(cls) -> None: cls._clear_test_db() @staticmethod def _clear_test_db(): - r = get_redis_connection() + r = get_redis_connection(url=TestRedisStorage.url) r.flushall() r.flushdb() time.sleep(1) def test_normal(self): - redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname) + redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname, + url=self.url) data = [] for i in range(1, 10): data.append( @@ -61,7 +63,8 @@ def test_normal(self): assert redis_storage.count(is_all=True) == 7 def test_with_deps(self): - redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname) + redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname, + url=self.url) data_id = redis_storage.batch_insert( [ CacheData( @@ -98,7 +101,8 @@ def test_with_deps(self): assert ret.question.deps[1].dep_type == 1 def test_create_on(self): - redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname) + redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname, + url=self.url) redis_storage.create() data = [] for i in range(1, 10): @@ -124,7 +128,8 @@ def test_create_on(self): assert last_access1 < last_access2 def test_session(self): - redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname) + redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname, + url=self.url) data = [] for i in range(1, 11): data.append(