Skip to content

Commit

Permalink
added support for custom class schema in weaviate vector store (zilli…
Browse files Browse the repository at this point in the history
…ztech#500)

Signed-off-by: pranaychandekar <[email protected]>
  • Loading branch information
pranaychandekar authored Jul 29, 2023
1 parent 03a0597 commit 635dca5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 9 deletions.
4 changes: 4 additions & 0 deletions gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def get(name, **kwargs):
startup_period = kwargs.get("startup_period", WEAVIATE_STARTUP_PERIOD)
embedded_options = kwargs.get("embedded_options", None)
additional_config = kwargs.get("additional_config", None)
class_name = kwargs.get("class_name", "GPTCache")
class_schema = kwargs.get("class_schema", None)

vector_base = Weaviate(
url=url,
Expand All @@ -283,6 +285,8 @@ def get(name, **kwargs):
startup_period=startup_period,
embedded_options=embedded_options,
additional_config=additional_config,
class_name=class_name,
class_schema=class_schema,
top_k=top_k,
)
else:
Expand Down
21 changes: 13 additions & 8 deletions gptcache/manager/vector_data/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(
startup_period: Optional[int] = 5,
embedded_options: Optional[EmbeddedOptions] = None,
additional_config: Optional[Config] = None,
class_name: str = "GPTCache",
class_schema: dict = None,
top_k: Optional[int] = 1,
) -> None:

Expand All @@ -50,26 +52,29 @@ def __init__(
additional_config=additional_config,
)

if class_schema:
self.class_schema = class_schema
self.class_name = class_schema.get("class")
else:
self.class_name = class_name
self.class_schema = self._get_default_class_schema()

self._create_class()
self.top_k = top_k

def _create_class(self):
class_schema = self._get_default_class_schema()

self.class_name = class_schema.get("class")

if self.client.schema.exists(self.class_name):
gptcache_log.warning(
"The %s collection already exists, and it will be used directly.",
self.class_name,
)
else:
self.client.schema.create_class(class_schema)
self.client.schema.create_class(self.class_schema)
return self.class_name

@staticmethod
def _get_default_class_schema() -> dict:
def _get_default_class_schema(self) -> dict:
return {
"class": "GPTCache",
"class": self.class_name,
"description": "LLM response cache",
"properties": [
{
Expand Down
28 changes: 27 additions & 1 deletion tests/unit_tests/manager/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ def test_normal(self):
size = 1000
dim = 512
top_k = 10
class_name = "Vectorcache"

db = VectorBase(
"weaviate",
class_name=class_name,
top_k=top_k
)

db._create_class()
created_class_name = db._create_class()
self.assertEqual(class_name, created_class_name)
data = np.random.randn(size, dim).astype(np.float32)
db.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))])
self.assertEqual(len(db.search(data[0])), top_k)
Expand All @@ -34,3 +37,26 @@ def test_normal(self):
emb = db.get_embeddings(0)
self.assertIsNone(emb)
db.close()

custom_class_name = "Customcache"
class_schema = {
"class": custom_class_name,
"description": "LLM response cache",
"properties": [
{
"name": "data_id",
"dataType": ["int"],
"description": "The data-id generated by GPTCache for vectors.",
}
],
"vectorIndexConfig": {"distance": "cosine"},
}

db = VectorBase(
"weaviate",
class_schema=class_schema,
top_k=top_k
)
created_class_name = db._create_class()
self.assertEqual(custom_class_name, created_class_name)
db.close()

0 comments on commit 635dca5

Please sign in to comment.