Skip to content

Commit

Permalink
refactor(llm): change vid embedding x:yy to yy & use multi-thread (#158)
Browse files Browse the repository at this point in the history
* remove num prefix in ok mode

* use multi-thread for building

---------

Co-authored-by: imbajin <[email protected]>
  • Loading branch information
MrJs133 and imbajin authored Jan 15, 2025
1 parent 7f7c82c commit def84a2
Showing 1 changed file with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,49 @@
from typing import Any, Dict

from tqdm import tqdm

from hugegraph_llm.config import resource_path, huge_settings
from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.indices.vector_index import VectorIndex
from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.utils.log import log

from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager

class BuildSemanticIndex:
def __init__(self, embedding: BaseEmbedding):
self.index_dir = str(os.path.join(resource_path, huge_settings.graph_name, "graph_vids"))
self.vid_index = VectorIndex.from_index_file(self.index_dir)
self.embedding = embedding
self.sm = SchemaManager(huge_settings.graph_name)

def _extract_names(self, vertices: list[str]) -> list[str]:
return [v.split(":")[1] for v in vertices]

# TODO: use asyncio for IO tasks
def _get_embeddings_parallel(self, vids: list[str]) -> list[Any]:
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
embeddings = list(tqdm(executor.map(self.embedding.get_text_embedding, vids), total=len(vids)))
return embeddings

def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
vertexlabels = self.sm.schema.getSchema()["vertexlabels"]
all_pk_flag = all(data.get('id_strategy') == 'PRIMARY_KEY' for data in vertexlabels)

past_vids = self.vid_index.properties
# TODO: We should build vid vector index separately, especially when the vertices may be very large
present_vids = context["vertices"] # Warning: data truncated by fetch_graph_data.py
removed_vids = set(past_vids) - set(present_vids)
removed_num = self.vid_index.remove(removed_vids)
added_vids = list(set(present_vids) - set(past_vids))
if len(added_vids) > 0:
log.debug("Building vector index for %s vertices...", len(added_vids))
added_embeddings = [self.embedding.get_text_embedding(v) for v in tqdm(added_vids)]
log.debug("Vector index built for %s vertices.", len(added_embeddings))

if added_vids:
vids_to_process = self._extract_names(added_vids) if all_pk_flag else added_vids
added_embeddings = self._get_embeddings_parallel(vids_to_process)
log.info("Building vector index for %s vertices...", len(added_vids))
self.vid_index.add(added_embeddings, added_vids)
self.vid_index.to_index_file(self.index_dir)
else:
log.debug("No vertices to build vector index.")
log.debug("No update vertices to build vector index.")
context.update({
"removed_vid_vector_num": removed_num,
"added_vid_vector_num": len(added_vids)
Expand Down

0 comments on commit def84a2

Please sign in to comment.