From def84a2c1ec9cc7ee18c0ac9ec4e80d371ce5044 Mon Sep 17 00:00:00 2001 From: SoJGooo <102796027+MrJs133@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:10:44 +0800 Subject: [PATCH] refactor(llm): change vid embedding x:yy to yy & use multi-thread (#158) * remove num prefix in ok mode * use multi-thread for building --------- Co-authored-by: imbajin --- .../index_op/build_semantic_index.py | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py index c8ce907b..b2bff495 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py @@ -20,33 +20,49 @@ from typing import Any, Dict from tqdm import tqdm + from hugegraph_llm.config import resource_path, huge_settings -from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.indices.vector_index import VectorIndex +from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.utils.log import log - +from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager class BuildSemanticIndex: def __init__(self, embedding: BaseEmbedding): self.index_dir = str(os.path.join(resource_path, huge_settings.graph_name, "graph_vids")) self.vid_index = VectorIndex.from_index_file(self.index_dir) self.embedding = embedding + self.sm = SchemaManager(huge_settings.graph_name) + + def _extract_names(self, vertices: list[str]) -> list[str]: + return [v.split(":")[1] for v in vertices] + + # TODO: use asyncio for IO tasks + def _get_embeddings_parallel(self, vids: list[str]) -> list[Any]: + from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor() as executor: + embeddings = list(tqdm(executor.map(self.embedding.get_text_embedding, vids), total=len(vids))) + return embeddings def run(self, context: Dict[str, Any]) -> Dict[str, Any]: + vertexlabels = self.sm.schema.getSchema()["vertexlabels"] + all_pk_flag = all(data.get('id_strategy') == 'PRIMARY_KEY' for data in vertexlabels) + past_vids = self.vid_index.properties # TODO: We should build vid vector index separately, especially when the vertices may be very large present_vids = context["vertices"] # Warning: data truncated by fetch_graph_data.py removed_vids = set(past_vids) - set(present_vids) removed_num = self.vid_index.remove(removed_vids) added_vids = list(set(present_vids) - set(past_vids)) - if len(added_vids) > 0: - log.debug("Building vector index for %s vertices...", len(added_vids)) - added_embeddings = [self.embedding.get_text_embedding(v) for v in tqdm(added_vids)] - log.debug("Vector index built for %s vertices.", len(added_embeddings)) + + if added_vids: + vids_to_process = self._extract_names(added_vids) if all_pk_flag else added_vids + added_embeddings = self._get_embeddings_parallel(vids_to_process) + log.info("Building vector index for %s vertices...", len(added_vids)) self.vid_index.add(added_embeddings, added_vids) self.vid_index.to_index_file(self.index_dir) else: - log.debug("No vertices to build vector index.") + log.debug("No update vertices to build vector index.") context.update({ "removed_vid_vector_num": removed_num, "added_vid_vector_num": len(added_vids)