Skip to content

Commit

Permalink
#4 1. fix for global shard id 2. switched from copy() to view()
Browse files Browse the repository at this point in the history
  • Loading branch information
DmitryKey committed Nov 15, 2021
1 parent ca94b0c commit 3e93232
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
24 changes: 14 additions & 10 deletions src/algorithms/sharding/kanndi/shard_by_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,14 @@ def shard_by_dist(data_file: str, dist: float, output_index_path: str, dtype: np
if len(points_to_resample) == SAMPLE_SIZE:
computed_dist_max = compute_median_dist(np.array(points_to_resample))
print(f"computed {computed_dist_max}", flush=True)
print("Updating median distance to this value")
print(f"Current dist value: {dist}")
if computed_dist_max > dist:
print(f"Updating median distance to this value")
dist = computed_dist_max
else:
# fallback: apply distance multiplier to increase the chances we will make this
dist = DIST_MULTIPLIER * dist
print(f"Increased the dist to 2x: {dist}", flush=True)
print(f"Increased the dist to {DIST_MULTIPLIER}x: {dist}", flush=True)
# unset the starving shard flat to actually start using this new re-sampled median distance
is_last_shard_starving = False
else:
Expand Down Expand Up @@ -225,25 +226,26 @@ def shard_by_dist(data_file: str, dist: float, output_index_path: str, dtype: np
# save the current starving shards' points only if we have them ;)
if running_shard_point_id > 0:
# TODO: apply same saturation threshold as for normal shards?
for idx, p in enumerate(shard_points):
for idx in range(0, running_shard_point_id+1):
special_shard_points[idx + running_special_shard_point_id,] = shard_points[idx]

running_special_shard_point_id = idx + 1

special_shard_point_ids.extend(shard_point_ids)
print("!!! Appended to the special_shard, its running size: {}".format(len(special_shard_points)), flush=True)
print("!!! Appended to the special_shard, its running size: {}".format(running_special_shard_point_id), flush=True)

special_shard_saturation_percent = (len(special_shard_point_ids) / expected_shard_size) * 100

if special_shard_saturation_percent > SHARD_SATURATION_PERCENT_MINIMUM:
global_shard_id += 1
if running_special_shard_point_id < expected_shard_size:
shard = Shard(special_shard_point_ids[0],
shard = Shard(global_shard_id,
special_shard_point_ids,
special_shard_points[0:running_special_shard_point_id],
size=running_special_shard_point_id,
shard_saturation_percent=special_shard_saturation_percent)
else:
shard = Shard(special_shard_point_ids[0],
shard = Shard(global_shard_id,
special_shard_point_ids,
special_shard_points,
size=running_special_shard_point_id,
Expand Down Expand Up @@ -310,27 +312,30 @@ def process_batch(centroids: List[SpacePoint], dist, expected_shard_size: int, o
running_shard_point_id = 1
shard.size = running_shard_point_id

print(f"Seed point for shard {shard.shardid}: {seed_point}")
print(f"Seed point for shard id {shard.shardid}: {seed_point}")

centroid = SpacePoint(shard.shardid, seed_point)
centroids.append(centroid)

need_seed_update = False
else:
in_loop_point_copy = in_loop_points[j].view()
# seed is up to date and we continue building the shard
points_pair[0] = centroids[-1].point
points_pair[1] = in_loop_points[j]
points_pair[1] = in_loop_point_copy
if VERBOSE:
print(f"points_pair[0]={points_pair[0]}")
print(f"points_pair[1]={points_pair[1]}")
dist_j = pdist(points_pair)

if VERBOSE:
print("got dist between seed_point and points[{}]: {}".format(j, dist_j))

if dist_j <= dist:
if VERBOSE:
print("Got a neighbor!")

shard.points[running_shard_point_id,] = in_loop_points[j]
shard.points[running_shard_point_id,] = in_loop_point_copy
shard.pointids[running_shard_point_id] = candidate_point_id
shard.size += 1
processed_point_ids[candidate_point_id] = True
Expand All @@ -352,7 +357,6 @@ def process_batch(centroids: List[SpacePoint], dist, expected_shard_size: int, o
print(f"Shards built so far: {shards} with {len(shards.keys())} keys", flush=True)
print(f"Collected {len(centroids)} centroids")
assert len(shards.keys()) == len(centroids), "Number of shards and collected centroids do not match"
continue

accumulated_points_in_shard = running_shard_point_id
# if the shard is in point collection phase
Expand Down
10 changes: 8 additions & 2 deletions src/kanndi_index.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@ export PYTHONPATH=.
###

# 100M points
#python algorithms/sharding/kanndi/shard_by_distance.py --input_file /datadrive/big-ann-benchmarks/data/bigann.bak/base.1B.u8bin.crop_nb_100000000 --output_dir /datadrive/big-ann/index/bigann/data.100M -M 100
# CPU profiling
#python -m cProfile -o /Users/dmitry/Desktop/BigANN/datasets/bigann/program.prof algorithms/sharding/kanndi/shard_by_distance.py --input_file /Users/dmitry/Desktop/BigANN/datasets/bigann/learn.100M.u8bin --output_dir /Users/dmitry/Desktop/BigANN/datasets/bigann/data.100M/ -M 100 --dtype uint8
# RAM profiling
# python -m memory_profiler algorithms/sharding/kanndi/shard_by_distance.py --input_file /Users/dmitry/Desktop/BigANN/datasets/bigann/learn.100M.u8bin --output_dir /Users/dmitry/Desktop/BigANN/datasets/bigann/data.100M/ -M 100 --dtype uint8

# 10M points
python -m memory_profiler algorithms/sharding/kanndi/shard_by_distance.py --input_file /datadrive/big-ann-benchmarks/data/bigann/base.1B.u8bin.crop_nb_10000000 --output_dir /datadrive/big-ann/index/bigann/data.10M/ -M 10 --dtype uint8

# 1B points
python algorithms/sharding/kanndi/shard_by_distance.py --input_file /datadrive/big-ann-benchmarks/data/bigann/base.1B.u8bin --output_dir /datadrive/big-ann/index/bigann/data.1B/ -M 1000 --dtype uint8
# python algorithms/sharding/kanndi/shard_by_distance.py --input_file /datadrive/big-ann-benchmarks/data/bigann/base.1B.u8bin --output_dir /datadrive/big-ann/index/bigann/data.1B/ -M 1000 --dtype uint8


###
Expand Down

0 comments on commit 3e93232

Please sign in to comment.