Skip to content

Commit

Permalink
Bulkwriter supports parquet
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo committed Nov 29, 2023
1 parent 85e6d50 commit 6c77576
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 85 deletions.
145 changes: 76 additions & 69 deletions examples/example_bulkwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
HOST = '127.0.0.1'
PORT = '19530'

CSV_COLLECTION_NAME = "test_csv"
ALL_TYPES_COLLECTION_NAME = "test_all_types"
SIMPLE_COLLECTION_NAME = "for_bulkwriter"
ALL_TYPES_COLLECTION_NAME = "all_types_for_bulkwriter"
DIM = 512

def gen_binary_vector():
Expand All @@ -62,10 +62,10 @@ def create_connection():
print(f"\nConnected")


def build_csv_collection():
def build_simple_collection():
print(f"\n===================== create collection ====================")
if utility.has_collection(CSV_COLLECTION_NAME):
utility.drop_collection(CSV_COLLECTION_NAME)
if utility.has_collection(SIMPLE_COLLECTION_NAME):
utility.drop_collection(SIMPLE_COLLECTION_NAME)

fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
Expand All @@ -74,7 +74,7 @@ def build_csv_collection():
FieldSchema(name="label", dtype=DataType.VARCHAR, max_length=512),
]
schema = CollectionSchema(fields=fields)
collection = Collection(name=CSV_COLLECTION_NAME, schema=schema)
collection = Collection(name=SIMPLE_COLLECTION_NAME, schema=schema)
print(f"Collection '{collection.name}' created")
return collection.schema

Expand Down Expand Up @@ -110,37 +110,59 @@ def read_sample_data(file_path: str, writer: [LocalBulkWriter, RemoteBulkWriter]

writer.append_row(row)

def test_local_writer_json(schema: CollectionSchema):
print(f"\n===================== test local JSON writer ====================")
def local_writer(schema: CollectionSchema, file_type: BulkFileType):
print(f"\n===================== local writer ({file_type.name}) ====================")
with LocalBulkWriter(
schema=schema,
local_path="/tmp/bulk_writer",
segment_size=4*1024*1024,
file_type=BulkFileType.JSON_RB,
segment_size=128*1024*1024,
file_type=file_type,
) as local_writer:
# read data from csv
read_sample_data("./data/train_embeddings.csv", local_writer)

# append rows
for i in range(100000):
local_writer.append_row({"path": f"path_{i}", "vector": gen_float_vector(), "label": f"label_{i}"})

print(f"{local_writer.total_row_count} rows appends")
print(f"{local_writer.buffer_row_count} rows in buffer not flushed")
local_writer.commit()
batch_files = local_writer.batch_files

print(f"Test local writer done! output local files: {batch_files}")
print(f"Local writer done! output local files: {batch_files}")


def test_local_writer_npy(schema: CollectionSchema):
print(f"\n===================== test local npy writer ====================")
with LocalBulkWriter(
def remote_writer(schema: CollectionSchema, file_type: BulkFileType):
print(f"\n===================== remote writer ({file_type.name}) ====================")
with RemoteBulkWriter(
schema=schema,
local_path="/tmp/bulk_writer",
segment_size=4*1024*1024,
) as local_writer:
read_sample_data("./data/train_embeddings.csv", local_writer)
local_writer.commit()
batch_files = local_writer.batch_files
remote_path="bulk_data",
connect_param=RemoteBulkWriter.ConnectParam(
endpoint=MINIO_ADDRESS,
access_key=MINIO_ACCESS_KEY,
secret_key=MINIO_SECRET_KEY,
bucket_name="a-bucket",
),
segment_size=512 * 1024 * 1024,
file_type=file_type,
) as remote_writer:
# read data from csv
read_sample_data("./data/train_embeddings.csv", remote_writer)

# append rows
for i in range(10000):
remote_writer.append_row({"path": f"path_{i}", "vector": gen_float_vector(), "label": f"label_{i}"})

print(f"Test local writer done! output local files: {batch_files}")
print(f"{remote_writer.total_row_count} rows appends")
print(f"{remote_writer.buffer_row_count} rows in buffer not flushed")
remote_writer.commit()
batch_files = remote_writer.batch_files

print(f"Remote writer done! output remote files: {batch_files}")

def test_parallel_append(schema: CollectionSchema):
print(f"\n===================== test parallel append ====================")
def parallel_append(schema: CollectionSchema):
print(f"\n===================== parallel append ====================")
def _append_row(writer: LocalBulkWriter, begin: int, end: int):
try:
for i in range(begin, end):
Expand Down Expand Up @@ -169,6 +191,8 @@ def _append_row(writer: LocalBulkWriter, begin: int, end: int):
th.join()
print(f"Thread '{th.name}' finished")

print(f"{local_writer.total_row_count} rows appends")
print(f"{local_writer.buffer_row_count} rows in buffer not flushed")
local_writer.commit()
print(f"Append finished, {thread_count*rows_per_thread} rows")

Expand All @@ -192,8 +216,8 @@ def _append_row(writer: LocalBulkWriter, begin: int, end: int):
print("Data is correct")


def test_remote_writer(schema: CollectionSchema):
print(f"\n===================== test remote writer ====================")
def all_types_writer(bin_vec: bool, schema: CollectionSchema, file_type: BulkFileType)->list:
print(f"\n===================== all field types ({file_type.name}) ====================")
with RemoteBulkWriter(
schema=schema,
remote_path="bulk_data",
Expand All @@ -203,33 +227,7 @@ def test_remote_writer(schema: CollectionSchema):
secret_key=MINIO_SECRET_KEY,
bucket_name="a-bucket",
),
segment_size=50 * 1024 * 1024,
) as remote_writer:
# read data from csv
read_sample_data("./data/train_embeddings.csv", remote_writer)
remote_writer.commit()

