diff --git a/README.md b/README.md
index eeda249ae..120a2431b 100644
--- a/README.md
+++ b/README.md
@@ -281,6 +281,24 @@ Case No. | Case Type | Dataset Size | Filtering Rate | Results |
Each case provides an in-depth examination of a vector database's abilities, providing you a comprehensive view of the database's performance.
+#### Custom Dataset for Performance case
+
+Through the `/custom` page, users can customize their own performance case using local datasets. After saving, the corresponding case can be selected from the `/run_test` page to perform the test.
+
+![image](fig/custom_dataset.png)
+![image](fig/custom_case_run_test.png)
+
+We have strict requirements for the data set format, please follow them.
+- `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
+ - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
+ - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
+ - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`.
+
+- `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files.
+
+- `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.
+
+
## Goals
Our goals of this benchmark are:
### Reproducibility & Usability
diff --git a/fig/custom_case_run_test.png b/fig/custom_case_run_test.png
new file mode 100644
index 000000000..8817b3439
Binary files /dev/null and b/fig/custom_case_run_test.png differ
diff --git a/fig/custom_dataset.png b/fig/custom_dataset.png
new file mode 100644
index 000000000..9d665891a
Binary files /dev/null and b/fig/custom_dataset.png differ
diff --git a/pyproject.toml b/pyproject.toml
index 1311cceef..edc3575c5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -80,4 +80,5 @@ zilliz_cloud = []
[project.scripts]
init_bench = "vectordb_bench.__main__:main"
vectordbbench = "vectordb_bench.cli.vectordbbench:cli"
+
[tool.setuptools_scm]
diff --git a/vectordb_bench/__init__.py b/vectordb_bench/__init__.py
index 018cd7732..f7664502f 100644
--- a/vectordb_bench/__init__.py
+++ b/vectordb_bench/__init__.py
@@ -35,6 +35,7 @@ class config:
K_DEFAULT = 100 # default return top k nearest neighbors during search
+ CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")
CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h
diff --git a/vectordb_bench/backend/assembler.py b/vectordb_bench/backend/assembler.py
index 6b0e3c81d..e7da4d49f 100644
--- a/vectordb_bench/backend/assembler.py
+++ b/vectordb_bench/backend/assembler.py
@@ -14,7 +14,7 @@ class Assembler:
def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
c_cls = task.case_config.case_id.case_cls
- c = c_cls()
+ c = c_cls(task.case_config.custom_case)
if type(task.db_case_config) != EmptyDBCaseConfig:
task.db_case_config.metric_type = c.dataset.data.metric_type
diff --git a/vectordb_bench/backend/cases.py b/vectordb_bench/backend/cases.py
index 6f4b35974..6c43bb910 100644
--- a/vectordb_bench/backend/cases.py
+++ b/vectordb_bench/backend/cases.py
@@ -4,9 +4,13 @@
from typing import Type
from vectordb_bench import config
+from vectordb_bench.backend.clients.api import MetricType
from vectordb_bench.base import BaseModel
+from vectordb_bench.frontend.components.custom.getCustomConfig import (
+ CustomDatasetConfig,
+)
-from .dataset import Dataset, DatasetManager
+from .dataset import CustomDataset, Dataset, DatasetManager
log = logging.getLogger(__name__)
@@ -44,25 +48,24 @@ class CaseType(Enum):
Performance1536D50K = 50
Custom = 100
+ PerformanceCustomDataset = 101
- @property
def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]:
- if self not in type2case:
- raise NotImplementedError(f"Case {self} has not implemented. You can add it manually to vectordb_bench.backend.cases.type2case or define a custom_configs['custom_cls']")
- return type2case[self]
+ if custom_configs is None:
+ return type2case.get(self)()
+ else:
+ return type2case.get(self)(**custom_configs)
- @property
- def case_name(self) -> str:
- c = self.case_cls
+ def case_name(self, custom_configs: dict | None = None) -> str:
+ c = self.case_cls(custom_configs)
if c is not None:
- return c().name
+ return c.name
raise ValueError("Case unsupported")
- @property
- def case_description(self) -> str:
- c = self.case_cls
+ def case_description(self, custom_configs: dict | None = None) -> str:
+ c = self.case_cls(custom_configs)
if c is not None:
- return c().description
+ return c.description
raise ValueError("Case unsupported")
@@ -289,26 +292,69 @@ class Performance1536D50K(PerformanceCase):
optimize_timeout: float | int | None = 15 * 60
+def metric_type_map(s: str) -> MetricType:
+ if s.lower() == "cosine":
+ return MetricType.COSINE
+ if s.lower() == "l2" or s.lower() == "euclidean":
+ return MetricType.L2
+ if s.lower() == "ip":
+ return MetricType.IP
+ err_msg = f"Not support metric_type: {s}"
+ log.error(err_msg)
+ raise RuntimeError(err_msg)
+
+
+class PerformanceCustomDataset(PerformanceCase):
+ case_id: CaseType = CaseType.PerformanceCustomDataset
+ name: str = "Performance With Custom Dataset"
+ description: str = ""
+ dataset: DatasetManager
+
+ def __init__(
+ self,
+ name,
+ description,
+ load_timeout,
+ optimize_timeout,
+ dataset_config,
+ **kwargs,
+ ):
+ dataset_config = CustomDatasetConfig(**dataset_config)
+ dataset = CustomDataset(
+ name=dataset_config.name,
+ size=dataset_config.size,
+ dim=dataset_config.dim,
+ metric_type=metric_type_map(dataset_config.metric_type),
+ use_shuffled=dataset_config.use_shuffled,
+ with_gt=dataset_config.with_gt,
+ dir=dataset_config.dir,
+ file_num=dataset_config.file_count,
+ )
+ super().__init__(
+ name=name,
+ description=description,
+ load_timeout=load_timeout,
+ optimize_timeout=optimize_timeout,
+ dataset=DatasetManager(data=dataset),
+ )
+
+
type2case = {
CaseType.CapacityDim960: CapacityDim960,
CaseType.CapacityDim128: CapacityDim128,
-
CaseType.Performance768D100M: Performance768D100M,
CaseType.Performance768D10M: Performance768D10M,
CaseType.Performance768D1M: Performance768D1M,
-
CaseType.Performance768D10M1P: Performance768D10M1P,
CaseType.Performance768D1M1P: Performance768D1M1P,
CaseType.Performance768D10M99P: Performance768D10M99P,
CaseType.Performance768D1M99P: Performance768D1M99P,
-
CaseType.Performance1536D500K: Performance1536D500K,
CaseType.Performance1536D5M: Performance1536D5M,
-
CaseType.Performance1536D500K1P: Performance1536D500K1P,
CaseType.Performance1536D5M1P: Performance1536D5M1P,
-
CaseType.Performance1536D500K99P: Performance1536D500K99P,
CaseType.Performance1536D5M99P: Performance1536D5M99P,
CaseType.Performance1536D50K: Performance1536D50K,
+ CaseType.PerformanceCustomDataset: PerformanceCustomDataset,
}
diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py
index 2b630eae3..d559eb6be 100644
--- a/vectordb_bench/backend/dataset.py
+++ b/vectordb_bench/backend/dataset.py
@@ -33,6 +33,7 @@ class BaseDataset(BaseModel):
use_shuffled: bool
with_gt: bool = False
_size_label: dict[int, SizeLabel] = PrivateAttr()
+ isCustom: bool = False
@validator("size")
def verify_size(cls, v):
@@ -52,7 +53,27 @@ def dir_name(self) -> str:
def file_count(self) -> int:
return self._size_label.get(self.size).file_count
+class CustomDataset(BaseDataset):
+ dir: str
+ file_num: int
+ isCustom: bool = True
+
+ @validator("size")
+ def verify_size(cls, v):
+ return v
+
+ @property
+ def label(self) -> str:
+ return "Custom"
+ @property
+ def dir_name(self) -> str:
+ return self.dir
+
+ @property
+ def file_count(self) -> int:
+ return self.file_num
+
class LAION(BaseDataset):
name: str = "LAION"
dim: int = 768
@@ -186,11 +207,12 @@ def prepare(self,
gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
all_files.extend([gt_file, test_file])
- source.reader().read(
- dataset=self.data.dir_name.lower(),
- files=all_files,
- local_ds_root=self.data_dir,
- )
+ if not self.data.isCustom:
+ source.reader().read(
+ dataset=self.data.dir_name.lower(),
+ files=all_files,
+ local_ds_root=self.data_dir,
+ )
if gt_file is not None and test_file is not None:
self.test_data = self._read_file(test_file)
diff --git a/vectordb_bench/custom/custom_case.json b/vectordb_bench/custom/custom_case.json
new file mode 100644
index 000000000..48ca8d8c4
--- /dev/null
+++ b/vectordb_bench/custom/custom_case.json
@@ -0,0 +1,18 @@
+[
+ {
+ "name": "My Dataset (Performace Case)",
+ "description": "this is a customized dataset.",
+ "load_timeout": 36000,
+ "optimize_timeout": 36000,
+ "dataset_config": {
+ "name": "My Dataset",
+ "dir": "/my_dataset_path",
+ "size": 1000000,
+ "dim": 1024,
+ "metric_type": "L2",
+ "file_count": 1,
+ "use_shuffled": false,
+ "with_gt": true
+ }
+ }
+]
\ No newline at end of file
diff --git a/vectordb_bench/frontend/components/check_results/charts.py b/vectordb_bench/frontend/components/check_results/charts.py
index 7e28d1e66..c2b2813b8 100644
--- a/vectordb_bench/frontend/components/check_results/charts.py
+++ b/vectordb_bench/frontend/components/check_results/charts.py
@@ -1,19 +1,19 @@
from vectordb_bench.backend.cases import Case
from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.models import ResultLabel
import plotly.express as px
-def drawCharts(st, allData, failedTasks, cases: list[Case]):
+def drawCharts(st, allData, failedTasks, caseNames: list[str]):
initMainExpanderStyle(st)
- for case in cases:
- chartContainer = st.expander(case.name, True)
- data = [data for data in allData if data["case_name"] == case.name]
+ for caseName in caseNames:
+ chartContainer = st.expander(caseName, True)
+ data = [data for data in allData if data["case_name"] == caseName]
drawChart(data, chartContainer)
- errorDBs = failedTasks[case.name]
+ errorDBs = failedTasks[caseName]
showFailedDBs(chartContainer, errorDBs)
diff --git a/vectordb_bench/frontend/components/check_results/data.py b/vectordb_bench/frontend/components/check_results/data.py
index c092da3a0..1e6bba00e 100644
--- a/vectordb_bench/frontend/components/check_results/data.py
+++ b/vectordb_bench/frontend/components/check_results/data.py
@@ -8,9 +8,9 @@
def getChartData(
tasks: list[CaseResult],
dbNames: list[str],
- cases: list[Case],
+ caseNames: list[str],
):
- filterTasks = getFilterTasks(tasks, dbNames, cases)
+ filterTasks = getFilterTasks(tasks, dbNames, caseNames)
mergedTasks, failedTasks = mergeTasks(filterTasks)
return mergedTasks, failedTasks
@@ -18,14 +18,13 @@ def getChartData(
def getFilterTasks(
tasks: list[CaseResult],
dbNames: list[str],
- cases: list[Case],
+ caseNames: list[str],
) -> list[CaseResult]:
- case_ids = [case.case_id for case in cases]
filterTasks = [
task
for task in tasks
if task.task_config.db_name in dbNames
- and task.task_config.case_config.case_id in case_ids
+ and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
]
return filterTasks
@@ -36,16 +35,17 @@ def mergeTasks(tasks: list[CaseResult]):
db_name = task.task_config.db_name
db = task.task_config.db.value
db_label = task.task_config.db_config.db_label or ""
- case_id = task.task_config.case_config.case_id
- dbCaseMetricsMap[db_name][case_id] = {
+ case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
+ dbCaseMetricsMap[db_name][case.name] = {
"db": db,
"db_label": db_label,
"metrics": mergeMetrics(
- dbCaseMetricsMap[db_name][case_id].get("metrics", {}),
+ dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
asdict(task.metrics),
),
"label": getBetterLabel(
- dbCaseMetricsMap[db_name][case_id].get("label", ResultLabel.FAILED),
+ dbCaseMetricsMap[db_name][case.name].get(
+ "label", ResultLabel.FAILED),
task.label,
),
}
@@ -53,12 +53,11 @@ def mergeTasks(tasks: list[CaseResult]):
mergedTasks = []
failedTasks = defaultdict(lambda: defaultdict(str))
for db_name, caseMetricsMap in dbCaseMetricsMap.items():
- for case_id, metricInfo in caseMetricsMap.items():
+ for case_name, metricInfo in caseMetricsMap.items():
metrics = metricInfo["metrics"]
db = metricInfo["db"]
db_label = metricInfo["db_label"]
label = metricInfo["label"]
- case_name = case_id.case_name
if label == ResultLabel.NORMAL:
mergedTasks.append(
{
@@ -80,7 +79,8 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
metrics = {**metrics_1}
for key, value in metrics_2.items():
metrics[key] = (
- getBetterMetric(key, value, metrics[key]) if key in metrics else value
+ getBetterMetric(
+ key, value, metrics[key]) if key in metrics else value
)
return metrics
diff --git a/vectordb_bench/frontend/components/check_results/expanderStyle.py b/vectordb_bench/frontend/components/check_results/expanderStyle.py
index 9496313e8..436eeec38 100644
--- a/vectordb_bench/frontend/components/check_results/expanderStyle.py
+++ b/vectordb_bench/frontend/components/check_results/expanderStyle.py
@@ -1,7 +1,7 @@
def initMainExpanderStyle(st):
st.markdown(
"""""",
+ unsafe_allow_html=True,
+ )
\ No newline at end of file
diff --git a/vectordb_bench/frontend/components/run_test/autoRefresh.py b/vectordb_bench/frontend/components/run_test/autoRefresh.py
index fe31d8205..034ab5017 100644
--- a/vectordb_bench/frontend/components/run_test/autoRefresh.py
+++ b/vectordb_bench/frontend/components/run_test/autoRefresh.py
@@ -1,5 +1,5 @@
from streamlit_autorefresh import st_autorefresh
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
def autoRefresh():
diff --git a/vectordb_bench/frontend/components/run_test/caseSelector.py b/vectordb_bench/frontend/components/run_test/caseSelector.py
index 49b839163..58799deff 100644
--- a/vectordb_bench/frontend/components/run_test/caseSelector.py
+++ b/vectordb_bench/frontend/components/run_test/caseSelector.py
@@ -1,9 +1,13 @@
-from vectordb_bench.frontend.const.styles import *
+
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.backend.cases import CaseType
-from vectordb_bench.frontend.const.dbCaseConfigs import *
+from vectordb_bench.frontend.config.dbCaseConfigs import *
+from collections import defaultdict
+
+from vectordb_bench.frontend.utils import addHorizontalLine
-def caseSelector(st, activedDbList):
+def caseSelector(st, activedDbList: list[DB]):
st.markdown(
"
",
unsafe_allow_html=True,
@@ -14,41 +18,49 @@ def caseSelector(st, activedDbList):
unsafe_allow_html=True,
)
- caseIsActived = {case: False for case in CASE_LIST}
- allCaseConfigs = {db: {case: {} for case in CASE_LIST} for db in DB_LIST}
- for caseOrDivider in CASE_LIST_WITH_DIVIDER:
- if caseOrDivider == DIVIDER:
- caseItemContainer.markdown(
- "",
- unsafe_allow_html=True,
- )
+ activedCaseList: list[CaseConfig] = []
+ dbToCaseClusterConfigs = defaultdict(lambda: defaultdict(dict))
+ dbToCaseConfigs = defaultdict(lambda: defaultdict(dict))
+ caseClusters = UI_CASE_CLUSTERS + [get_custom_case_cluter()]
+ for caseCluster in caseClusters:
+ activedCaseList += caseClusterExpander(
+ st, caseCluster, dbToCaseClusterConfigs, activedDbList)
+ for db in dbToCaseClusterConfigs:
+ for uiCaseItem in dbToCaseClusterConfigs[db]:
+ for case in uiCaseItem.cases:
+ dbToCaseConfigs[db][case] = dbToCaseClusterConfigs[db][uiCaseItem]
+
+ return activedCaseList, dbToCaseConfigs
+
+
+def caseClusterExpander(st, caseCluster: UICaseItemCluster, dbToCaseClusterConfigs, activedDbList: list[DB]):
+ expander = st.expander(caseCluster.label, False)
+ activedCases: list[CaseConfig] = []
+ for uiCaseItem in caseCluster.uiCaseItems:
+ if uiCaseItem.isLine:
+ addHorizontalLine(expander)
else:
- case = caseOrDivider
- caseItemContainer = st.container()
- caseIsActived[case] = caseItem(
- caseItemContainer, allCaseConfigs, case, activedDbList
- )
- activedCaseList = [case for case in CASE_LIST if caseIsActived[case]]
- return activedCaseList, allCaseConfigs
+ activedCases += caseItemCheckbox(expander,
+ dbToCaseClusterConfigs, uiCaseItem, activedDbList)
+ return activedCases
-def caseItem(st, allCaseConfigs, case: CaseType, activedDbList):
- selected = st.checkbox(case.case_name)
+def caseItemCheckbox(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
+ selected = st.checkbox(uiCaseItem.label)
st.markdown(
- f"{case.case_description}
",
+ f"{uiCaseItem.description}
",
unsafe_allow_html=True,
)
if selected:
- caseConfigSettingContainer = st.container()
caseConfigSetting(
- caseConfigSettingContainer, allCaseConfigs, case, activedDbList
+ st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList
)
- return selected
+ return uiCaseItem.cases if selected else []
-def caseConfigSetting(st, allCaseConfigs, case, activedDbList):
+def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
for db in activedDbList:
columns = st.columns(1 + CASE_CONFIG_SETTING_COLUMNS)
# column 0 - title
@@ -57,12 +69,12 @@ def caseConfigSetting(st, allCaseConfigs, case, activedDbList):
f"{db.name}
",
unsafe_allow_html=True,
)
- caseConfig = allCaseConfigs[db][case]
k = 0
- for config in CASE_CONFIG_MAP.get(db, {}).get(case.case_cls().label, []):
+ caseConfig = dbToCaseClusterConfigs[db][uiCaseItem]
+ for config in CASE_CONFIG_MAP.get(db, {}).get(uiCaseItem.caseLabel, []):
if config.isDisplayed(caseConfig):
column = columns[1 + k % CASE_CONFIG_SETTING_COLUMNS]
- key = "%s-%s-%s" % (db, case, config.label.value)
+ key = "%s-%s-%s" % (db, uiCaseItem.label, config.label.value)
if config.inputType == InputType.Text:
caseConfig[config.label] = column.text_input(
config.displayLabel if config.displayLabel else config.label.value,
diff --git a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py
index ffd52721f..8f4f35c93 100644
--- a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py
+++ b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py
@@ -1,13 +1,9 @@
from pydantic import ValidationError
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.frontend.utils import inputIsPassword
def dbConfigSettings(st, activedDbList):
- st.markdown(
- "",
- unsafe_allow_html=True,
- )
expander = st.expander("Configurations for the selected databases", True)
dbConfigs = {}
diff --git a/vectordb_bench/frontend/components/run_test/dbSelector.py b/vectordb_bench/frontend/components/run_test/dbSelector.py
index 5fcbd8c08..ccf0168c6 100644
--- a/vectordb_bench/frontend/components/run_test/dbSelector.py
+++ b/vectordb_bench/frontend/components/run_test/dbSelector.py
@@ -1,7 +1,6 @@
from streamlit.runtime.media_file_storage import MediaFileStorageError
-
-from vectordb_bench.frontend.const.styles import *
-from vectordb_bench.frontend.const.dbCaseConfigs import DB_LIST
+from vectordb_bench.frontend.config.styles import DB_SELECTOR_COLUMNS, DB_TO_ICON
+from vectordb_bench.frontend.config.dbCaseConfigs import DB_LIST
def dbSelector(st):
@@ -18,17 +17,6 @@ def dbSelector(st):
dbContainerColumns = st.columns(DB_SELECTOR_COLUMNS, gap="small")
dbIsActived = {db: False for db in DB_LIST}
- # style - image; column gap; checkbox font;
- st.markdown(
- """
-
- """,
- unsafe_allow_html=True,
- )
for i, db in enumerate(DB_LIST):
column = dbContainerColumns[i % DB_SELECTOR_COLUMNS]
dbIsActived[db] = column.checkbox(db.name)
diff --git a/vectordb_bench/frontend/components/run_test/generateTasks.py b/vectordb_bench/frontend/components/run_test/generateTasks.py
index 55f3c8399..828913f30 100644
--- a/vectordb_bench/frontend/components/run_test/generateTasks.py
+++ b/vectordb_bench/frontend/components/run_test/generateTasks.py
@@ -1,17 +1,15 @@
+from vectordb_bench.backend.clients import DB
from vectordb_bench.models import CaseConfig, CaseConfigParamType, TaskConfig
-def generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs):
+def generate_tasks(activedDbList: list[DB], dbConfigs, activedCaseList: list[CaseConfig], allCaseConfigs):
tasks = []
for db in activedDbList:
for case in activedCaseList:
task = TaskConfig(
db=db.value,
db_config=dbConfigs[db],
- case_config=CaseConfig(
- case_id=case.value,
- custom_case={},
- ),
+ case_config=case,
db_case_config=db.case_config_cls(
allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None)
)(**{key.value: value for key, value in allCaseConfigs[db][case].items()}),
diff --git a/vectordb_bench/frontend/components/run_test/initStyle.py b/vectordb_bench/frontend/components/run_test/initStyle.py
new file mode 100644
index 000000000..59dd438e1
--- /dev/null
+++ b/vectordb_bench/frontend/components/run_test/initStyle.py
@@ -0,0 +1,14 @@
+def initStyle(st):
+ st.markdown(
+ """""",
+ unsafe_allow_html=True,
+ )
\ No newline at end of file
diff --git a/vectordb_bench/frontend/components/run_test/submitTask.py b/vectordb_bench/frontend/components/run_test/submitTask.py
index 26cb1ef70..f824dd9d9 100644
--- a/vectordb_bench/frontend/components/run_test/submitTask.py
+++ b/vectordb_bench/frontend/components/run_test/submitTask.py
@@ -1,5 +1,5 @@
from datetime import datetime
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.interface import benchMarkRunner
diff --git a/vectordb_bench/frontend/const/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py
similarity index 78%
rename from vectordb_bench/frontend/const/dbCaseConfigs.py
rename to vectordb_bench/frontend/config/dbCaseConfigs.py
index ed101ac69..ce8a3a4ae 100644
--- a/vectordb_bench/frontend/const/dbCaseConfigs.py
+++ b/vectordb_bench/frontend/config/dbCaseConfigs.py
@@ -1,43 +1,147 @@
-from enum import IntEnum
+from enum import IntEnum, Enum
import typing
from pydantic import BaseModel
from vectordb_bench.backend.cases import CaseLabel, CaseType
from vectordb_bench.backend.clients import DB
from vectordb_bench.backend.clients.api import IndexType
+from vectordb_bench.frontend.components.custom.getCustomConfig import get_custom_configs
-from vectordb_bench.models import CaseConfigParamType
+from vectordb_bench.models import CaseConfig, CaseConfigParamType
MAX_STREAMLIT_INT = (1 << 53) - 1
DB_LIST = [d for d in DB if d != DB.Test]
-DIVIDER = "DIVIDER"
-CASE_LIST_WITH_DIVIDER = [
+
+class Delimiter(Enum):
+ Line = "line"
+
+
+class BatchCaseConfig(BaseModel):
+ label: str = ""
+ description: str = ""
+ cases: list[CaseConfig] = []
+
+
+class UICaseItem(BaseModel):
+ isLine: bool = False
+ label: str = ""
+ description: str = ""
+ cases: list[CaseConfig] = []
+ caseLabel: CaseLabel = CaseLabel.Performance
+
+ def __init__(
+ self,
+ isLine: bool = False,
+ case_id: CaseType = None,
+ custom_case: dict = {},
+ cases: list[CaseConfig] = [],
+ label: str = "",
+ description: str = "",
+ caseLabel: CaseLabel = CaseLabel.Performance,
+ ):
+ if isLine is True:
+ super().__init__(isLine=True)
+ elif case_id is not None and isinstance(case_id, CaseType):
+ c = case_id.case_cls(custom_case)
+ super().__init__(
+ label=c.name,
+ description=c.description,
+ cases=[CaseConfig(case_id=case_id, custom_case=custom_case)],
+ caseLabel=c.label,
+ )
+ else:
+ super().__init__(
+ label=label,
+ description=description,
+ cases=cases,
+ caseLabel=caseLabel,
+ )
+
+ def __hash__(self) -> int:
+ return hash(self.json())
+
+
+class UICaseItemCluster(BaseModel):
+ label: str = ""
+ uiCaseItems: list[UICaseItem] = []
+
+
+def get_custom_case_items() -> list[UICaseItem]:
+ custom_configs = get_custom_configs()
+ return [
+ UICaseItem(
+ case_id=CaseType.PerformanceCustomDataset, custom_case=custom_config.dict()
+ )
+ for custom_config in custom_configs
+ ]
+
+
+def get_custom_case_cluter() -> UICaseItemCluster:
+ return UICaseItemCluster(
+ label="Custom Search Performance Test", uiCaseItems=get_custom_case_items()
+ )
+
+
+UI_CASE_CLUSTERS: list[UICaseItemCluster] = [
+ UICaseItemCluster(
+ label="Search Performance Test",
+ uiCaseItems=[
+ UICaseItem(case_id=CaseType.Performance768D100M),
+ UICaseItem(case_id=CaseType.Performance768D10M),
+ UICaseItem(case_id=CaseType.Performance768D1M),
+ UICaseItem(isLine=True),
+ UICaseItem(case_id=CaseType.Performance1536D5M),
+ UICaseItem(case_id=CaseType.Performance1536D500K),
+ UICaseItem(case_id=CaseType.Performance1536D50K),
+ ],
+ ),
+ UICaseItemCluster(
+ label="Filter Search Performance Test",
+ uiCaseItems=[
+ UICaseItem(case_id=CaseType.Performance768D10M1P),
+ UICaseItem(case_id=CaseType.Performance768D10M99P),
+ UICaseItem(case_id=CaseType.Performance768D1M1P),
+ UICaseItem(case_id=CaseType.Performance768D1M99P),
+ UICaseItem(isLine=True),
+ UICaseItem(case_id=CaseType.Performance1536D5M1P),
+ UICaseItem(case_id=CaseType.Performance1536D5M99P),
+ UICaseItem(case_id=CaseType.Performance1536D500K1P),
+ UICaseItem(case_id=CaseType.Performance1536D500K99P),
+ ],
+ ),
+ UICaseItemCluster(
+ label="Capacity Test",
+ uiCaseItems=[
+ UICaseItem(case_id=CaseType.CapacityDim960),
+ UICaseItem(case_id=CaseType.CapacityDim128),
+ ],
+ ),
+]
+
+# DIVIDER = "DIVIDER"
+DISPLAY_CASE_ORDER: list[CaseType] = [
CaseType.Performance768D100M,
CaseType.Performance768D10M,
CaseType.Performance768D1M,
- DIVIDER,
CaseType.Performance1536D5M,
CaseType.Performance1536D500K,
CaseType.Performance1536D50K,
- DIVIDER,
CaseType.Performance768D10M1P,
CaseType.Performance768D1M1P,
- DIVIDER,
CaseType.Performance1536D5M1P,
CaseType.Performance1536D500K1P,
- DIVIDER,
CaseType.Performance768D10M99P,
CaseType.Performance768D1M99P,
- DIVIDER,
CaseType.Performance1536D5M99P,
CaseType.Performance1536D500K99P,
- DIVIDER,
CaseType.CapacityDim960,
CaseType.CapacityDim128,
]
+CASE_NAME_ORDER = [case.case_cls().name for case in DISPLAY_CASE_ORDER]
-CASE_LIST = [item for item in CASE_LIST_WITH_DIVIDER if isinstance(item, CaseType)]
+# CASE_LIST = [
+# item for item in CASE_LIST_WITH_DIVIDER if isinstance(item, CaseType)]
class InputType(IntEnum):
@@ -53,7 +157,7 @@ class CaseConfigInput(BaseModel):
inputHelp: str = ""
displayLabel: str = ""
# todo type should be a function
- isDisplayed: typing.Any = lambda x: True
+ isDisplayed: typing.Any = lambda config: True
CaseConfigParamInput_IndexType = CaseConfigInput(
@@ -146,7 +250,7 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_maintenance_work_mem_PgVector = CaseConfigInput(
label=CaseConfigParamType.maintenance_work_mem,
inputHelp="Recommended value: 1.33x the index size, not to exceed the available free memory."
- "Specify in gigabytes. e.g. 8GB",
+ "Specify in gigabytes. e.g. 8GB",
inputType=InputType.Text,
inputConfig={
"value": "8GB",
@@ -157,7 +261,7 @@ class CaseConfigInput(BaseModel):
label=CaseConfigParamType.max_parallel_workers,
displayLabel="Max parallel workers",
inputHelp="Recommended value: (cpu cores - 1). This will set the parameters: max_parallel_maintenance_workers,"
- " max_parallel_workers & table(parallel_workers)",
+ " max_parallel_workers & table(parallel_workers)",
inputType=InputType.Number,
inputConfig={
"min": 0,
@@ -514,7 +618,8 @@ class CaseConfigInput(BaseModel):
"options": ["x4", "x8", "x16", "x32", "x64"],
},
isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None)
- == "product" and config.get(CaseConfigParamType.IndexType, None)
+ == "product"
+ and config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
IndexType.IVFFlat.value,
@@ -582,22 +687,24 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_NumCandidates_ES,
]
-PgVectorLoadingConfig = [CaseConfigParamInput_IndexType_PgVector,
- CaseConfigParamInput_Lists_PgVector,
- CaseConfigParamInput_m,
- CaseConfigParamInput_EFConstruction_PgVector,
- CaseConfigParamInput_maintenance_work_mem_PgVector,
- CaseConfigParamInput_max_parallel_workers_PgVector,
- ]
-PgVectorPerformanceConfig = [CaseConfigParamInput_IndexType_PgVector,
- CaseConfigParamInput_m,
- CaseConfigParamInput_EFConstruction_PgVector,
- CaseConfigParamInput_EFSearch_PgVector,
- CaseConfigParamInput_Lists_PgVector,
- CaseConfigParamInput_Probes_PgVector,
- CaseConfigParamInput_maintenance_work_mem_PgVector,
- CaseConfigParamInput_max_parallel_workers_PgVector,
- ]
+PgVectorLoadingConfig = [
+ CaseConfigParamInput_IndexType_PgVector,
+ CaseConfigParamInput_Lists_PgVector,
+ CaseConfigParamInput_m,
+ CaseConfigParamInput_EFConstruction_PgVector,
+ CaseConfigParamInput_maintenance_work_mem_PgVector,
+ CaseConfigParamInput_max_parallel_workers_PgVector,
+]
+PgVectorPerformanceConfig = [
+ CaseConfigParamInput_IndexType_PgVector,
+ CaseConfigParamInput_m,
+ CaseConfigParamInput_EFConstruction_PgVector,
+ CaseConfigParamInput_EFSearch_PgVector,
+ CaseConfigParamInput_Lists_PgVector,
+ CaseConfigParamInput_Probes_PgVector,
+ CaseConfigParamInput_maintenance_work_mem_PgVector,
+ CaseConfigParamInput_max_parallel_workers_PgVector,
+]
PgVectoRSLoadingConfig = [
CaseConfigParamInput_IndexType,
diff --git a/vectordb_bench/frontend/const/dbPrices.py b/vectordb_bench/frontend/config/dbPrices.py
similarity index 100%
rename from vectordb_bench/frontend/const/dbPrices.py
rename to vectordb_bench/frontend/config/dbPrices.py
diff --git a/vectordb_bench/frontend/const/styles.py b/vectordb_bench/frontend/config/styles.py
similarity index 100%
rename from vectordb_bench/frontend/const/styles.py
rename to vectordb_bench/frontend/config/styles.py
diff --git a/vectordb_bench/frontend/pages/concurrent.py b/vectordb_bench/frontend/pages/concurrent.py
index 0c1415efc..b4eae339c 100644
--- a/vectordb_bench/frontend/pages/concurrent.py
+++ b/vectordb_bench/frontend/pages/concurrent.py
@@ -1,18 +1,14 @@
-
-
-
import streamlit as st
-from vectordb_bench.backend.cases import CaseType
from vectordb_bench.frontend.components.check_results.footer import footer
-from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
-from vectordb_bench.frontend.components.check_results.priceTable import priceTable
from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon
-from vectordb_bench.frontend.components.check_results.nav import NavToResults, NavToRunTest
-from vectordb_bench.frontend.components.check_results.charts import drawMetricChart
+from vectordb_bench.frontend.components.check_results.nav import (
+ NavToResults,
+ NavToRunTest,
+)
from vectordb_bench.frontend.components.check_results.filters import getshownData
from vectordb_bench.frontend.components.concurrent.charts import drawChartsByCase
from vectordb_bench.frontend.components.get_results.saveAsImage import getResults
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import FAVICON
from vectordb_bench.interface import benchMarkRunner
from vectordb_bench.models import TestResult
@@ -30,26 +26,23 @@ def main():
drawHeaderIcon(st)
allResults = benchMarkRunner.get_results()
-
+
def check_conc_data(res: TestResult):
case_results = res.results
count = 0
for case_result in case_results:
if len(case_result.metrics.conc_num_list) > 0:
count += 1
-
+
return count > 0
-
+
checkedResults = [res for res in allResults if check_conc_data(res)]
-
st.title("VectorDB Benchmark (Concurrent Performance)")
# results selector
resultSelectorContainer = st.sidebar.container()
- shownData, _, showCases = getshownData(
- checkedResults, resultSelectorContainer)
-
+ shownData, _, showCaseNames = getshownData(checkedResults, resultSelectorContainer)
resultSelectorContainer.divider()
@@ -61,8 +54,8 @@ def check_conc_data(res: TestResult):
# save or share
resultesContainer = st.sidebar.container()
getResults(resultesContainer, "vectordb_bench_concurrent")
-
- drawChartsByCase(shownData, showCases, st.container())
+
+ drawChartsByCase(shownData, showCaseNames, st.container())
# footer
footer(st.container())
diff --git a/vectordb_bench/frontend/pages/custom.py b/vectordb_bench/frontend/pages/custom.py
new file mode 100644
index 000000000..28c249f78
--- /dev/null
+++ b/vectordb_bench/frontend/pages/custom.py
@@ -0,0 +1,64 @@
+import streamlit as st
+from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon
+from vectordb_bench.frontend.components.custom.displayCustomCase import displayCustomCase
+from vectordb_bench.frontend.components.custom.displaypPrams import displayParams
+from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig, generate_custom_case, get_custom_configs, save_custom_configs
+from vectordb_bench.frontend.components.custom.initStyle import initStyle
+from vectordb_bench.frontend.config.styles import FAVICON, PAGE_TITLE
+
+
+class CustomCaseManager():
+ customCaseItems: list[CustomCaseConfig]
+
+ def __init__(self):
+ self.customCaseItems = get_custom_configs()
+
+ def addCase(self):
+ new_custom_case = generate_custom_case()
+ new_custom_case.dataset_config.name = f"{new_custom_case.dataset_config.name} {len(self.customCaseItems)}"
+ self.customCaseItems += [new_custom_case]
+ self.save()
+
+ def deleteCase(self, idx: int):
+ self.customCaseItems.pop(idx)
+ self.save()
+
+ def save(self):
+ save_custom_configs(self.customCaseItems)
+
+
+def main():
+ st.set_page_config(
+ page_title=PAGE_TITLE,
+ page_icon=FAVICON,
+ # layout="wide",
+ # initial_sidebar_state="collapsed",
+ )
+
+ # header
+ drawHeaderIcon(st)
+
+ # init style
+ initStyle(st)
+
+ st.title("Custom Dataset")
+ displayParams(st)
+ customCaseManager = CustomCaseManager()
+
+ for idx, customCase in enumerate(customCaseManager.customCaseItems):
+ expander = st.expander(customCase.dataset_config.name, expanded=True)
+ key = f"custom_case_{idx}"
+ displayCustomCase(customCase, expander, key=key)
+
+ columns = expander.columns(8)
+ columns[0].button(
+ "Save", key=f"{key}_", type="secondary", on_click=lambda: customCaseManager.save())
+ columns[1].button(":red[Delete]", key=f"{key}_delete", type="secondary",
+ on_click=lambda: customCaseManager.deleteCase(idx))
+
+ st.button("\+ New Dataset", key=f"add_custom_configs",
+ type="primary", on_click=lambda: customCaseManager.addCase())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vectordb_bench/frontend/pages/quries_per_dollar.py b/vectordb_bench/frontend/pages/quries_per_dollar.py
index 10c1ac8f1..0bb05294b 100644
--- a/vectordb_bench/frontend/pages/quries_per_dollar.py
+++ b/vectordb_bench/frontend/pages/quries_per_dollar.py
@@ -8,7 +8,7 @@
from vectordb_bench.frontend.components.check_results.charts import drawMetricChart
from vectordb_bench.frontend.components.check_results.filters import getshownData
from vectordb_bench.frontend.components.get_results.saveAsImage import getResults
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.interface import benchMarkRunner
from vectordb_bench.metric import QURIES_PER_DOLLAR_METRIC
@@ -26,7 +26,7 @@ def main():
# results selector
resultSelectorContainer = st.sidebar.container()
- shownData, _, showCases = getshownData(allResults, resultSelectorContainer)
+ shownData, _, showCaseNames = getshownData(allResults, resultSelectorContainer)
resultSelectorContainer.divider()
@@ -45,8 +45,8 @@ def main():
priceMap = priceTable(priceTableContainer, shownData)
# charts
- for case in showCases:
- data = [data for data in shownData if data["case_name"] == case.name]
+ for caseName in showCaseNames:
+ data = [data for data in shownData if data["case_name"] == caseName]
dataWithMetric = []
metric = QURIES_PER_DOLLAR_METRIC
for d in data:
@@ -56,7 +56,7 @@ def main():
d[metric] = d["qps"] / price * 3.6
dataWithMetric.append(d)
if len(dataWithMetric) > 0:
- chartContainer = st.expander(case.name, True)
+ chartContainer = st.expander(caseName, True)
drawMetricChart(data, metric, chartContainer)
# footer
diff --git a/vectordb_bench/frontend/pages/run_test.py b/vectordb_bench/frontend/pages/run_test.py
index 0712bb6cc..1297743ae 100644
--- a/vectordb_bench/frontend/pages/run_test.py
+++ b/vectordb_bench/frontend/pages/run_test.py
@@ -5,6 +5,7 @@
from vectordb_bench.frontend.components.run_test.dbSelector import dbSelector
from vectordb_bench.frontend.components.run_test.generateTasks import generate_tasks
from vectordb_bench.frontend.components.run_test.hideSidebar import hideSidebar
+from vectordb_bench.frontend.components.run_test.initStyle import initStyle
from vectordb_bench.frontend.components.run_test.submitTask import submitTask
from vectordb_bench.frontend.components.check_results.nav import NavToResults
from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon
@@ -15,6 +16,9 @@ def main():
# set page config
initRunTestPageConfig(st)
+ # init style
+ initStyle(st)
+
# header
drawHeaderIcon(st)
diff --git a/vectordb_bench/frontend/pages/tables.py b/vectordb_bench/frontend/pages/tables.py
index a4dab68a6..c088dc930 100644
--- a/vectordb_bench/frontend/pages/tables.py
+++ b/vectordb_bench/frontend/pages/tables.py
@@ -1,7 +1,7 @@
import streamlit as st
from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon
from vectordb_bench.frontend.components.tables.data import getNewResults
-from vectordb_bench.frontend.const.styles import FAVICON
+from vectordb_bench.frontend.config.styles import FAVICON
def main():
@@ -21,4 +21,4 @@ def main():
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/vectordb_bench/frontend/utils.py b/vectordb_bench/frontend/utils.py
index 139854af6..787b67d03 100644
--- a/vectordb_bench/frontend/utils.py
+++ b/vectordb_bench/frontend/utils.py
@@ -1,6 +1,22 @@
-from vectordb_bench.models import CaseType
+import random
+import string
+
passwordKeys = ["password", "api_key"]
+
+
def inputIsPassword(key: str) -> bool:
return key.lower() in passwordKeys
+
+def addHorizontalLine(st):
+ st.markdown(
+ "",
+ unsafe_allow_html=True,
+ )
+
+
+def generate_random_string(length):
+ letters = string.ascii_letters + string.digits
+ result = ''.join(random.choice(letters) for _ in range(length))
+ return result
diff --git a/vectordb_bench/frontend/vdb_benchmark.py b/vectordb_bench/frontend/vdb_benchmark.py
index 0be43470e..b859c68b8 100644
--- a/vectordb_bench/frontend/vdb_benchmark.py
+++ b/vectordb_bench/frontend/vdb_benchmark.py
@@ -6,7 +6,7 @@
from vectordb_bench.frontend.components.check_results.charts import drawCharts
from vectordb_bench.frontend.components.check_results.filters import getshownData
from vectordb_bench.frontend.components.get_results.saveAsImage import getResults
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.interface import benchMarkRunner
@@ -24,7 +24,7 @@ def main():
# results selector and filter
resultSelectorContainer = st.sidebar.container()
- shownData, failedTasks, showCases = getshownData(
+ shownData, failedTasks, showCaseNames = getshownData(
allResults, resultSelectorContainer
)
@@ -40,7 +40,7 @@ def main():
getResults(resultesContainer, "vectordb_bench")
# charts
- drawCharts(st, shownData, failedTasks, showCases)
+ drawCharts(st, shownData, failedTasks, showCaseNames)
# footer
footer(st.container())
diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py
index aa9c930ea..56034796e 100644
--- a/vectordb_bench/models.py
+++ b/vectordb_bench/models.py
@@ -94,6 +94,10 @@ def k(self, value):
self._k = value
'''
+ def __hash__(self) -> int:
+ return hash(self.json())
+
+
class TaskStage(StrEnum):
"""Enumerations of various stages of the task"""
@@ -250,18 +254,18 @@ def append_return(x, y):
max_db = max(map(len, [f.task_config.db.name for f in filtered_results]))
max_db_labels = (
- max(map(len, [f.task_config.db_config.db_label for f in filtered_results]))
- + 3
+ max(map(len, [f.task_config.db_config.db_label for f in filtered_results]))
+ + 3
)
max_case = max(
map(len, [f.task_config.case_config.case_id.name for f in filtered_results])
)
max_load_dur = (
- max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3
+ max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3
)
max_qps = max(map(len, [str(f.metrics.qps) for f in filtered_results])) + 3
max_recall = (
- max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3
+ max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3
)
max_db_labels = 8 if max_db_labels < 8 else max_db_labels
diff --git a/vectordb_bench/results/getLeaderboardData.py b/vectordb_bench/results/getLeaderboardData.py
index 50f458533..c6484514d 100644
--- a/vectordb_bench/results/getLeaderboardData.py
+++ b/vectordb_bench/results/getLeaderboardData.py
@@ -2,7 +2,7 @@
import ujson
import pathlib
from vectordb_bench.backend.cases import CaseType
-from vectordb_bench.frontend.const.dbPrices import DB_DBLABEL_TO_PRICE
+from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE
from vectordb_bench.interface import benchMarkRunner
from vectordb_bench.models import CaseResult, ResultLabel, TestResult