Skip to content

Commit

Permalink
#4 1. distance_matrix -> pdist 2. added transposition before reshapin…
Browse files Browse the repository at this point in the history
…g to avoid array copy
  • Loading branch information
DmitryKey committed Oct 30, 2021
1 parent 242abdf commit bc8ac93
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/algorithms/sharding/kanndi/shard_by_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def shard_by_dist(data_file: str, dist: float, output_index_path: str, dtype:np.
need_seed_update = False
else:
# seed is up to date and we continue building the shard
dist_j = distance_matrix(np.array([seed_point]), np.array([in_loop_points[j]]))
# dist_j = distance_matrix(np.array([seed_point]), np.array([in_loop_points[j]]))
dist_j = pdist(np.array([seed_point, in_loop_points[j]]))
if VERBOSE:
print("got dist between seed_point and points[{}]: {}".format(j, dist_j))
if dist_j <= dist:
Expand Down
11 changes: 10 additions & 1 deletion src/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import nmslib
import linecache
import os
import gc


def ts():
Expand Down Expand Up @@ -118,7 +119,12 @@ def read_bin(filename, dtype, start_idx=0, chunk_size=None):
type_multiplier = 4

arr = np.fromfile(f, count=nvecs * dim, dtype=dtype, offset=start_idx * dim * type_multiplier)
return arr.reshape(-1, dim)
# Reshaping an array may or may not involve a copy. The reasons will be explained in the How it works... section.
# For instance, reshaping a 2D matrix does not involve a copy, unless it is transposed
# (or more generally, non-contiguous):
# Source: https://ipython-books.github.io/45-understanding-the-internals-of-numpy-to-avoid-unnecessary-array-copying
# return arr.reshape(-1, dim)
return arr.T.reshape(-1, dim)


def read_ibin(filename, start_idx=0, chunk_size=None):
Expand Down Expand Up @@ -192,6 +198,7 @@ def mmap_bin(filename, dtype):
print(text)
"""


# by UKPLab
# From https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/util.py (MIT License)
def pytorch_cos_sim(a: Tensor, b: Tensor):
Expand Down Expand Up @@ -250,6 +257,8 @@ def add_points(path, shard: Shard):
index.addDataPointBatch(shard.points, shard.pointids)
index.createIndex(print_progress=False)
index.saveIndex(shardpath, save_data=True)
del index
gc.collect()


# Loads index from disk
Expand Down

0 comments on commit bc8ac93

Please sign in to comment.