diff --git a/README.md b/README.md
index 275628a..66af4fa 100644
--- a/README.md
+++ b/README.md
@@ -214,7 +214,8 @@ Requirements: Docker, nodejs (to build the frontend), and optionally [`uv`](http
> **Experimental entities indexing**: it can take a lot of time to generate embeddings for entities. So we recommend to run the script to generate embeddings on a machine with GPU (does not need to be a powerful one, but at least with a GPU, checkout [fastembed GPU docs](https://qdrant.github.io/fastembed/examples/FastEmbed_GPU/) to install the GPU drivers and dependencies)
>
> ```sh
-> uv run --with gpu packages/expasy-agent/src/expasy_agent/indexing/index_entities.py
+> cd packages/expasy-agent
+> nohup uv run --extra gpu src/expasy_agent/indexing/index_entities.py &
> ```
>
> Then move the entities collection containing the embeddings in `data/qdrant/collections/entities` before starting the stack
diff --git a/chat-with-context/index.html b/chat-with-context/index.html
index d9dc4c9..2c59cb0 100644
--- a/chat-with-context/index.html
+++ b/chat-with-context/index.html
@@ -90,7 +90,7 @@
diff --git a/chat-with-context/src/providers.ts b/chat-with-context/src/providers.ts
index 92bfb9d..e8c8135 100644
--- a/chat-with-context/src/providers.ts
+++ b/chat-with-context/src/providers.ts
@@ -116,37 +116,33 @@ async function processLangGraphChunk(state: ChatState, chunk: any) {
throw new Error(`An error occurred. Please try again. ${chunk.data.error}: ${chunk.data.message}`);
}
// console.log(chunk);
- // Handle updates to the state
+ // Handle updates to the state (nodes that are retrieving stuff without querying the LLM usually)
if (chunk.event === "updates") {
// console.log("UPDATES", chunk);
for (const nodeId of Object.keys(chunk.data)) {
const nodeData = chunk.data[nodeId];
- // Retrieved docs sent
if (nodeData.retrieved_docs) {
+ // Retrieved docs sent
state.appendStepToLastMsg(
`📚️ Using ${nodeData.retrieved_docs.length} documents`,
nodeId,
nodeData.retrieved_docs,
);
- }
- // Handle entities extracted from the user input
- if (nodeData.extracted_entities) {
+ } else if (nodeData.extracted_entities) {
+ // Handle entities extracted from the user input
state.appendStepToLastMsg(
`⚗️ Extracted ${nodeData.extracted_entities.length} potential entities`,
nodeId,
[],
nodeData.extracted_entities.map((entity: any) =>
- `\n\nEntities found in the user question for "${entity.term.join(" ")}":\n\n` +
+ `\n\nEntities found in the user question for "${entity.text}":\n\n` +
entity.matchs.map((match: any) =>
`- ${match.payload.label} with IRI <${match.payload.uri}> in endpoint ${match.payload.endpoint_url}\n\n`
).join('')
).join(''),
);
- }
- // Handle things extracted from output (SPARQL queries here)
- if (nodeData.structured_output) {
- // const extractedEntities = nodeData.extracted_entities;
- // Retrieved extracted_entities sent
+ } else if (nodeData.structured_output) {
+ // Handle things extracted from output (SPARQL queries here)
if (nodeData.structured_output.sparql_query) {
state.lastMsg().setLinks([
{
@@ -158,9 +154,8 @@ async function processLangGraphChunk(state: ChatState, chunk: any) {
},
]);
}
- }
- // Handle post-generation validation
- if (nodeData.validation) {
+ } else if (nodeData.validation) {
+ // Handle post-generation validation
for (const validationStep of nodeData.validation) {
// Handle messages related to tools (includes post generation validation)
state.appendStepToLastMsg(validationStep.label, nodeId, [], validationStep.details);
@@ -172,6 +167,9 @@ async function processLangGraphChunk(state: ChatState, chunk: any) {
state.lastMsg().setContent(validationStep.fixed_message);
}
}
+ } else {
+ // Handle other updates
+ state.appendStepToLastMsg(`💭 ${nodeId.replace("_", " ")}`, nodeId);
}
}
}
diff --git a/chat-with-context/vite.config.ts b/chat-with-context/vite.config.ts
index bc4ede2..d34609b 100644
--- a/chat-with-context/vite.config.ts
+++ b/chat-with-context/vite.config.ts
@@ -13,7 +13,7 @@ export default defineConfig({
port: 3000,
},
envDir: "../",
- envPrefix: "EXPASY_",
+ envPrefix: "CHAT_",
build: {
outDir: "dist",
target: ["esnext"],
diff --git a/compose.dev.yml b/compose.dev.yml
index d2aafd9..631e48d 100644
--- a/compose.dev.yml
+++ b/compose.dev.yml
@@ -24,8 +24,8 @@ services:
# - DEFAULT_LLM_MODEL=deepseek/deepseek-chat
# - DEFAULT_LLM_MODEL=azure/Mistral-Large-2411
# - DEFAULT_LLM_MODEL=azure/DeepSeek-R1
- - DEFAULT_LLM_MODEL=openai/o3-mini
- # - DEFAULT_LLM_MODEL=openai/gpt-4o-mini
+ # - DEFAULT_LLM_MODEL=openai/o3-mini
+ - DEFAULT_LLM_MODEL=openai/gpt-4o-mini
entrypoint: ["uv", "run", "uvicorn", "src.expasy_agent.main:app", "--host", "0.0.0.0", "--port", "80", "--reload"]
# langgraph-redis:
diff --git a/deploy.sh b/deploy.sh
index 3a48d37..91441ca 100755
--- a/deploy.sh
+++ b/deploy.sh
@@ -28,6 +28,11 @@ elif [ "$1" = "index" ]; then
echo "🔎 Indexing endpoints in the vector database"
ssh_cmd "podman-compose exec api python src/expasy_agent/indexing/index_endpoints.py"
+elif [ "$1" = "import-entities-index" ]; then
+ echo "Imported from adsicore which has GPU to generate them"
+ scp -R adsicore:/mnt/scratch/sparql-llm/packages/expasy-agent/src/expasy_agent/data/qdrant/collections/entities ./data/qdrant/collections/entities
+ # scp -R adsicore:/mnt/scratch/sparql-llm/packages/expasy-agent/src/expasy_agent/data/qdrant/collections/entities expasychatpodman:/var/containers/podman/sparql-llm/data/qdrant/collections/entities
+
elif [ "$1" = "likes" ]; then
mkdir -p data/prod
scp expasychat:/var/containers/podman/sparql-llm/data/logs/likes.jsonl ./data/prod/
diff --git a/notebooks/test_expasy_chat.ipynb b/notebooks/test_expasy_chat.ipynb
index 9abc36f..9992002 100644
--- a/notebooks/test_expasy_chat.ipynb
+++ b/notebooks/test_expasy_chat.ipynb
@@ -989,7 +989,7 @@
"} LIMIT 20\"\"\",\n",
" },\n",
" {\n",
- " \"question\": \"What are the orthologs in rat for protein Q9Y2T1 ? Return ?ratProtein ?ratUniProtXref\",\n",
+ " \"question\": \"What are the orthologs in rat for protein Q9Y2T1? Return ?ratProtein ?ratUniProtXref\",\n",
" \"endpoint\": \"https://sparql.omabrowser.org/sparql/\",\n",
" \"query\": \"\"\"PREFIX up: \n",
"PREFIX rdfs: \n",
@@ -1397,7 +1397,7 @@
"\n",
"list_of_approaches = {\n",
" \"No RAG\": answer_no_rag,\n",
- " # \"RAG without validation\": answer_rag_without_validation,\n",
+ " \"RAG without validation\": answer_rag_without_validation,\n",
" \"RAG with validation\": answer_rag_with_validation,\n",
"}\n",
"\n",
diff --git a/notebooks/test_expasy_chat_with_training_data.ipynb b/notebooks/test_expasy_chat_with_training_data.ipynb
index 4562e38..d906a32 100644
--- a/notebooks/test_expasy_chat_with_training_data.ipynb
+++ b/notebooks/test_expasy_chat_with_training_data.ipynb
@@ -44,7 +44,7 @@
"from sparql_llm.utils import extract_sparql_queries\n",
"\n",
"load_dotenv()\n",
- "expasy_api_key = os.getenv(\"EXPASY_API_KEY\")\n",
+ "expasy_api_key = os.getenv(\"CHAT_API_KEY\")\n",
"\n",
"example_queries = [\n",
" # {\n",
diff --git a/packages/expasy-agent/pyproject.toml b/packages/expasy-agent/pyproject.toml
index 44d1cf3..2e7aa86 100644
--- a/packages/expasy-agent/pyproject.toml
+++ b/packages/expasy-agent/pyproject.toml
@@ -43,6 +43,8 @@ dependencies = [
"litellm",
"jinja2",
"tqdm",
+ "scispacy",
+ "en_core_sci_md @ https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_md-0.5.4.tar.gz",
]
[project.optional-dependencies]
diff --git a/packages/expasy-agent/src/expasy_agent/config.py b/packages/expasy-agent/src/expasy_agent/config.py
index 5c5772f..aae13ef 100644
--- a/packages/expasy-agent/src/expasy_agent/config.py
+++ b/packages/expasy-agent/src/expasy_agent/config.py
@@ -174,7 +174,7 @@ class Configuration:
default=settings.default_temperature,
metadata={
"description": "The temperature of the language model."
- "Should be between 0.0 and 1.0. Higher values make the model more creative but less deterministic."
+ "Should be between 0.0 and 2.0. Higher values make the model more creative but less deterministic."
},
)
max_tokens: Annotated[int, {"__template_metadata__": {"kind": "llm"}}] = field(
diff --git a/packages/expasy-agent/src/expasy_agent/graph.py b/packages/expasy-agent/src/expasy_agent/graph.py
index 7d445a1..f0a95fc 100644
--- a/packages/expasy-agent/src/expasy_agent/graph.py
+++ b/packages/expasy-agent/src/expasy_agent/graph.py
@@ -61,12 +61,12 @@ def route_model_output(state: State, config: RunnableConfig) -> Literal["__end__
# Add edges
builder.add_edge("__start__", "retrieve")
builder.add_edge("retrieve", "call_model")
-builder.add_edge("extract_entities", "call_model")
builder.add_edge("call_model", "validate_output")
# Entity extraction node
builder.add_node(extract_entities)
builder.add_edge("__start__", "extract_entities")
+builder.add_edge("extract_entities", "call_model")
# Add a conditional edge to determine the next step after `validate_output`
builder.add_conditional_edges(
diff --git a/packages/expasy-agent/src/expasy_agent/indexing/index_entities.py b/packages/expasy-agent/src/expasy_agent/indexing/index_entities.py
index 9241a26..c6600b3 100644
--- a/packages/expasy-agent/src/expasy_agent/indexing/index_entities.py
+++ b/packages/expasy-agent/src/expasy_agent/indexing/index_entities.py
@@ -1,17 +1,16 @@
-import csv
import os
import time
-from ast import literal_eval
from langchain_core.documents import Document
from qdrant_client import QdrantClient, models
-from tqdm import tqdm
-
-from expasy_agent.config import get_embedding_model, get_vectordb, settings
from sparql_llm.utils import query_sparql
-# Run the script to extract entities from endpoints and generate embeddings for them (long):
-# uv run python src/sparql_llm/embed_entities.py
+from expasy_agent.config import get_embedding_model, settings
+
+# NOTE: Run the script to extract entities from endpoints and generate embeddings for them (long):
+# ssh adsicore
+# cd /mnt/scratch/sparql-llm/packages/expasy-agent
+# nohup uv run --extra gpu src/expasy_agent/indexing/index_entities.py &
entities_embeddings_dir = os.path.join("data", "embeddings")
@@ -43,6 +42,7 @@ def retrieve_index_data(entity: dict, docs: list[Document], pagination: (int, in
def generate_embeddings_for_entities():
start_time = time.time()
embedding_model = get_embedding_model(gpu=True)
+ print("Start indexing entities")
entities_list = {
"genex:AnatomicalEntity": {
@@ -258,7 +258,7 @@ def generate_embeddings_for_entities():
print(f"Done querying in {time.time() - start_time} seconds, generating embeddings for {len(docs)} entities")
- # To test with a smaller number of entities
+ # Uncomment the next line to test with a smaller number of entities
# docs = docs[:10]
embeddings = embedding_model.embed([q.page_content for q in docs])
@@ -280,101 +280,7 @@ def generate_embeddings_for_entities():
),
)
- # NOTE: saving to a CSV makes it too slow to then read and upload to the vectordb, so we now directly upload to the vectordb
- # with open(entities_embeddings_filepath, mode="w", newline="") as file:
- # writer = csv.writer(file)
- # header = ["label", "uri", "endpoint_url", "entity_type", "embedding"]
- # writer.writerow(header)
- # for doc, embedding in zip(docs, embeddings):
- # row = [
- # doc.metadata["label"],
- # doc.metadata["uri"],
- # doc.metadata["endpoint_url"],
- # doc.metadata["entity_type"],
- # embedding.tolist(),
- # ]
- # writer.writerow(row)
-
- print(f"Done generating embeddings for {len(docs)} entities in {time.time() - start_time} seconds")
-
- # vectordb = get_vectordb()
- # vectordb.upsert(
- # collection_name="entities",
- # points=models.Batch(
- # ids=list(range(1, len(docs) + 1)),
- # vectors=embeddings,
- # payloads=[doc.metadata for doc in docs],
- # ),
- # # wait=False, # Waiting for indexing to finish or not
- # )
-
-
-# NOTE: not used anymore, we now directly load to a local vectordb then move the loaded entities collection to the main vectordb
-def load_entities_embeddings_to_vectordb():
- start_time = time.time()
- vectordb = get_vectordb()
- docs = []
- embeddings = []
- batch_size = 100000
- embeddings_count = 0
-
- print(f"Reading entities embeddings from the .csv file at {entities_embeddings_filepath}")
- with open(entities_embeddings_filepath) as file:
- reader = csv.DictReader(file)
- for row in tqdm(reader, desc="Extracting embeddings from CSV file"):
- docs.append(
- Document(
- page_content=row["label"],
- metadata={
- "label": row["label"],
- "uri": row["uri"],
- "endpoint_url": row["endpoint_url"],
- "entity_type": row["entity_type"],
- },
- )
- )
- embeddings.append(literal_eval(row["embedding"]))
- embeddings_count += 1
-
- if len(docs) == batch_size:
- vectordb.upsert(
- collection_name=settings.entities_collection_name,
- points=models.Batch(
- ids=list(range(embeddings_count - batch_size + 1, embeddings_count + 1)),
- vectors=embeddings,
- payloads=[doc.metadata for doc in docs],
- ),
- )
- docs = []
- embeddings = []
-
-
- print(
- f"Found embeddings for {len(docs)} entities in {time.time() - start_time} seconds. Now adding them to the vectordb"
- )
-
- # for i in range(0, len(docs), batch_size):
- # batch_docs = docs[i:i + batch_size]
- # batch_embeddings = embeddings[i:i + batch_size]
- # vectordb.upsert(
- # collection_name=settings.entities_collection_name,
- # points=models.Batch(
- # ids=list(range(i + 1, i + len(batch_docs) + 1)),
- # vectors=batch_embeddings,
- # payloads=[doc.metadata for doc in batch_docs],
- # ),
- # )
-
- # vectordb.upsert(
- # collection_name=settings.entities_collection_name,
- # points=models.Batch(
- # ids=list(range(1, len(docs) + 1)),
- # vectors=embeddings,
- # payloads=[doc.metadata for doc in docs],
- # ),
- # )
-
- print(f"Done uploading embeddings for {len(docs)} entities in the vectordb in {time.time() - start_time} seconds")
+ print(f"Done generating embeddings for {len(docs)} entities in {(time.time() - start_time) / 60:.2f} minutes")
if __name__ == "__main__":
diff --git a/packages/expasy-agent/src/expasy_agent/nodes/extraction.py b/packages/expasy-agent/src/expasy_agent/nodes/extraction.py
index c2f38cf..acde877 100644
--- a/packages/expasy-agent/src/expasy_agent/nodes/extraction.py
+++ b/packages/expasy-agent/src/expasy_agent/nodes/extraction.py
@@ -5,6 +5,7 @@
from langchain_core.runnables import RunnableConfig
from qdrant_client import QdrantClient
from sparql_llm.utils import get_message_text
+import spacy
from expasy_agent.config import Configuration, get_embedding_model, settings
from expasy_agent.state import State
@@ -15,7 +16,7 @@ def format_extracted_entities(entities_list: list[Any]) -> str:
return "No entities found in the user question that matches entities in the endpoints. "
prompt = ""
for entity in entities_list:
- prompt += f'\n\nEntities found in the user question for "{" ".join(entity["term"])}":\n\n'
+ prompt += f'\n\nEntities found in the user question for "{" ".join(entity["text"])}":\n\n'
for match in entity["matchs"]:
prompt += f"- {match.payload['label']} with IRI <{match.payload['uri']}> in endpoint {match.payload['endpoint_url']}\n\n"
# prompt += "\nIf the user is asking for a named entity, and this entity cannot be found in the endpoint, warn them about the fact we could not find it in the endpoints.\n\n"
@@ -37,49 +38,94 @@ async def extract_entities(state: State, config: RunnableConfig) -> dict[str, li
dict[str, list[Document]]: A dictionary with a single key "retrieved_docs"
containing a list of retrieved Document objects.
"""
- human_input = get_message_text(state.messages[-1])
+ user_input = get_message_text(state.messages[-1])
vectordb = QdrantClient(url=settings.vectordb_url, prefer_grpc=True)
embedding_model = get_embedding_model()
-
score_threshold = 0.8
- sentence_splitted = re.findall(r"\b\w+\b", human_input)
- window_size = len(sentence_splitted)
entities_list = []
- while window_size > 0 and window_size <= len(sentence_splitted):
- window_start = 0
- window_end = window_start + window_size
- while window_end <= len(sentence_splitted):
- term = sentence_splitted[window_start:window_end]
- # print("term", term)
- term_embeddings = next(iter(embedding_model.embed([" ".join(term)])))
- query_hits = vectordb.search(
- collection_name=settings.entities_collection_name,
- query_vector=term_embeddings,
- limit=10,
- )
- matchs = []
- for query_hit in query_hits:
- if query_hit.score > score_threshold:
- matchs.append(query_hit)
- if len(matchs) > 0:
- entities_list.append(
- {
- "matchs": matchs,
- "term": term,
- "start_index": window_start,
- "end_index": window_end,
- }
- )
- # term_search = reduce(lambda x, y: "{} {}".format(x, y), sentence_splitted[window_start:window_end])
- # resultSearch = index.search(term_search)
- # if resultSearch is not None and len(resultSearch) > 0:
- # selected_hit = resultSearch[0]
- # if selected_hit['score'] > MAX_SCORE_PARSER_TRIPLES:
- # selected_hit = None
- # if selected_hit is not None and selected_hit not in matchs:
- # matchs.append(selected_hit)
- window_start += 1
- window_end = window_start + window_size
- window_size -= 1
+
+ # Extract potential entities with scispaCy https://allenai.github.io/scispacy/
+ # NOTE: more expensive alternative could be to use the BioBERT model
+ nlp = spacy.load("en_core_sci_md")
+ potential_entities = nlp(user_input).ents
+ print(potential_entities)
+
+ # Search for matches in the indexed entities
+ entities_embeddings = embedding_model.embed([entity.text for entity in potential_entities])
+ for i, entity_embeddings in enumerate(entities_embeddings):
+ query_hits = vectordb.search(
+ collection_name=settings.entities_collection_name,
+ query_vector=entity_embeddings,
+ limit=10,
+ )
+ matchs = []
+ for query_hit in query_hits:
+ if query_hit.score > score_threshold:
+ matchs.append(query_hit)
+ entities_list.append(
+ {
+ "matchs": matchs,
+ "text": potential_entities[i].text,
+ # "start_index": None,
+ # "end_index": None,
+ }
+ )
return {"extracted_entities": entities_list}
+
+ ## Using BioBERT
+ # from transformers import AutoTokenizer, AutoModelForTokenClassification
+ # from transformers import pipeline
+
+ # # Load BioBERT model and tokenizer
+ # model_name = "dmis-lab/biobert-v1.1"
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
+ # model = AutoModelForTokenClassification.from_pretrained(model_name)
+
+ # # Create NER pipeline
+ # ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
+
+ # # Extract entities
+ # results = ner_pipeline(user_input)
+ # for entity in results:
+ # print(f"{entity['word']} ({entity['entity_group']})")
+
+ ## Old way
+ # sentence_splitted = re.findall(r"\b\w+\b", user_input)
+ # window_size = len(sentence_splitted)
+ # while window_size > 0 and window_size <= len(sentence_splitted):
+ # window_start = 0
+ # window_end = window_start + window_size
+ # while window_end <= len(sentence_splitted):
+ # term = sentence_splitted[window_start:window_end]
+ # # print("term", term)
+ # term_embeddings = next(iter(embedding_model.embed([" ".join(term)])))
+ # query_hits = vectordb.search(
+ # collection_name=settings.entities_collection_name,
+ # query_vector=term_embeddings,
+ # limit=10,
+ # )
+ # matchs = []
+ # for query_hit in query_hits:
+ # if query_hit.score > score_threshold:
+ # matchs.append(query_hit)
+ # if len(matchs) > 0:
+ # entities_list.append(
+ # {
+ # "matchs": matchs,
+ # "term": term,
+ # "start_index": window_start,
+ # "end_index": window_end,
+ # }
+ # )
+ # # term_search = reduce(lambda x, y: "{} {}".format(x, y), sentence_splitted[window_start:window_end])
+ # # resultSearch = index.search(term_search)
+ # # if resultSearch is not None and len(resultSearch) > 0:
+ # # selected_hit = resultSearch[0]
+ # # if selected_hit['score'] > MAX_SCORE_PARSER_TRIPLES:
+ # # selected_hit = None
+ # # if selected_hit is not None and selected_hit not in matchs:
+ # # matchs.append(selected_hit)
+ # window_start += 1
+ # window_end = window_start + window_size
+ # window_size -= 1
diff --git a/packages/expasy-agent/src/expasy_agent/nodes/tools.py b/packages/expasy-agent/src/expasy_agent/nodes/tools.py
index 7e1134b..f46a589 100644
--- a/packages/expasy-agent/src/expasy_agent/nodes/tools.py
+++ b/packages/expasy-agent/src/expasy_agent/nodes/tools.py
@@ -20,7 +20,7 @@ def multiply(a: int, b: int) -> int:
# TODO: Extract potential entities from the user question (experimental)
# entities_list = extract_entities(question)
# for entity in entities_list:
-# prompt += f'\n\nEntities found in the user question for "{" ".join(entity["term"])}":\n\n'
+# prompt += f'\n\nEntities found in the user question for "{" ".join(entity["text"])}":\n\n'
# for match in entity["matchs"]:
# prompt += f"- {match.payload['label']} with IRI <{match.payload['uri']}> in endpoint {match.payload['endpoint_url']}\n\n"
# if len(entities_list) == 0:
diff --git a/packages/expasy-agent/src/expasy_agent/prompts.py b/packages/expasy-agent/src/expasy_agent/prompts.py
index 8e6b968..80c4605 100644
--- a/packages/expasy-agent/src/expasy_agent/prompts.py
+++ b/packages/expasy-agent/src/expasy_agent/prompts.py
@@ -16,6 +16,8 @@
{extracted_entities}
"""
+
+
# try to make it as efficient as possible to avoid timeout due to how large the datasets are, make sure the query written is valid SPARQL,
# If the answer to the question is in the provided context, do not provide a query, just provide the answer, unless explicitly asked.