Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support custom_dataset #320

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,24 @@ Case No. | Case Type | Dataset Size | Filtering Rate | Results |

Each case provides an in-depth examination of a vector database's abilities, providing you a comprehensive view of the database's performance.

#### Custom Dataset for Performance case

Through the `/custom` page, users can customize their own performance case using local datasets. After saving, the corresponding case can be selected from the `/run_test` page to perform the test.

![image](fig/custom_dataset.png)
![image](fig/custom_case_run_test.png)

We have strict requirements for the data set format, please follow them.
- `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
- Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
- Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
- Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`.

- `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files.

- `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.


## Goals
Our goals of this benchmark are:
### Reproducibility & Usability
Expand Down
Binary file added fig/custom_case_run_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added fig/custom_dataset.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,5 @@ zilliz_cloud = []
[project.scripts]
init_bench = "vectordb_bench.__main__:main"
vectordbbench = "vectordb_bench.cli.vectordbbench:cli"

[tool.setuptools_scm]
2 changes: 2 additions & 0 deletions vectordb_bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class config:


K_DEFAULT = 100 # default return top k nearest neighbors during search
RESULTS_LOCAL_DIR = pathlib.Path(__file__).parent.joinpath("results")
alwayslove2013 marked this conversation as resolved.
Show resolved Hide resolved
CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")

CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h
Expand Down
2 changes: 1 addition & 1 deletion vectordb_bench/backend/assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Assembler:
def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
c_cls = task.case_config.case_id.case_cls

c = c_cls()
c = c_cls(task.case_config.custom_case)
if type(task.db_case_config) != EmptyDBCaseConfig:
task.db_case_config.metric_type = c.dataset.data.metric_type

Expand Down
82 changes: 64 additions & 18 deletions vectordb_bench/backend/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from typing import Type

from vectordb_bench import config
from vectordb_bench.backend.clients.api import MetricType
from vectordb_bench.base import BaseModel
from vectordb_bench.frontend.components.custom.getCustomConfig import (
CustomDatasetConfig,
)

from .dataset import Dataset, DatasetManager
from .dataset import CustomDataset, Dataset, DatasetManager


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,25 +48,24 @@ class CaseType(Enum):
Performance1536D50K = 50

Custom = 100
PerformanceCustomDataset = 101

@property
def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]:
if self not in type2case:
raise NotImplementedError(f"Case {self} has not implemented. You can add it manually to vectordb_bench.backend.cases.type2case or define a custom_configs['custom_cls']")
return type2case[self]
if custom_configs is None:
return type2case.get(self)()
else:
return type2case.get(self)(**custom_configs)

@property
def case_name(self) -> str:
c = self.case_cls
def case_name(self, custom_configs: dict | None = None) -> str:
c = self.case_cls(custom_configs)
if c is not None:
return c().name
return c.name
raise ValueError("Case unsupported")

@property
def case_description(self) -> str:
c = self.case_cls
def case_description(self, custom_configs: dict | None = None) -> str:
c = self.case_cls(custom_configs)
if c is not None:
return c().description
return c.description
raise ValueError("Case unsupported")


Expand Down Expand Up @@ -289,26 +292,69 @@ class Performance1536D50K(PerformanceCase):
optimize_timeout: float | int | None = 15 * 60


def metric_type_map(s: str) -> MetricType:
if s.lower() == "cosine":
return MetricType.COSINE
if s.lower() == "l2" or s.lower() == "euclidean":
return MetricType.L2
if s.lower() == "ip":
return MetricType.IP
err_msg = f"Not support metric_type: {s}"
log.error(err_msg)
raise RuntimeError(err_msg)


class PerformanceCustomDataset(PerformanceCase):
case_id: CaseType = CaseType.PerformanceCustomDataset
name: str = "Performance With Custom Dataset"
description: str = ""
dataset: DatasetManager

def __init__(
self,
name,
description,
load_timeout,
optimize_timeout,
dataset_config,
**kwargs,
):
dataset_config = CustomDatasetConfig(**dataset_config)
dataset = CustomDataset(
name=dataset_config.name,
size=dataset_config.size,
dim=dataset_config.dim,
metric_type=metric_type_map(dataset_config.metric_type),
use_shuffled=dataset_config.use_shuffled,
with_gt=dataset_config.with_gt,
dir=dataset_config.dir,
file_num=dataset_config.file_count,
)
super().__init__(
name=name,
description=description,
load_timeout=load_timeout,
optimize_timeout=optimize_timeout,
dataset=DatasetManager(data=dataset),
)


type2case = {
CaseType.CapacityDim960: CapacityDim960,
CaseType.CapacityDim128: CapacityDim128,

CaseType.Performance768D100M: Performance768D100M,
CaseType.Performance768D10M: Performance768D10M,
CaseType.Performance768D1M: Performance768D1M,

CaseType.Performance768D10M1P: Performance768D10M1P,
CaseType.Performance768D1M1P: Performance768D1M1P,
CaseType.Performance768D10M99P: Performance768D10M99P,
CaseType.Performance768D1M99P: Performance768D1M99P,

CaseType.Performance1536D500K: Performance1536D500K,
CaseType.Performance1536D5M: Performance1536D5M,

CaseType.Performance1536D500K1P: Performance1536D500K1P,
CaseType.Performance1536D5M1P: Performance1536D5M1P,

CaseType.Performance1536D500K99P: Performance1536D500K99P,
CaseType.Performance1536D5M99P: Performance1536D5M99P,
CaseType.Performance1536D50K: Performance1536D50K,
CaseType.PerformanceCustomDataset: PerformanceCustomDataset,
}
32 changes: 27 additions & 5 deletions vectordb_bench/backend/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class BaseDataset(BaseModel):
use_shuffled: bool
with_gt: bool = False
_size_label: dict[int, SizeLabel] = PrivateAttr()
isCustom: bool = False

