Skip to content

Commit

Permalink
experiment using langchain to generate entities embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Feb 6, 2025
1 parent 462c942 commit 716797a
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 19 deletions.
2 changes: 2 additions & 0 deletions packages/expasy-agent/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ dependencies = [
"langchain-openai >=0.1.22",
"langchain-azure-ai >=0.1.0",
"langchain-groq >=0.2.4",
# "langchain-huggingface", # This will install torch and many heavy nvidia dependencies
"langchain-together",
"langchain-deepseek-official >=0.1.0",
# "langchain-anthropic >=0.1.23",
# "langchain-fireworks >=0.1.7",
Expand Down
3 changes: 1 addition & 2 deletions packages/expasy-agent/src/expasy_agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ class Settings(BaseSettings):
# Settings for the vector store
# https://qdrant.github.io/fastembed/examples/Supported_Models/
embedding_model: str = "BAAI/bge-large-en-v1.5"
# embedding_dimensions: int = 1024
# sparse_embedding_model: str = "Qdrant/bm25"
sparse_embedding_model: str = "Qdrant/bm25"
vectordb_url: str = "http://vectordb:6334/"
docs_collection_name: str = "expasy"
entities_collection_name: str = "entities"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def init_vectordb() -> None:
url=settings.vectordb_url,
collection_name=settings.docs_collection_name,
embedding=make_text_encoder(settings.embedding_model),
# sparse_embedding=FastEmbedSparse(model_name="Qdrant/bm25"),
# sparse_embedding=FastEmbedSparse(model_name=settings.sparse_embedding_model),
# retrieval_mode=RetrievalMode.HYBRID,
prefer_grpc=True,
force_recreate=True,
Expand Down
49 changes: 33 additions & 16 deletions packages/expasy-agent/src/expasy_agent/indexing/index_entities.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import os
from http import client
import time

from langchain_core.documents import Document
from langchain_qdrant import FastEmbedSparse, QdrantVectorStore, RetrievalMode
from qdrant_client import QdrantClient, models
from sparql_llm.utils import query_sparql

from expasy_agent.config import get_embedding_model, settings
from expasy_agent.nodes.retrieval import make_text_encoder

# NOTE: Run the script to extract entities from endpoints and generate embeddings for them (long):
# ssh adsicore
# cd /mnt/scratch/sparql-llm/packages/expasy-agent
# nohup uv run --extra gpu src/expasy_agent/indexing/index_entities.py &


entities_embeddings_dir = os.path.join("data", "embeddings")
entities_embeddings_filepath = os.path.join(entities_embeddings_dir, "entities_embeddings.csv")
# entities_embeddings_dir = os.path.join("data", "embeddings")
# entities_embeddings_filepath = os.path.join(entities_embeddings_dir, "entities_embeddings.csv")


def retrieve_index_data(entity: dict, docs: list[Document], pagination: (int, int) = None):
Expand Down Expand Up @@ -256,31 +258,46 @@ def generate_embeddings_for_entities():
# entities_res = query_sparql(entity["query"], entity["endpoint"])["results"]["bindings"]
# print(f"Found {len(entities_res)} entities for {entity['label']} in {entity['endpoint']}")

print(f"Done querying in {time.time() - start_time} seconds, generating embeddings for {len(docs)} entities")
print(f"Done querying SPARQL endpoints in {(time.time() - start_time) / 60:.2f} minutes, generating embeddings for {len(docs)} entities...")

# Uncomment the next line to test with a smaller number of entities
# docs = docs[:10]

embeddings = embedding_model.embed([q.page_content for q in docs])

if not os.path.exists(entities_embeddings_dir):
os.makedirs(entities_embeddings_dir)

vectordb_local = QdrantClient(
path="data/qdrant",
# host=host,
prefer_grpc=True,
)
vectordb_local.upsert(

# Using LangChain
QdrantVectorStore.from_documents(
docs,
# url=settings.vectordb_url,
client=vectordb_local,
collection_name=settings.entities_collection_name,
points=models.Batch(
ids=list(range(1, len(docs) + 1)),
vectors=embeddings,
payloads=[doc.metadata for doc in docs],
),
embedding=make_text_encoder(settings.embedding_model),
sparse_embedding=FastEmbedSparse(model_name=settings.sparse_embedding_model),
retrieval_mode=RetrievalMode.HYBRID,
prefer_grpc=True,
force_recreate=True,
)

print(f"Done generating embeddings for {len(docs)} entities in {(time.time() - start_time) / 60:.2f} minutes")
# Directly using Qdrant client
# embeddings = embedding_model.embed([q.page_content for q in docs])
# vectordb_local.recreate_collection(
# collection_name="demo_collection",
# vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
# )
# vectordb_local.upsert(
# collection_name=settings.entities_collection_name,
# points=models.Batch(
# ids=list(range(1, len(docs) + 1)),
# vectors=embeddings,
# payloads=[doc.metadata for doc in docs],
# ),
# )

print(f"Done generating and indexing embeddings in collection {settings.entities_collection_name} for {len(docs)} entities in {(time.time() - start_time) / 60:.2f} minutes")


if __name__ == "__main__":
Expand Down
24 changes: 24 additions & 0 deletions packages/expasy-agent/src/expasy_agent/nodes/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,30 @@ def load_chat_model(configuration: Configuration) -> BaseChatModel:
temperature=configuration.temperature,
)

if provider == "together":
# https://python.langchain.com/docs/integrations/chat/together/
from langchain_together import ChatTogether
return ChatTogether(
# model="meta-llama/Llama-3-70b-chat-hf",
model=model_name,
temperature=configuration.temperature,
max_tokens=configuration.max_tokens,
timeout=None,
max_retries=2,
)

if provider == "hf":
# https://python.langchain.com/docs/integrations/chat/huggingface/
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
return ChatHuggingFace(llm=HuggingFaceEndpoint(
# repo_id="HuggingFaceH4/zephyr-7b-beta",
repo_id=model_name,
task="text-generation",
max_new_tokens=configuration.max_tokens,
do_sample=False,
repetition_penalty=1.03,
))

if provider == "azure":
# https://learn.microsoft.com/en-us/azure/ai-studio/how-to/develop/langchain
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
Expand Down

0 comments on commit 716797a

Please sign in to comment.