From 6c77576a404cdfd0460d01dc6ac6a6c660a5fdbf Mon Sep 17 00:00:00 2001 From: yhmo Date: Wed, 22 Nov 2023 15:24:20 +0800 Subject: [PATCH] Bulkwriter supports parquet Signed-off-by: yhmo --- examples/example_bulkwriter.py | 145 +++++++++++---------- pymilvus/bulk_writer/buffer.py | 39 +++++- pymilvus/bulk_writer/bulk_writer.py | 6 + pymilvus/bulk_writer/constants.py | 3 + pymilvus/bulk_writer/local_bulk_writer.py | 35 +++-- pymilvus/bulk_writer/remote_bulk_writer.py | 6 +- requirements.txt | 1 + 7 files changed, 150 insertions(+), 85 deletions(-) diff --git a/examples/example_bulkwriter.py b/examples/example_bulkwriter.py index 50cd00a5a..4dfc3f981 100644 --- a/examples/example_bulkwriter.py +++ b/examples/example_bulkwriter.py @@ -44,8 +44,8 @@ HOST = '127.0.0.1' PORT = '19530' -CSV_COLLECTION_NAME = "test_csv" -ALL_TYPES_COLLECTION_NAME = "test_all_types" +SIMPLE_COLLECTION_NAME = "for_bulkwriter" +ALL_TYPES_COLLECTION_NAME = "all_types_for_bulkwriter" DIM = 512 def gen_binary_vector(): @@ -62,10 +62,10 @@ def create_connection(): print(f"\nConnected") -def build_csv_collection(): +def build_simple_collection(): print(f"\n===================== create collection ====================") - if utility.has_collection(CSV_COLLECTION_NAME): - utility.drop_collection(CSV_COLLECTION_NAME) + if utility.has_collection(SIMPLE_COLLECTION_NAME): + utility.drop_collection(SIMPLE_COLLECTION_NAME) fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), @@ -74,7 +74,7 @@ def build_csv_collection(): FieldSchema(name="label", dtype=DataType.VARCHAR, max_length=512), ] schema = CollectionSchema(fields=fields) - collection = Collection(name=CSV_COLLECTION_NAME, schema=schema) + collection = Collection(name=SIMPLE_COLLECTION_NAME, schema=schema) print(f"Collection '{collection.name}' created") return collection.schema @@ -110,37 +110,59 @@ def read_sample_data(file_path: str, writer: [LocalBulkWriter, RemoteBulkWriter] writer.append_row(row) -def test_local_writer_json(schema: CollectionSchema): - print(f"\n===================== test local JSON writer ====================") +def local_writer(schema: CollectionSchema, file_type: BulkFileType): + print(f"\n===================== local writer ({file_type.name}) ====================") with LocalBulkWriter( schema=schema, local_path="/tmp/bulk_writer", - segment_size=4*1024*1024, - file_type=BulkFileType.JSON_RB, + segment_size=128*1024*1024, + file_type=file_type, ) as local_writer: + # read data from csv read_sample_data("./data/train_embeddings.csv", local_writer) + + # append rows + for i in range(100000): + local_writer.append_row({"path": f"path_{i}", "vector": gen_float_vector(), "label": f"label_{i}"}) + + print(f"{local_writer.total_row_count} rows appends") + print(f"{local_writer.buffer_row_count} rows in buffer not flushed") local_writer.commit() batch_files = local_writer.batch_files - print(f"Test local writer done! output local files: {batch_files}") + print(f"Local writer done! output local files: {batch_files}") -def test_local_writer_npy(schema: CollectionSchema): - print(f"\n===================== test local npy writer ====================") - with LocalBulkWriter( +def remote_writer(schema: CollectionSchema, file_type: BulkFileType): + print(f"\n===================== remote writer ({file_type.name}) ====================") + with RemoteBulkWriter( schema=schema, - local_path="/tmp/bulk_writer", - segment_size=4*1024*1024, - ) as local_writer: - read_sample_data("./data/train_embeddings.csv", local_writer) - local_writer.commit() - batch_files = local_writer.batch_files + remote_path="bulk_data", + connect_param=RemoteBulkWriter.ConnectParam( + endpoint=MINIO_ADDRESS, + access_key=MINIO_ACCESS_KEY, + secret_key=MINIO_SECRET_KEY, + bucket_name="a-bucket", + ), + segment_size=512 * 1024 * 1024, + file_type=file_type, + ) as remote_writer: + # read data from csv + read_sample_data("./data/train_embeddings.csv", remote_writer) + + # append rows + for i in range(10000): + remote_writer.append_row({"path": f"path_{i}", "vector": gen_float_vector(), "label": f"label_{i}"}) - print(f"Test local writer done! output local files: {batch_files}") + print(f"{remote_writer.total_row_count} rows appends") + print(f"{remote_writer.buffer_row_count} rows in buffer not flushed") + remote_writer.commit() + batch_files = remote_writer.batch_files + print(f"Remote writer done! output remote files: {batch_files}") -def test_parallel_append(schema: CollectionSchema): - print(f"\n===================== test parallel append ====================") +def parallel_append(schema: CollectionSchema): + print(f"\n===================== parallel append ====================") def _append_row(writer: LocalBulkWriter, begin: int, end: int): try: for i in range(begin, end): @@ -169,6 +191,8 @@ def _append_row(writer: LocalBulkWriter, begin: int, end: int): th.join() print(f"Thread '{th.name}' finished") + print(f"{local_writer.total_row_count} rows appends") + print(f"{local_writer.buffer_row_count} rows in buffer not flushed") local_writer.commit() print(f"Append finished, {thread_count*rows_per_thread} rows") @@ -192,8 +216,8 @@ def _append_row(writer: LocalBulkWriter, begin: int, end: int): print("Data is correct") -def test_remote_writer(schema: CollectionSchema): - print(f"\n===================== test remote writer ====================") +def all_types_writer(bin_vec: bool, schema: CollectionSchema, file_type: BulkFileType)->list: + print(f"\n===================== all field types ({file_type.name}) ====================") with RemoteBulkWriter( schema=schema, remote_path="bulk_data", @@ -203,33 +227,7 @@ def test_remote_writer(schema: CollectionSchema): secret_key=MINIO_SECRET_KEY, bucket_name="a-bucket", ), - segment_size=50 * 1024 * 1024, - ) as remote_writer: - # read data from csv - read_sample_data("./data/train_embeddings.csv", remote_writer) - remote_writer.commit() - - # append rows - for i in range(10000): - remote_writer.append_row({"path": f"path_{i}", "vector": gen_float_vector(), "label": f"label_{i}"}) - remote_writer.commit() - - batch_files = remote_writer.batch_files - - print(f"Test remote writer done! output remote files: {batch_files}") - - -def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list: - print(f"\n===================== all types test ====================") - with RemoteBulkWriter( - schema=schema, - remote_path="bulk_data", - connect_param=RemoteBulkWriter.ConnectParam( - endpoint=MINIO_ADDRESS, - access_key=MINIO_ACCESS_KEY, - secret_key=MINIO_SECRET_KEY, - bucket_name="a-bucket", - ), + file_type=file_type, ) as remote_writer: print("Append rows") batch_count = 10000 @@ -267,14 +265,16 @@ def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list: f"dynamic_{i}": i, }) + print(f"{remote_writer.total_row_count} rows appends") + print(f"{remote_writer.buffer_row_count} rows in buffer not flushed") print("Generate data files...") remote_writer.commit() print(f"Data files have been uploaded: {remote_writer.batch_files}") return remote_writer.batch_files -def test_call_bulkinsert(schema: CollectionSchema, batch_files: list): - print(f"\n===================== test call bulkinsert ====================") +def call_bulkinsert(schema: CollectionSchema, batch_files: list): + print(f"\n===================== call bulkinsert ====================") if utility.has_collection(ALL_TYPES_COLLECTION_NAME): utility.drop_collection(ALL_TYPES_COLLECTION_NAME) @@ -302,7 +302,7 @@ def test_call_bulkinsert(schema: CollectionSchema, batch_files: list): print(f"Collection row number: {collection.num_entities}") -def test_retrieve_imported_data(bin_vec: bool): +def retrieve_imported_data(bin_vec: bool): collection = Collection(name=ALL_TYPES_COLLECTION_NAME) print("Create index...") index_param = { @@ -326,7 +326,7 @@ def test_retrieve_imported_data(bin_vec: bool): for item in results: print(item) -def test_cloud_bulkinsert(): +def cloud_bulkinsert(): url = "https://_your_cloud_server_url_" api_key = "_api_key_for_the_url_" cluster_id = "_your_cloud_instance_id_" @@ -371,24 +371,31 @@ def test_cloud_bulkinsert(): if __name__ == '__main__': create_connection() - schema = build_csv_collection() - test_local_writer_json(schema) - test_local_writer_npy(schema) - test_remote_writer(schema) - test_parallel_append(schema) + file_types = [BulkFileType.JSON_RB, BulkFileType.NPY, BulkFileType.PARQUET] + + schema = build_simple_collection() + for file_type in file_types: + local_writer(schema=schema, file_type=file_type) + + for file_type in file_types: + remote_writer(schema=schema, file_type=file_type) + + parallel_append(schema) # float vectors + all scalar types schema = build_all_type_schema(bin_vec=False) - batch_files = test_all_types_writer(bin_vec=False, schema=schema) - test_call_bulkinsert(schema, batch_files) - test_retrieve_imported_data(bin_vec=False) + for file_type in file_types: + batch_files = all_types_writer(bin_vec=False, schema=schema, file_type=file_type) + call_bulkinsert(schema, batch_files) + retrieve_imported_data(bin_vec=False) # binary vectors + all scalar types schema = build_all_type_schema(bin_vec=True) - batch_files = test_all_types_writer(bin_vec=True, schema=schema) - test_call_bulkinsert(schema, batch_files) - test_retrieve_imported_data(bin_vec=True) + for file_type in file_types: + batch_files = all_types_writer(bin_vec=True, schema=schema, file_type=file_type) + call_bulkinsert(schema, batch_files) + retrieve_imported_data(bin_vec=True) - # # to test cloud bulkinsert api, you need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud) - # test_cloud_bulkinsert() + # # to call cloud bulkinsert api, you need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud) + # cloud_bulkinsert() diff --git a/pymilvus/bulk_writer/buffer.py b/pymilvus/bulk_writer/buffer.py index 09f70d061..ea46c3a34 100644 --- a/pymilvus/bulk_writer/buffer.py +++ b/pymilvus/bulk_writer/buffer.py @@ -15,6 +15,8 @@ from pathlib import Path import numpy as np +import pandas as pd +import pyarrow.parquet as pq from pymilvus.client.types import ( DataType, @@ -107,8 +109,10 @@ def persist(self, local_path: str) -> list: # output files if self._file_type == BulkFileType.NPY: return self._persist_npy(local_path) - if self._file_type == BulkFileType.JSON_RB: + elif self._file_type == BulkFileType.JSON_RB: return self._persist_json_rows(local_path) + elif self._file_type == BulkFileType.PARQUET: + return self._persist_parquet(local_path) self._throw(f"Unsupported file tpye: {self._file_type}") return [] @@ -174,3 +178,36 @@ def _persist_json_rows(self, local_path: str): logger.info(f"Successfully persist row-based file {file_path}") return [str(file_path)] + + def _persist_parquet(self, local_path: str): + file_path = Path(local_path + ".parquet") + + data = {} + for k in self._buffer: + field_schema = self._fields[k] + if field_schema.dtype == DataType.JSON: + # for JSON field, store as string array + str_arr = [] + for val in self._buffer[k]: + str_arr.append(json.dumps(val)) + data[k] = pd.Series(str_arr, dtype=None) + elif field_schema.dtype == DataType.FLOAT_VECTOR: + arr = [] + for val in self._buffer[k]: + arr.append(np.array(val, dtype=np.dtype("float32"))) + data[k] = pd.Series(arr) + elif field_schema.dtype == DataType.BINARY_VECTOR: + arr = [] + for val in self._buffer[k]: + arr.append(np.array(val, dtype=np.dtype("uint8"))) + data[k] = pd.Series(arr) + elif field_schema.dtype.name in NUMPY_TYPE_CREATOR: + dt = NUMPY_TYPE_CREATOR[field_schema.dtype.name] + data[k] = pd.Series(self._buffer[k], dtype=dt) + + # write to Parquet file + df = pd.DataFrame(data=data) + df.to_parquet(file_path, engine="pyarrow") # don't use fastparquet + + logger.info(f"Successfully persist parquet file {file_path}") + return [str(file_path)] \ No newline at end of file diff --git a/pymilvus/bulk_writer/bulk_writer.py b/pymilvus/bulk_writer/bulk_writer.py index 88bc44c9b..86255cb5d 100644 --- a/pymilvus/bulk_writer/bulk_writer.py +++ b/pymilvus/bulk_writer/bulk_writer.py @@ -43,6 +43,7 @@ def __init__( self._schema = schema self._buffer_size = 0 self._buffer_row_count = 0 + self._total_row_count = 0 self._segment_size = segment_size self._file_type = file_type self._buffer_lock = Lock() @@ -64,6 +65,10 @@ def buffer_size(self): def buffer_row_count(self): return self._buffer_row_count + @property + def total_row_count(self): + return self._total_row_count + @property def segment_size(self): return self._segment_size @@ -165,3 +170,4 @@ def _verify_row(self, row: dict): with self._buffer_lock: self._buffer_size = self._buffer_size + row_size self._buffer_row_count = self._buffer_row_count + 1 + self._total_row_count = self._total_row_count + 1 diff --git a/pymilvus/bulk_writer/constants.py b/pymilvus/bulk_writer/constants.py index ba233b979..e66a88f6b 100644 --- a/pymilvus/bulk_writer/constants.py +++ b/pymilvus/bulk_writer/constants.py @@ -46,6 +46,7 @@ DataType.JSON.name: lambda x: isinstance(x, dict) and len(x) <= 65535, DataType.FLOAT_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) == dim, DataType.BINARY_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) * 8 == dim, + DataType.ARRAY.name: lambda x: isinstance(x, list), } NUMPY_TYPE_CREATOR = { @@ -60,9 +61,11 @@ DataType.JSON.name: None, DataType.FLOAT_VECTOR.name: np.dtype("float32"), DataType.BINARY_VECTOR.name: np.dtype("uint8"), + DataType.ARRAY.name: None, } class BulkFileType(IntEnum): NPY = 1 JSON_RB = 2 + PARQUET = 3 diff --git a/pymilvus/bulk_writer/local_bulk_writer.py b/pymilvus/bulk_writer/local_bulk_writer.py index f21ac90f6..cabd4e8a2 100644 --- a/pymilvus/bulk_writer/local_bulk_writer.py +++ b/pymilvus/bulk_writer/local_bulk_writer.py @@ -45,7 +45,9 @@ def __init__( self._working_thread = {} self._working_thread_lock = Lock() self._local_files = [] - self._make_dir() + + self._mkdir_lock = Lock() + self._make_dir(make_uid=True) @property def uuid(self): @@ -61,10 +63,6 @@ def __del__(self): self._exit() def _exit(self): - # remove the uuid folder - if Path(self._local_path).exists() and not any(Path(self._local_path).iterdir()): - Path(self._local_path).rmdir() - logger.info(f"Delete empty directory '{self._local_path}'") # wait flush thread if len(self._working_thread) > 0: @@ -72,13 +70,26 @@ def _exit(self): logger.info(f"Wait flush thread '{k}' to finish") th.join() - def _make_dir(self): - Path(self._local_path).mkdir(exist_ok=True) - logger.info(f"Data path created: {self._local_path}") - uidir = Path(self._local_path).joinpath(self._uuid) - self._local_path = uidir - Path(uidir).mkdir(exist_ok=True) - logger.info(f"Data path created: {uidir}") + + self._rm_dir() + + def _make_dir(self, make_uid: bool): + with self._mkdir_lock: + Path(self._local_path).mkdir(exist_ok=True) + logger.info(f"Data path created: {self._local_path}") + + if make_uid: + uidir = Path(self._local_path).joinpath(self._uuid) + self._local_path = uidir + Path(uidir).mkdir(exist_ok=True) + logger.info(f"Data path created: {uidir}") + + def _rm_dir(self): + with self._mkdir_lock: + # remove the uuid folder if it is empty + if Path(self._local_path).exists() and not any(Path(self._local_path).iterdir()): + Path(self._local_path).rmdir() + logger.info(f"Delete local directory '{self._local_path}'") def append_row(self, row: dict, **kwargs): super().append_row(row, **kwargs) diff --git a/pymilvus/bulk_writer/remote_bulk_writer.py b/pymilvus/bulk_writer/remote_bulk_writer.py index 60635170d..376f241d0 100644 --- a/pymilvus/bulk_writer/remote_bulk_writer.py +++ b/pymilvus/bulk_writer/remote_bulk_writer.py @@ -127,9 +127,9 @@ def _local_rm(self, file: str): try: Path(file).unlink() parent_dir = Path(file).parent - if not any(Path(parent_dir).iterdir()): + if parent_dir != self._local_path and (not any(Path(parent_dir).iterdir())): Path(parent_dir).rmdir() - logger.info(f"Delete empty directory '{parent_dir!s}'") + logger.info(f"Delete empty directory '{parent_dir}'") except Exception: logger.warning(f"Failed to delete local file: {file}") @@ -144,7 +144,7 @@ def _upload(self, file_list: list): for file_path in file_list: ext = Path(file_path).suffix - if ext not in {".json", ".npy"}: + if ext not in [".json", ".npy", ".parquet"]: continue relative_file_path = str(file_path).replace(str(super().data_path), "") diff --git a/requirements.txt b/requirements.txt index d345a3cae..52ed9964a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,6 +28,7 @@ sphinxcontrib-serializinghtml sphinxcontrib-napoleon sphinxcontrib-prettyspecialmethods tqdm==4.65.0 +pyarrow>=14.0.1 pytest>=5.3.4 pytest-cov==2.8.1 pytest-timeout==1.3.4