@validator("size")
def verify_size(cls, v):
Expand All @@ -52,7 +53,27 @@ def dir_name(self) -> str:
def file_count(self) -> int:
return self._size_label.get(self.size).file_count

class CustomDataset(BaseDataset):
dir: str
file_num: int
isCustom: bool = True

@validator("size")
def verify_size(cls, v):
return v

@property
def label(self) -> str:
return "Custom"

@property
def dir_name(self) -> str:
return self.dir

@property
def file_count(self) -> int:
return self.file_num

class LAION(BaseDataset):
name: str = "LAION"
dim: int = 768
Expand Down Expand Up @@ -186,11 +207,12 @@ def prepare(self,
gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
all_files.extend([gt_file, test_file])

source.reader().read(
dataset=self.data.dir_name.lower(),
files=all_files,
local_ds_root=self.data_dir,
)
if not self.data.isCustom:
source.reader().read(
dataset=self.data.dir_name.lower(),
files=all_files,
local_ds_root=self.data_dir,
)

if gt_file is not None and test_file is not None:
self.test_data = self._read_file(test_file)
Expand Down
18 changes: 18 additions & 0 deletions vectordb_bench/custom/custom_case.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[
{
"name": "My Dataset (Performace Case)",
"description": "this is a customized dataset.",
"load_timeout": 36000,
"optimize_timeout": 36000,
"dataset_config": {
"name": "My Dataset",
"dir": "/my_dataset_path",
"size": 1000000,
"dim": 1024,
"metric_type": "L2",
"file_count": 1,
"use_shuffled": false,
"with_gt": true
}
}
]
12 changes: 6 additions & 6 deletions vectordb_bench/frontend/components/check_results/charts.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from vectordb_bench.backend.cases import Case
from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
from vectordb_bench.frontend.const.styles import *
from vectordb_bench.frontend.config.styles import *
from vectordb_bench.models import ResultLabel
import plotly.express as px


def drawCharts(st, allData, failedTasks, cases: list[Case]):
def drawCharts(st, allData, failedTasks, caseNames: list[str]):
initMainExpanderStyle(st)
for case in cases:
chartContainer = st.expander(case.name, True)
data = [data for data in allData if data["case_name"] == case.name]
for caseName in caseNames:
chartContainer = st.expander(caseName, True)
data = [data for data in allData if data["case_name"] == caseName]
drawChart(data, chartContainer)

errorDBs = failedTasks[case.name]
errorDBs = failedTasks[caseName]
showFailedDBs(chartContainer, errorDBs)


Expand Down
24 changes: 12 additions & 12 deletions vectordb_bench/frontend/components/check_results/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,23 @@
def getChartData(
tasks: list[CaseResult],
dbNames: list[str],
cases: list[Case],
caseNames: list[str],
):
filterTasks = getFilterTasks(tasks, dbNames, cases)
filterTasks = getFilterTasks(tasks, dbNames, caseNames)
mergedTasks, failedTasks = mergeTasks(filterTasks)
return mergedTasks, failedTasks


def getFilterTasks(
tasks: list[CaseResult],
dbNames: list[str],
cases: list[Case],
caseNames: list[str],
) -> list[CaseResult]:
case_ids = [case.case_id for case in cases]
filterTasks = [
task
for task in tasks
if task.task_config.db_name in dbNames
and task.task_config.case_config.case_id in case_ids
and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
]
return filterTasks

Expand All @@ -36,29 +35,29 @@ def mergeTasks(tasks: list[CaseResult]):
db_name = task.task_config.db_name
db = task.task_config.db.value
db_label = task.task_config.db_config.db_label or ""
case_id = task.task_config.case_config.case_id
dbCaseMetricsMap[db_name][case_id] = {
case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
dbCaseMetricsMap[db_name][case.name] = {
"db": db,
"db_label": db_label,
"metrics": mergeMetrics(
dbCaseMetricsMap[db_name][case_id].get("metrics", {}),
dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
asdict(task.metrics),
),
"label": getBetterLabel(
dbCaseMetricsMap[db_name][case_id].get("label", ResultLabel.FAILED),
dbCaseMetricsMap[db_name][case.name].get(
"label", ResultLabel.FAILED),
task.label,
),
}

mergedTasks = []
failedTasks = defaultdict(lambda: defaultdict(str))
for db_name, caseMetricsMap in dbCaseMetricsMap.items():
for case_id, metricInfo in caseMetricsMap.items():
for case_name, metricInfo in caseMetricsMap.items():
metrics = metricInfo["metrics"]
db = metricInfo["db"]
db_label = metricInfo["db_label"]
label = metricInfo["label"]
case_name = case_id.case_name
if label == ResultLabel.NORMAL:
mergedTasks.append(
{
Expand All @@ -80,7 +79,8 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
metrics = {**metrics_1}
for key, value in metrics_2.items():
metrics[key] = (
getBetterMetric(key, value, metrics[key]) if key in metrics else value
getBetterMetric(
key, value, metrics[key]) if key in metrics else value
)

return metrics
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def initMainExpanderStyle(st):
st.markdown(
"""<style>
.main .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;}
.main div[data-testid='stExpander'] p {font-size: 18px; font-weight: 600;}
.main div[data-testid='stExpander'] {
background-color: #F6F8FA;
border: 1px solid #A9BDD140;
Expand Down
Loading