Skip to content

Commit

Permalink
use multi-thread for building
Browse files Browse the repository at this point in the history
  • Loading branch information
imbajin committed Jan 15, 2025
1 parent b01574f commit 8dd5ab7
Showing 1 changed file with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,28 @@ def __init__(self, embedding: BaseEmbedding):
def _extract_names(self, vertices: list[str]) -> list[str]:
return [v.split(":")[1] for v in vertices]

def _check_primary_key(self, vertexlabels):
return all(data.get('id_strategy') == 'PRIMARY_KEY' for data in vertexlabels)
# TODO: use asyncio for IO tasks
def _get_embeddings_parallel(self, vids: list[str]) -> list[Any]:
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
embeddings = list(tqdm(executor.map(self.embedding.get_text_embedding, vids), total=len(vids)))
return embeddings

def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
flag_extract_names = self._check_primary_key(self.sm.schema.getSchema()["vertexlabels"])
vertexlabels = self.sm.schema.getSchema()["vertexlabels"]
all_pk_flag = all(data.get('id_strategy') == 'PRIMARY_KEY' for data in vertexlabels)

past_vids = self.vid_index.properties
# TODO: We should build vid vector index separately, especially when the vertices may be very large
present_vids = context["vertices"] # Warning: data truncated by fetch_graph_data.py
removed_vids = set(past_vids) - set(present_vids)
removed_num = self.vid_index.remove(removed_vids)
added_vids = list(set(present_vids) - set(past_vids))

if len(added_vids) > 0:
# TODO: We should use multi value map when meet same value. (e.g [1:tom, 2:tom, tom] in one graph)
if flag_extract_names:
extract_added_vids = self._extract_names(added_vids)
added_embeddings = [self.embedding.get_text_embedding(v) for v in tqdm(extract_added_vids)]
else:
added_embeddings = [self.embedding.get_text_embedding(v) for v in tqdm(added_vids)]
if added_vids:
vids_to_process = self._extract_names(added_vids) if all_pk_flag else added_vids
added_embeddings = self._get_embeddings_parallel(vids_to_process)
log.info("Building vector index for %s vertices...", len(added_vids))
log.info("Vector index built for %s vertices.", len(added_embeddings))
self.vid_index.add(added_embeddings, added_vids)
self.vid_index.to_index_file(self.index_dir)
else:
Expand Down

0 comments on commit 8dd5ab7

Please sign in to comment.