diff --git a/.gitignore b/.gitignore index 1d60f34..25a08ca 100644 --- a/.gitignore +++ b/.gitignore @@ -127,4 +127,7 @@ dmypy.json # Pyre type checker .pyre/ -*~ \ No newline at end of file +*~ + +# vscode +.vscode diff --git a/hnsqlite/collection.py b/hnsqlite/collection.py index 5af74a9..2119be6 100644 --- a/hnsqlite/collection.py +++ b/hnsqlite/collection.py @@ -535,6 +535,10 @@ def search(self, vector: np.array, k = 12, filter=None) -> List[SearchResponse]: """ if isinstance(vector, list): vector = np.array(vector) + + vector_dim = vector.shape[0] + if vector_dim != self.config.dim: + raise ValueError(f"Dim mismatch: vector: {vector_dim}, collection: {self.config.dim}") def _filter(id): with Session(self.db_engine) as session: diff --git a/test/test_collection.py b/test/test_collection.py index fb843d3..77c3c99 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -112,7 +112,7 @@ def test_search_invalid(self): logger.info("test_search_invalid") collection = Collection("test-collection7", 128) vector = np.random.rand(64).astype(np.float32) # Wrong length - with self.assertRaises(Exception): + with self.assertRaises(ValueError): collection.search(vector) diff --git a/test/test_filter.py b/test/test_filter.py index c9971b3..9f06bd6 100644 --- a/test/test_filter.py +++ b/test/test_filter.py @@ -1,6 +1,6 @@ import unittest -from filter import filter_item +from hnsqlite.filter import filter_item class TestFilterItem(unittest.TestCase): def test_filter_item(self):