# append rows
for i in range(10000):
remote_writer.append_row({"path": f"path_{i}", "vector": gen_float_vector(), "label": f"label_{i}"})
remote_writer.commit()

batch_files = remote_writer.batch_files

print(f"Test remote writer done! output remote files: {batch_files}")


def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list:
print(f"\n===================== all types test ====================")
with RemoteBulkWriter(
schema=schema,
remote_path="bulk_data",
connect_param=RemoteBulkWriter.ConnectParam(
endpoint=MINIO_ADDRESS,
access_key=MINIO_ACCESS_KEY,
secret_key=MINIO_SECRET_KEY,
bucket_name="a-bucket",
),
file_type=file_type,
) as remote_writer:
print("Append rows")
batch_count = 10000
Expand Down Expand Up @@ -267,14 +265,16 @@ def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list:
f"dynamic_{i}": i,
})

print(f"{remote_writer.total_row_count} rows appends")
print(f"{remote_writer.buffer_row_count} rows in buffer not flushed")
print("Generate data files...")
remote_writer.commit()
print(f"Data files have been uploaded: {remote_writer.batch_files}")
return remote_writer.batch_files


def test_call_bulkinsert(schema: CollectionSchema, batch_files: list):
print(f"\n===================== test call bulkinsert ====================")
def call_bulkinsert(schema: CollectionSchema, batch_files: list):
print(f"\n===================== call bulkinsert ====================")
if utility.has_collection(ALL_TYPES_COLLECTION_NAME):
utility.drop_collection(ALL_TYPES_COLLECTION_NAME)

Expand Down Expand Up @@ -302,7 +302,7 @@ def test_call_bulkinsert(schema: CollectionSchema, batch_files: list):
print(f"Collection row number: {collection.num_entities}")


