Skip to content

Commit

Permalink
#4 refactoring + fixes to logic + helper script
Browse files Browse the repository at this point in the history
  • Loading branch information
DmitryKey committed Oct 3, 2021
1 parent ba14663 commit 3db1039
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 30 deletions.
69 changes: 43 additions & 26 deletions src/algorithms/sharding/kanndi/shard_by_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def shard_by_dist(data_file: str, dist: float, output_index_path: str, shards_m:
complete_shards = 0

total_num_elements = get_total_nvecs_fbin(data_file)
dimensionality = get_total_dim_fbin(data_file)
# dimensionality = get_total_dim_fbin(data_file)
print(f"Total number of points to process: {total_num_elements}")
print(f"Reading data from {data_file} in {BATCH_SIZE} chunks")

Expand All @@ -63,44 +63,61 @@ def shard_by_dist(data_file: str, dist: float, output_index_path: str, shards_m:
# expected number of elements per shard
expected_shard_size = total_num_elements / shards_m

# get the seed point
points = read_fbin(data_file, start_idx=0, chunk_size=1)
seed_point_id = 0
seed_point = points[seed_point_id]
# remember the seed point
processed_point_ids.add(seed_point_id)
shard = [seed_point]
shard_ids = [seed_point_id]

need_seed_update = False

# repeat, while number of shards did not reach the target level M
while complete_shards < shards_m:
complete_shards = len(shards.keys())
#shard = np.zeros(shape=(shards_m, dimensionality))
shard = []
shard_ids = []
# step through the dataset with BATCH_SIZE window

# step through the dataset with batch by batch
for i in range(0, range_upper, BATCH_SIZE):
print(f"Processing index={i}")
points = read_fbin(data_file, start_idx=i, chunk_size=BATCH_SIZE)
# print(points.shape)
# print(points.shape[0])

# Proceed with processing the batch, if this point has not been visited previously.
# TODO: Picking the first point in each batch as the starting point for sharding might be suboptimal.
if i not in processed_point_ids:
# fix the starting point
first_point = points[0]
# mark it visited
processed_point_ids.add(i)
shard.append(first_point)
shard_ids.append(i)
## drop it from the input points
# points = np.delete(points, first_point, axis=0)
print("going inside inner loop by j over current batch of points, skipping the first point")
for j in range(1, points.shape[0]):
# id of the candidate is a combination of the running i-th batch and offset j within it
candidate_point_id = i + j
if candidate_point_id not in processed_point_ids:
dist_j = linalg.norm(first_point-points[j])

print("going inside inner loop by j over current batch of points, skipping the seed point")
for j in range(0, points.shape[0]):
if j == seed_point_id:
continue
# id of the shard candidate is a combination of the running i-th batch and offset j within it
candidate_point_id = i + j
if candidate_point_id not in processed_point_ids:
# update seed point?
if need_seed_update:
seed_point = points[j]
shard = [seed_point]
shard_ids = [i]
need_seed_update = False
else:
# seed is up to date and we continue building the shard
dist_j = linalg.norm(seed_point - points[j])
if dist_j <= dist:
processed_point_ids.add(candidate_point_id)
shard.append(points[j])
shard_ids.append(candidate_point_id)

# check if we saturated the shard
if len(shard) == expected_shard_size:
print("Saturated shard with id={}".format(i))
add_points(output_index_path, str(i), shard_ids, shard)
shards[i] = len(shard)
need_seed_update = True
break

print("Size of current shard after going through the current batch: {}".format(len(shard)))

# check if we saturated the shard
if len(shard) == expected_shard_size:
add_points(output_index_path, "shard_" + str(i), shard_ids, shard)
add_points(output_index_path, str(i), shard_ids, shard)
shards[i] = len(shard)

print("Processed points: {}".format(len(processed_point_ids)))

Expand Down
3 changes: 3 additions & 0 deletions src/run_kanndi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# fix the util package reading issue
export PYTHONPATH=.
python algorithms/sharding/kanndi/shard_by_distance.py /datadrive/big-ann-benchmarks/data/bigann.bak/base.1B.u8bin.crop_nb_100000000 /datadrive/big-ann/data/
10 changes: 6 additions & 4 deletions src/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def read_fbin(filename, start_idx=0, chunk_size=None):
nvecs = (nvecs - start_idx) if chunk_size is None else chunk_size
arr = np.fromfile(f, count=nvecs * dim, dtype=np.float32,
offset=start_idx * 4 * dim)
return arr.reshape(nvecs, dim)
if arr.size > 0:
return arr.reshape(nvecs, dim)
else:
return np.zeros(shape=(1, dim))


# by Leo Joffe
Expand Down Expand Up @@ -230,9 +233,8 @@ def add_points(path, name, ids, points):
"""
Adds a batch of points to a specific shard
"""
shardpath = shard_filename(path,name)
shardpath = shard_filename(path, name)
shard = nmslib.init(method='hnsw', space='l2')
shard.loadIndex(shardpath, load_data=True)
shard.addDataPointBatch(points, ids)
shard.createIndex(print_progress=False)
shard.saveIndex(shardpath,save_data=True)
shard.saveIndex(shardpath, save_data=True)

0 comments on commit 3db1039

Please sign in to comment.