Skip to content

Commit

Permalink
Improve entities extraction by using spacy
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Feb 5, 2025
1 parent 44c1f01 commit 462c942
Show file tree
Hide file tree
Showing 15 changed files with 129 additions and 169 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ Requirements: Docker, nodejs (to build the frontend), and optionally [`uv`](http
> **Experimental entities indexing**: it can take a lot of time to generate embeddings for entities. So we recommend to run the script to generate embeddings on a machine with GPU (does not need to be a powerful one, but at least with a GPU, checkout [fastembed GPU docs](https://qdrant.github.io/fastembed/examples/FastEmbed_GPU/) to install the GPU drivers and dependencies)
>
> ```sh
> uv run --with gpu packages/expasy-agent/src/expasy_agent/indexing/index_entities.py
> cd packages/expasy-agent
> nohup uv run --extra gpu src/expasy_agent/indexing/index_entities.py &
> ```
>
> Then move the entities collection containing the embeddings in `data/qdrant/collections/entities` before starting the stack
Expand Down
2 changes: 1 addition & 1 deletion chat-with-context/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ <h2 class="text-xl text-center font-semibold border-b pb-2">
<chat-with-context
chat-endpoint="http://localhost:8000/chat"
feedback-endpoint="http://localhost:8000/feedback"
api-key="%EXPASY_API_KEY%"
api-key="%CHAT_API_KEY%"
model="openai/gpt-4o-mini"
examples="Which resources are available at the SIB?,How can I get the HGNC symbol for the protein P68871?,What are the rat orthologs of the human TP53?,Where is expressed the gene ACE2 in human?,Anatomical entities where the INS zebrafish gene is expressed and its gene GO annotations,List the genes in primates orthologous to genes expressed in the fruit fly eye"
></chat-with-context>
Expand Down
26 changes: 12 additions & 14 deletions chat-with-context/src/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,37 +116,33 @@ async function processLangGraphChunk(state: ChatState, chunk: any) {
throw new Error(`An error occurred. Please try again. ${chunk.data.error}: ${chunk.data.message}`);
}
// console.log(chunk);
// Handle updates to the state
// Handle updates to the state (nodes that are retrieving stuff without querying the LLM usually)
if (chunk.event === "updates") {
// console.log("UPDATES", chunk);
for (const nodeId of Object.keys(chunk.data)) {
const nodeData = chunk.data[nodeId];
// Retrieved docs sent
if (nodeData.retrieved_docs) {
// Retrieved docs sent
state.appendStepToLastMsg(
`📚️ Using ${nodeData.retrieved_docs.length} documents`,
nodeId,
nodeData.retrieved_docs,
);
}
// Handle entities extracted from the user input
if (nodeData.extracted_entities) {
} else if (nodeData.extracted_entities) {
// Handle entities extracted from the user input
state.appendStepToLastMsg(
`⚗️ Extracted ${nodeData.extracted_entities.length} potential entities`,
nodeId,
[],
nodeData.extracted_entities.map((entity: any) =>
`\n\nEntities found in the user question for "${entity.term.join(" ")}":\n\n` +
`\n\nEntities found in the user question for "${entity.text}":\n\n` +
entity.matchs.map((match: any) =>
`- ${match.payload.label} with IRI <${match.payload.uri}> in endpoint ${match.payload.endpoint_url}\n\n`
).join('')
).join(''),
);
}
// Handle things extracted from output (SPARQL queries here)
if (nodeData.structured_output) {
// const extractedEntities = nodeData.extracted_entities;
// Retrieved extracted_entities sent
} else if (nodeData.structured_output) {
// Handle things extracted from output (SPARQL queries here)
if (nodeData.structured_output.sparql_query) {
state.lastMsg().setLinks([
{
Expand All @@ -158,9 +154,8 @@ async function processLangGraphChunk(state: ChatState, chunk: any) {
},
]);
}
}
// Handle post-generation validation
if (nodeData.validation) {
} else if (nodeData.validation) {
// Handle post-generation validation
for (const validationStep of nodeData.validation) {
// Handle messages related to tools (includes post generation validation)
state.appendStepToLastMsg(validationStep.label, nodeId, [], validationStep.details);
Expand All @@ -172,6 +167,9 @@ async function processLangGraphChunk(state: ChatState, chunk: any) {
state.lastMsg().setContent(validationStep.fixed_message);
}
}
} else {
// Handle other updates
state.appendStepToLastMsg(`💭 ${nodeId.replace("_", " ")}`, nodeId);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion chat-with-context/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export default defineConfig({
port: 3000,
},
envDir: "../",
envPrefix: "EXPASY_",
envPrefix: "CHAT_",
build: {
outDir: "dist",
target: ["esnext"],
Expand Down
4 changes: 2 additions & 2 deletions compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ services:
# - DEFAULT_LLM_MODEL=deepseek/deepseek-chat
# - DEFAULT_LLM_MODEL=azure/Mistral-Large-2411
# - DEFAULT_LLM_MODEL=azure/DeepSeek-R1
- DEFAULT_LLM_MODEL=openai/o3-mini
# - DEFAULT_LLM_MODEL=openai/gpt-4o-mini
# - DEFAULT_LLM_MODEL=openai/o3-mini
- DEFAULT_LLM_MODEL=openai/gpt-4o-mini
entrypoint: ["uv", "run", "uvicorn", "src.expasy_agent.main:app", "--host", "0.0.0.0", "--port", "80", "--reload"]

# langgraph-redis:
Expand Down
5 changes: 5 additions & 0 deletions deploy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ elif [ "$1" = "index" ]; then
echo "🔎 Indexing endpoints in the vector database"
ssh_cmd "podman-compose exec api python src/expasy_agent/indexing/index_endpoints.py"

elif [ "$1" = "import-entities-index" ]; then
echo "Imported from adsicore which has GPU to generate them"
scp -R adsicore:/mnt/scratch/sparql-llm/packages/expasy-agent/src/expasy_agent/data/qdrant/collections/entities ./data/qdrant/collections/entities
# scp -R adsicore:/mnt/scratch/sparql-llm/packages/expasy-agent/src/expasy_agent/data/qdrant/collections/entities expasychatpodman:/var/containers/podman/sparql-llm/data/qdrant/collections/entities

elif [ "$1" = "likes" ]; then
mkdir -p data/prod
scp expasychat:/var/containers/podman/sparql-llm/data/logs/likes.jsonl ./data/prod/
Expand Down
4 changes: 2 additions & 2 deletions notebooks/test_expasy_chat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@
"} LIMIT 20\"\"\",\n",
" },\n",
" {\n",
" \"question\": \"What are the orthologs in rat for protein Q9Y2T1 ? Return ?ratProtein ?ratUniProtXref\",\n",
" \"question\": \"What are the orthologs in rat for protein Q9Y2T1? Return ?ratProtein ?ratUniProtXref\",\n",
" \"endpoint\": \"https://sparql.omabrowser.org/sparql/\",\n",
" \"query\": \"\"\"PREFIX up: <http://purl.uniprot.org/core/>\n",
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>\n",
Expand Down Expand Up @@ -1397,7 +1397,7 @@
"\n",
"list_of_approaches = {\n",
" \"No RAG\": answer_no_rag,\n",
" # \"RAG without validation\": answer_rag_without_validation,\n",
" \"RAG without validation\": answer_rag_without_validation,\n",
" \"RAG with validation\": answer_rag_with_validation,\n",
"}\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/test_expasy_chat_with_training_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"from sparql_llm.utils import extract_sparql_queries\n",
"\n",
"load_dotenv()\n",
"expasy_api_key = os.getenv(\"EXPASY_API_KEY\")\n",
"expasy_api_key = os.getenv(\"CHAT_API_KEY\")\n",
"\n",
"example_queries = [\n",
" # {\n",
Expand Down
2 changes: 2 additions & 0 deletions packages/expasy-agent/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ dependencies = [
"litellm",
"jinja2",
"tqdm",
"scispacy",
"en_core_sci_md @ https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_md-0.5.4.tar.gz",
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion packages/expasy-agent/src/expasy_agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class Configuration:
default=settings.default_temperature,
metadata={
"description": "The temperature of the language model."
"Should be between 0.0 and 1.0. Higher values make the model more creative but less deterministic."
"Should be between 0.0 and 2.0. Higher values make the model more creative but less deterministic."
},
)
max_tokens: Annotated[int, {"__template_metadata__": {"kind": "llm"}}] = field(
Expand Down
2 changes: 1 addition & 1 deletion packages/expasy-agent/src/expasy_agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def route_model_output(state: State, config: RunnableConfig) -> Literal["__end__
# Add edges
builder.add_edge("__start__", "retrieve")
builder.add_edge("retrieve", "call_model")
builder.add_edge("extract_entities", "call_model")
builder.add_edge("call_model", "validate_output")

# Entity extraction node
builder.add_node(extract_entities)
builder.add_edge("__start__", "extract_entities")
builder.add_edge("extract_entities", "call_model")

# Add a conditional edge to determine the next step after `validate_output`
builder.add_conditional_edges(
Expand Down
112 changes: 9 additions & 103 deletions packages/expasy-agent/src/expasy_agent/indexing/index_entities.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import csv
import os
import time
from ast import literal_eval

from langchain_core.documents import Document
from qdrant_client import QdrantClient, models
from tqdm import tqdm

from expasy_agent.config import get_embedding_model, get_vectordb, settings
from sparql_llm.utils import query_sparql

# Run the script to extract entities from endpoints and generate embeddings for them (long):
# uv run python src/sparql_llm/embed_entities.py
from expasy_agent.config import get_embedding_model, settings

# NOTE: Run the script to extract entities from endpoints and generate embeddings for them (long):
# ssh adsicore
# cd /mnt/scratch/sparql-llm/packages/expasy-agent
# nohup uv run --extra gpu src/expasy_agent/indexing/index_entities.py &


entities_embeddings_dir = os.path.join("data", "embeddings")
Expand Down Expand Up @@ -43,6 +42,7 @@ def retrieve_index_data(entity: dict, docs: list[Document], pagination: (int, in
def generate_embeddings_for_entities():
start_time = time.time()
embedding_model = get_embedding_model(gpu=True)
print("Start indexing entities")

entities_list = {
"genex:AnatomicalEntity": {
Expand Down Expand Up @@ -258,7 +258,7 @@ def generate_embeddings_for_entities():

print(f"Done querying in {time.time() - start_time} seconds, generating embeddings for {len(docs)} entities")

# To test with a smaller number of entities
# Uncomment the next line to test with a smaller number of entities
# docs = docs[:10]

embeddings = embedding_model.embed([q.page_content for q in docs])
Expand All @@ -280,101 +280,7 @@ def generate_embeddings_for_entities():
),
)

# NOTE: saving to a CSV makes it too slow to then read and upload to the vectordb, so we now directly upload to the vectordb
# with open(entities_embeddings_filepath, mode="w", newline="") as file:
# writer = csv.writer(file)
# header = ["label", "uri", "endpoint_url", "entity_type", "embedding"]
# writer.writerow(header)
# for doc, embedding in zip(docs, embeddings):
# row = [
# doc.metadata["label"],
# doc.metadata["uri"],
# doc.metadata["endpoint_url"],
# doc.metadata["entity_type"],
# embedding.tolist(),
# ]
# writer.writerow(row)

print(f"Done generating embeddings for {len(docs)} entities in {time.time() - start_time} seconds")

# vectordb = get_vectordb()
# vectordb.upsert(
# collection_name="entities",
# points=models.Batch(
# ids=list(range(1, len(docs) + 1)),
# vectors=embeddings,
# payloads=[doc.metadata for doc in docs],
# ),
# # wait=False, # Waiting for indexing to finish or not
# )


# NOTE: not used anymore, we now directly load to a local vectordb then move the loaded entities collection to the main vectordb
def load_entities_embeddings_to_vectordb():
start_time = time.time()
vectordb = get_vectordb()
docs = []
embeddings = []
batch_size = 100000
embeddings_count = 0

print(f"Reading entities embeddings from the .csv file at {entities_embeddings_filepath}")
with open(entities_embeddings_filepath) as file:
reader = csv.DictReader(file)
for row in tqdm(reader, desc="Extracting embeddings from CSV file"):
docs.append(
Document(
page_content=row["label"],
metadata={
"label": row["label"],
"uri": row["uri"],
"endpoint_url": row["endpoint_url"],
"entity_type": row["entity_type"],
},
)
)
embeddings.append(literal_eval(row["embedding"]))
embeddings_count += 1

if len(docs) == batch_size:
vectordb.upsert(
collection_name=settings.entities_collection_name,
points=models.Batch(
ids=list(range(embeddings_count - batch_size + 1, embeddings_count + 1)),
vectors=embeddings,
payloads=[doc.metadata for doc in docs],
),
)
docs = []
embeddings = []


print(
f"Found embeddings for {len(docs)} entities in {time.time() - start_time} seconds. Now adding them to the vectordb"
)

# for i in range(0, len(docs), batch_size):
# batch_docs = docs[i:i + batch_size]
# batch_embeddings = embeddings[i:i + batch_size]
# vectordb.upsert(
# collection_name=settings.entities_collection_name,
# points=models.Batch(
# ids=list(range(i + 1, i + len(batch_docs) + 1)),
# vectors=batch_embeddings,
# payloads=[doc.metadata for doc in batch_docs],
# ),
# )

# vectordb.upsert(
# collection_name=settings.entities_collection_name,
# points=models.Batch(
# ids=list(range(1, len(docs) + 1)),
# vectors=embeddings,
# payloads=[doc.metadata for doc in docs],
# ),
# )

print(f"Done uploading embeddings for {len(docs)} entities in the vectordb in {time.time() - start_time} seconds")
print(f"Done generating embeddings for {len(docs)} entities in {(time.time() - start_time) / 60:.2f} minutes")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 462c942

Please sign in to comment.