def test_retrieve_imported_data(bin_vec: bool):
def retrieve_imported_data(bin_vec: bool):
collection = Collection(name=ALL_TYPES_COLLECTION_NAME)
print("Create index...")
index_param = {
Expand All @@ -326,7 +326,7 @@ def test_retrieve_imported_data(bin_vec: bool):
for item in results:
print(item)

def test_cloud_bulkinsert():
def cloud_bulkinsert():
url = "https://_your_cloud_server_url_"
api_key = "_api_key_for_the_url_"
cluster_id = "_your_cloud_instance_id_"
Expand Down Expand Up @@ -371,24 +371,31 @@ def test_cloud_bulkinsert():
if __name__ == '__main__':
create_connection()

schema = build_csv_collection()
test_local_writer_json(schema)
test_local_writer_npy(schema)
test_remote_writer(schema)
test_parallel_append(schema)
file_types = [BulkFileType.JSON_RB, BulkFileType.NPY, BulkFileType.PARQUET]

schema = build_simple_collection()
for file_type in file_types:
local_writer(schema=schema, file_type=file_type)

for file_type in file_types:
remote_writer(schema=schema, file_type=file_type)

parallel_append(schema)

# float vectors + all scalar types
schema = build_all_type_schema(bin_vec=False)
batch_files = test_all_types_writer(bin_vec=False, schema=schema)
test_call_bulkinsert(schema, batch_files)
test_retrieve_imported_data(bin_vec=False)
for file_type in file_types:
batch_files = all_types_writer(bin_vec=False, schema=schema, file_type=file_type)
call_bulkinsert(schema, batch_files)
retrieve_imported_data(bin_vec=False)

# binary vectors + all scalar types
schema = build_all_type_schema(bin_vec=True)
batch_files = test_all_types_writer(bin_vec=True, schema=schema)
test_call_bulkinsert(schema, batch_files)
test_retrieve_imported_data(bin_vec=True)
for file_type in file_types:
batch_files = all_types_writer(bin_vec=True, schema=schema, file_type=file_type)
call_bulkinsert(schema, batch_files)
retrieve_imported_data(bin_vec=True)

# # to test cloud bulkinsert api, you need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud)
# test_cloud_bulkinsert()
# # to call cloud bulkinsert api, you need to apply a cloud service from Zilliz Cloud(https://zilliz.com/cloud)
# cloud_bulkinsert()

39 changes: 38 additions & 1 deletion pymilvus/bulk_writer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from pathlib import Path

import numpy as np
import pandas as pd
import pyarrow.parquet as pq

from pymilvus.client.types import (
DataType,
Expand Down Expand Up @@ -107,8 +109,10 @@ def persist(self, local_path: str) -> list:
# output files
if self._file_type == BulkFileType.NPY:
return self._persist_npy(local_path)
if self._file_type == BulkFileType.JSON_RB:
elif self._file_type == BulkFileType.JSON_RB:
return self._persist_json_rows(local_path)
elif self._file_type == BulkFileType.PARQUET:
return self._persist_parquet(local_path)

self._throw(f"Unsupported file tpye: {self._file_type}")
return []
Expand Down Expand Up @@ -174,3 +178,36 @@ def _persist_json_rows(self, local_path: str):

logger.info(f"Successfully persist row-based file {file_path}")
return [str(file_path)]

def _persist_parquet(self, local_path: str):
file_path = Path(local_path + ".parquet")

data = {}
for k in self._buffer:
field_schema = self._fields[k]
if field_schema.dtype == DataType.JSON:
# for JSON field, store as string array
str_arr = []
for val in self._buffer[k]:
str_arr.append(json.dumps(val))
data[k] = pd.Series(str_arr, dtype=None)
elif field_schema.dtype == DataType.FLOAT_VECTOR:
arr = []
for val in self._buffer[k]:
arr.append(np.array(val, dtype=np.dtype("float32")))
data[k] = pd.Series(arr)
elif field_schema.dtype == DataType.BINARY_VECTOR:
arr = []
for val in self._buffer[k]:
arr.append(np.array(val, dtype=np.dtype("uint8")))
data[k] = pd.Series(arr)
elif field_schema.dtype.name in NUMPY_TYPE_CREATOR:
dt = NUMPY_TYPE_CREATOR[field_schema.dtype.name]
data[k] = pd.Series(self._buffer[k], dtype=dt)

# write to Parquet file
df = pd.DataFrame(data=data)
df.to_parquet(file_path, engine="pyarrow") # don't use fastparquet

logger.info(f"Successfully persist parquet file {file_path}")
return [str(file_path)]
6 changes: 6 additions & 0 deletions pymilvus/bulk_writer/bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
self._schema = schema
self._buffer_size = 0
self._buffer_row_count = 0
self._total_row_count = 0
self._segment_size = segment_size
self._file_type = file_type
self._buffer_lock = Lock()
Expand All @@ -64,6 +65,10 @@ def buffer_size(self):
def buffer_row_count(self):
return self._buffer_row_count

@property
def total_row_count(self):
return self._total_row_count

@property
def segment_size(self):
return self._segment_size
Expand Down Expand Up @@ -165,3 +170,4 @@ def _verify_row(self, row: dict):
with self._buffer_lock:
self._buffer_size = self._buffer_size + row_size
self._buffer_row_count = self._buffer_row_count + 1
self._total_row_count = self._total_row_count + 1
3 changes: 3 additions & 0 deletions pymilvus/bulk_writer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
DataType.JSON.name: lambda x: isinstance(x, dict) and len(x) <= 65535,
DataType.FLOAT_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) == dim,
DataType.BINARY_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) * 8 == dim,
DataType.ARRAY.name: lambda x: isinstance(x, list),
}

NUMPY_TYPE_CREATOR = {
Expand All @@ -60,9 +61,11 @@
DataType.JSON.name: None,
DataType.FLOAT_VECTOR.name: np.dtype("float32"),
DataType.BINARY_VECTOR.name: np.dtype("uint8"),
DataType.ARRAY.name: None,
}


class BulkFileType(IntEnum):
NPY = 1
JSON_RB = 2
PARQUET = 3
Loading

0 comments on commit 6c77576

Please sign in to comment.