diff --git a/README.md b/README.md index 275628a..66af4fa 100644 --- a/README.md +++ b/README.md @@ -214,7 +214,8 @@ Requirements: Docker, nodejs (to build the frontend), and optionally [`uv`](http > **Experimental entities indexing**: it can take a lot of time to generate embeddings for entities. So we recommend to run the script to generate embeddings on a machine with GPU (does not need to be a powerful one, but at least with a GPU, checkout [fastembed GPU docs](https://qdrant.github.io/fastembed/examples/FastEmbed_GPU/) to install the GPU drivers and dependencies) > > ```sh -> uv run --with gpu packages/expasy-agent/src/expasy_agent/indexing/index_entities.py +> cd packages/expasy-agent +> nohup uv run --extra gpu src/expasy_agent/indexing/index_entities.py & > ``` > > Then move the entities collection containing the embeddings in `data/qdrant/collections/entities` before starting the stack diff --git a/chat-with-context/index.html b/chat-with-context/index.html index d9dc4c9..2c59cb0 100644 --- a/chat-with-context/index.html +++ b/chat-with-context/index.html @@ -90,7 +90,7 @@

diff --git a/chat-with-context/src/providers.ts b/chat-with-context/src/providers.ts index 92bfb9d..e8c8135 100644 --- a/chat-with-context/src/providers.ts +++ b/chat-with-context/src/providers.ts @@ -116,37 +116,33 @@ async function processLangGraphChunk(state: ChatState, chunk: any) { throw new Error(`An error occurred. Please try again. ${chunk.data.error}: ${chunk.data.message}`); } // console.log(chunk); - // Handle updates to the state + // Handle updates to the state (nodes that are retrieving stuff without querying the LLM usually) if (chunk.event === "updates") { // console.log("UPDATES", chunk); for (const nodeId of Object.keys(chunk.data)) { const nodeData = chunk.data[nodeId]; - // Retrieved docs sent if (nodeData.retrieved_docs) { + // Retrieved docs sent state.appendStepToLastMsg( `📚️ Using ${nodeData.retrieved_docs.length} documents`, nodeId, nodeData.retrieved_docs, ); - } - // Handle entities extracted from the user input - if (nodeData.extracted_entities) { + } else if (nodeData.extracted_entities) { + // Handle entities extracted from the user input state.appendStepToLastMsg( `⚗️ Extracted ${nodeData.extracted_entities.length} potential entities`, nodeId, [], nodeData.extracted_entities.map((entity: any) => - `\n\nEntities found in the user question for "${entity.term.join(" ")}":\n\n` + + `\n\nEntities found in the user question for "${entity.text}":\n\n` + entity.matchs.map((match: any) => `- ${match.payload.label} with IRI <${match.payload.uri}> in endpoint ${match.payload.endpoint_url}\n\n` ).join('') ).join(''), ); - } - // Handle things extracted from output (SPARQL queries here) - if (nodeData.structured_output) { - // const extractedEntities = nodeData.extracted_entities; - // Retrieved extracted_entities sent + } else if (nodeData.structured_output) { + // Handle things extracted from output (SPARQL queries here) if (nodeData.structured_output.sparql_query) { state.lastMsg().setLinks([ { @@ -158,9 +154,8 @@ async function processLangGraphChunk(state: ChatState, chunk: any) { }, ]); } - } - // Handle post-generation validation - if (nodeData.validation) { + } else if (nodeData.validation) { + // Handle post-generation validation for (const validationStep of nodeData.validation) { // Handle messages related to tools (includes post generation validation) state.appendStepToLastMsg(validationStep.label, nodeId, [], validationStep.details); @@ -172,6 +167,9 @@ async function processLangGraphChunk(state: ChatState, chunk: any) { state.lastMsg().setContent(validationStep.fixed_message); } } + } else { + // Handle other updates + state.appendStepToLastMsg(`💭 ${nodeId.replace("_", " ")}`, nodeId); } } } diff --git a/chat-with-context/vite.config.ts b/chat-with-context/vite.config.ts index bc4ede2..d34609b 100644 --- a/chat-with-context/vite.config.ts +++ b/chat-with-context/vite.config.ts @@ -13,7 +13,7 @@ export default defineConfig({ port: 3000, }, envDir: "../", - envPrefix: "EXPASY_", + envPrefix: "CHAT_", build: { outDir: "dist", target: ["esnext"], diff --git a/compose.dev.yml b/compose.dev.yml index d2aafd9..631e48d 100644 --- a/compose.dev.yml +++ b/compose.dev.yml @@ -24,8 +24,8 @@ services: # - DEFAULT_LLM_MODEL=deepseek/deepseek-chat # - DEFAULT_LLM_MODEL=azure/Mistral-Large-2411 # - DEFAULT_LLM_MODEL=azure/DeepSeek-R1 - - DEFAULT_LLM_MODEL=openai/o3-mini - # - DEFAULT_LLM_MODEL=openai/gpt-4o-mini + # - DEFAULT_LLM_MODEL=openai/o3-mini + - DEFAULT_LLM_MODEL=openai/gpt-4o-mini entrypoint: ["uv", "run", "uvicorn", "src.expasy_agent.main:app", "--host", "0.0.0.0", "--port", "80", "--reload"] # langgraph-redis: diff --git a/deploy.sh b/deploy.sh index 3a48d37..91441ca 100755 --- a/deploy.sh +++ b/deploy.sh @@ -28,6 +28,11 @@ elif [ "$1" = "index" ]; then echo "🔎 Indexing endpoints in the vector database" ssh_cmd "podman-compose exec api python src/expasy_agent/indexing/index_endpoints.py" +elif [ "$1" = "import-entities-index" ]; then + echo "Imported from adsicore which has GPU to generate them" + scp -R adsicore:/mnt/scratch/sparql-llm/packages/expasy-agent/src/expasy_agent/data/qdrant/collections/entities ./data/qdrant/collections/entities + # scp -R adsicore:/mnt/scratch/sparql-llm/packages/expasy-agent/src/expasy_agent/data/qdrant/collections/entities expasychatpodman:/var/containers/podman/sparql-llm/data/qdrant/collections/entities + elif [ "$1" = "likes" ]; then mkdir -p data/prod scp expasychat:/var/containers/podman/sparql-llm/data/logs/likes.jsonl ./data/prod/ diff --git a/notebooks/test_expasy_chat.ipynb b/notebooks/test_expasy_chat.ipynb index 9abc36f..9992002 100644 --- a/notebooks/test_expasy_chat.ipynb +++ b/notebooks/test_expasy_chat.ipynb @@ -989,7 +989,7 @@ "} LIMIT 20\"\"\",\n", " },\n", " {\n", - " \"question\": \"What are the orthologs in rat for protein Q9Y2T1 ? Return ?ratProtein ?ratUniProtXref\",\n", + " \"question\": \"What are the orthologs in rat for protein Q9Y2T1? Return ?ratProtein ?ratUniProtXref\",\n", " \"endpoint\": \"https://sparql.omabrowser.org/sparql/\",\n", " \"query\": \"\"\"PREFIX up: \n", "PREFIX rdfs: \n", @@ -1397,7 +1397,7 @@ "\n", "list_of_approaches = {\n", " \"No RAG\": answer_no_rag,\n", - " # \"RAG without validation\": answer_rag_without_validation,\n", + " \"RAG without validation\": answer_rag_without_validation,\n", " \"RAG with validation\": answer_rag_with_validation,\n", "}\n", "\n", diff --git a/notebooks/test_expasy_chat_with_training_data.ipynb b/notebooks/test_expasy_chat_with_training_data.ipynb index 4562e38..d906a32 100644 --- a/notebooks/test_expasy_chat_with_training_data.ipynb +++ b/notebooks/test_expasy_chat_with_training_data.ipynb @@ -44,7 +44,7 @@ "from sparql_llm.utils import extract_sparql_queries\n", "\n", "load_dotenv()\n", - "expasy_api_key = os.getenv(\"EXPASY_API_KEY\")\n", + "expasy_api_key = os.getenv(\"CHAT_API_KEY\")\n", "\n", "example_queries = [\n", " # {\n", diff --git a/packages/expasy-agent/pyproject.toml b/packages/expasy-agent/pyproject.toml index 44d1cf3..2e7aa86 100644 --- a/packages/expasy-agent/pyproject.toml +++ b/packages/expasy-agent/pyproject.toml @@ -43,6 +43,8 @@ dependencies = [ "litellm", "jinja2", "tqdm", + "scispacy", + "en_core_sci_md @ https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_md-0.5.4.tar.gz", ] [project.optional-dependencies] diff --git a/packages/expasy-agent/src/expasy_agent/config.py b/packages/expasy-agent/src/expasy_agent/config.py index 5c5772f..aae13ef 100644 --- a/packages/expasy-agent/src/expasy_agent/config.py +++ b/packages/expasy-agent/src/expasy_agent/config.py @@ -174,7 +174,7 @@ class Configuration: default=settings.default_temperature, metadata={ "description": "The temperature of the language model." - "Should be between 0.0 and 1.0. Higher values make the model more creative but less deterministic." + "Should be between 0.0 and 2.0. Higher values make the model more creative but less deterministic." }, ) max_tokens: Annotated[int, {"__template_metadata__": {"kind": "llm"}}] = field( diff --git a/packages/expasy-agent/src/expasy_agent/graph.py b/packages/expasy-agent/src/expasy_agent/graph.py index 7d445a1..f0a95fc 100644 --- a/packages/expasy-agent/src/expasy_agent/graph.py +++ b/packages/expasy-agent/src/expasy_agent/graph.py @@ -61,12 +61,12 @@ def route_model_output(state: State, config: RunnableConfig) -> Literal["__end__ # Add edges builder.add_edge("__start__", "retrieve") builder.add_edge("retrieve", "call_model") -builder.add_edge("extract_entities", "call_model") builder.add_edge("call_model", "validate_output") # Entity extraction node builder.add_node(extract_entities) builder.add_edge("__start__", "extract_entities") +builder.add_edge("extract_entities", "call_model") # Add a conditional edge to determine the next step after `validate_output` builder.add_conditional_edges( diff --git a/packages/expasy-agent/src/expasy_agent/indexing/index_entities.py b/packages/expasy-agent/src/expasy_agent/indexing/index_entities.py index 9241a26..c6600b3 100644 --- a/packages/expasy-agent/src/expasy_agent/indexing/index_entities.py +++ b/packages/expasy-agent/src/expasy_agent/indexing/index_entities.py @@ -1,17 +1,16 @@ -import csv import os import time -from ast import literal_eval from langchain_core.documents import Document from qdrant_client import QdrantClient, models -from tqdm import tqdm - -from expasy_agent.config import get_embedding_model, get_vectordb, settings from sparql_llm.utils import query_sparql -# Run the script to extract entities from endpoints and generate embeddings for them (long): -# uv run python src/sparql_llm/embed_entities.py +from expasy_agent.config import get_embedding_model, settings + +# NOTE: Run the script to extract entities from endpoints and generate embeddings for them (long): +# ssh adsicore +# cd /mnt/scratch/sparql-llm/packages/expasy-agent +# nohup uv run --extra gpu src/expasy_agent/indexing/index_entities.py & entities_embeddings_dir = os.path.join("data", "embeddings") @@ -43,6 +42,7 @@ def retrieve_index_data(entity: dict, docs: list[Document], pagination: (int, in def generate_embeddings_for_entities(): start_time = time.time() embedding_model = get_embedding_model(gpu=True) + print("Start indexing entities") entities_list = { "genex:AnatomicalEntity": { @@ -258,7 +258,7 @@ def generate_embeddings_for_entities(): print(f"Done querying in {time.time() - start_time} seconds, generating embeddings for {len(docs)} entities") - # To test with a smaller number of entities + # Uncomment the next line to test with a smaller number of entities # docs = docs[:10] embeddings = embedding_model.embed([q.page_content for q in docs]) @@ -280,101 +280,7 @@ def generate_embeddings_for_entities(): ), ) - # NOTE: saving to a CSV makes it too slow to then read and upload to the vectordb, so we now directly upload to the vectordb - # with open(entities_embeddings_filepath, mode="w", newline="") as file: - # writer = csv.writer(file) - # header = ["label", "uri", "endpoint_url", "entity_type", "embedding"] - # writer.writerow(header) - # for doc, embedding in zip(docs, embeddings): - # row = [ - # doc.metadata["label"], - # doc.metadata["uri"], - # doc.metadata["endpoint_url"], - # doc.metadata["entity_type"], - # embedding.tolist(), - # ] - # writer.writerow(row) - - print(f"Done generating embeddings for {len(docs)} entities in {time.time() - start_time} seconds") - - # vectordb = get_vectordb() - # vectordb.upsert( - # collection_name="entities", - # points=models.Batch( - # ids=list(range(1, len(docs) + 1)), - # vectors=embeddings, - # payloads=[doc.metadata for doc in docs], - # ), - # # wait=False, # Waiting for indexing to finish or not - # ) - - -# NOTE: not used anymore, we now directly load to a local vectordb then move the loaded entities collection to the main vectordb -def load_entities_embeddings_to_vectordb(): - start_time = time.time() - vectordb = get_vectordb() - docs = [] - embeddings = [] - batch_size = 100000 - embeddings_count = 0 - - print(f"Reading entities embeddings from the .csv file at {entities_embeddings_filepath}") - with open(entities_embeddings_filepath) as file: - reader = csv.DictReader(file) - for row in tqdm(reader, desc="Extracting embeddings from CSV file"): - docs.append( - Document( - page_content=row["label"], - metadata={ - "label": row["label"], - "uri": row["uri"], - "endpoint_url": row["endpoint_url"], - "entity_type": row["entity_type"], - }, - ) - ) - embeddings.append(literal_eval(row["embedding"])) - embeddings_count += 1 - - if len(docs) == batch_size: - vectordb.upsert( - collection_name=settings.entities_collection_name, - points=models.Batch( - ids=list(range(embeddings_count - batch_size + 1, embeddings_count + 1)), - vectors=embeddings, - payloads=[doc.metadata for doc in docs], - ), - ) - docs = [] - embeddings = [] - - - print( - f"Found embeddings for {len(docs)} entities in {time.time() - start_time} seconds. Now adding them to the vectordb" - ) - - # for i in range(0, len(docs), batch_size): - # batch_docs = docs[i:i + batch_size] - # batch_embeddings = embeddings[i:i + batch_size] - # vectordb.upsert( - # collection_name=settings.entities_collection_name, - # points=models.Batch( - # ids=list(range(i + 1, i + len(batch_docs) + 1)), - # vectors=batch_embeddings, - # payloads=[doc.metadata for doc in batch_docs], - # ), - # ) - - # vectordb.upsert( - # collection_name=settings.entities_collection_name, - # points=models.Batch( - # ids=list(range(1, len(docs) + 1)), - # vectors=embeddings, - # payloads=[doc.metadata for doc in docs], - # ), - # ) - - print(f"Done uploading embeddings for {len(docs)} entities in the vectordb in {time.time() - start_time} seconds") + print(f"Done generating embeddings for {len(docs)} entities in {(time.time() - start_time) / 60:.2f} minutes") if __name__ == "__main__": diff --git a/packages/expasy-agent/src/expasy_agent/nodes/extraction.py b/packages/expasy-agent/src/expasy_agent/nodes/extraction.py index c2f38cf..acde877 100644 --- a/packages/expasy-agent/src/expasy_agent/nodes/extraction.py +++ b/packages/expasy-agent/src/expasy_agent/nodes/extraction.py @@ -5,6 +5,7 @@ from langchain_core.runnables import RunnableConfig from qdrant_client import QdrantClient from sparql_llm.utils import get_message_text +import spacy from expasy_agent.config import Configuration, get_embedding_model, settings from expasy_agent.state import State @@ -15,7 +16,7 @@ def format_extracted_entities(entities_list: list[Any]) -> str: return "No entities found in the user question that matches entities in the endpoints. " prompt = "" for entity in entities_list: - prompt += f'\n\nEntities found in the user question for "{" ".join(entity["term"])}":\n\n' + prompt += f'\n\nEntities found in the user question for "{" ".join(entity["text"])}":\n\n' for match in entity["matchs"]: prompt += f"- {match.payload['label']} with IRI <{match.payload['uri']}> in endpoint {match.payload['endpoint_url']}\n\n" # prompt += "\nIf the user is asking for a named entity, and this entity cannot be found in the endpoint, warn them about the fact we could not find it in the endpoints.\n\n" @@ -37,49 +38,94 @@ async def extract_entities(state: State, config: RunnableConfig) -> dict[str, li dict[str, list[Document]]: A dictionary with a single key "retrieved_docs" containing a list of retrieved Document objects. """ - human_input = get_message_text(state.messages[-1]) + user_input = get_message_text(state.messages[-1]) vectordb = QdrantClient(url=settings.vectordb_url, prefer_grpc=True) embedding_model = get_embedding_model() - score_threshold = 0.8 - sentence_splitted = re.findall(r"\b\w+\b", human_input) - window_size = len(sentence_splitted) entities_list = [] - while window_size > 0 and window_size <= len(sentence_splitted): - window_start = 0 - window_end = window_start + window_size - while window_end <= len(sentence_splitted): - term = sentence_splitted[window_start:window_end] - # print("term", term) - term_embeddings = next(iter(embedding_model.embed([" ".join(term)]))) - query_hits = vectordb.search( - collection_name=settings.entities_collection_name, - query_vector=term_embeddings, - limit=10, - ) - matchs = [] - for query_hit in query_hits: - if query_hit.score > score_threshold: - matchs.append(query_hit) - if len(matchs) > 0: - entities_list.append( - { - "matchs": matchs, - "term": term, - "start_index": window_start, - "end_index": window_end, - } - ) - # term_search = reduce(lambda x, y: "{} {}".format(x, y), sentence_splitted[window_start:window_end]) - # resultSearch = index.search(term_search) - # if resultSearch is not None and len(resultSearch) > 0: - # selected_hit = resultSearch[0] - # if selected_hit['score'] > MAX_SCORE_PARSER_TRIPLES: - # selected_hit = None - # if selected_hit is not None and selected_hit not in matchs: - # matchs.append(selected_hit) - window_start += 1 - window_end = window_start + window_size - window_size -= 1 + + # Extract potential entities with scispaCy https://allenai.github.io/scispacy/ + # NOTE: more expensive alternative could be to use the BioBERT model + nlp = spacy.load("en_core_sci_md") + potential_entities = nlp(user_input).ents + print(potential_entities) + + # Search for matches in the indexed entities + entities_embeddings = embedding_model.embed([entity.text for entity in potential_entities]) + for i, entity_embeddings in enumerate(entities_embeddings): + query_hits = vectordb.search( + collection_name=settings.entities_collection_name, + query_vector=entity_embeddings, + limit=10, + ) + matchs = [] + for query_hit in query_hits: + if query_hit.score > score_threshold: + matchs.append(query_hit) + entities_list.append( + { + "matchs": matchs, + "text": potential_entities[i].text, + # "start_index": None, + # "end_index": None, + } + ) return {"extracted_entities": entities_list} + + ## Using BioBERT + # from transformers import AutoTokenizer, AutoModelForTokenClassification + # from transformers import pipeline + + # # Load BioBERT model and tokenizer + # model_name = "dmis-lab/biobert-v1.1" + # tokenizer = AutoTokenizer.from_pretrained(model_name) + # model = AutoModelForTokenClassification.from_pretrained(model_name) + + # # Create NER pipeline + # ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") + + # # Extract entities + # results = ner_pipeline(user_input) + # for entity in results: + # print(f"{entity['word']} ({entity['entity_group']})") + + ## Old way + # sentence_splitted = re.findall(r"\b\w+\b", user_input) + # window_size = len(sentence_splitted) + # while window_size > 0 and window_size <= len(sentence_splitted): + # window_start = 0 + # window_end = window_start + window_size + # while window_end <= len(sentence_splitted): + # term = sentence_splitted[window_start:window_end] + # # print("term", term) + # term_embeddings = next(iter(embedding_model.embed([" ".join(term)]))) + # query_hits = vectordb.search( + # collection_name=settings.entities_collection_name, + # query_vector=term_embeddings, + # limit=10, + # ) + # matchs = [] + # for query_hit in query_hits: + # if query_hit.score > score_threshold: + # matchs.append(query_hit) + # if len(matchs) > 0: + # entities_list.append( + # { + # "matchs": matchs, + # "term": term, + # "start_index": window_start, + # "end_index": window_end, + # } + # ) + # # term_search = reduce(lambda x, y: "{} {}".format(x, y), sentence_splitted[window_start:window_end]) + # # resultSearch = index.search(term_search) + # # if resultSearch is not None and len(resultSearch) > 0: + # # selected_hit = resultSearch[0] + # # if selected_hit['score'] > MAX_SCORE_PARSER_TRIPLES: + # # selected_hit = None + # # if selected_hit is not None and selected_hit not in matchs: + # # matchs.append(selected_hit) + # window_start += 1 + # window_end = window_start + window_size + # window_size -= 1 diff --git a/packages/expasy-agent/src/expasy_agent/nodes/tools.py b/packages/expasy-agent/src/expasy_agent/nodes/tools.py index 7e1134b..f46a589 100644 --- a/packages/expasy-agent/src/expasy_agent/nodes/tools.py +++ b/packages/expasy-agent/src/expasy_agent/nodes/tools.py @@ -20,7 +20,7 @@ def multiply(a: int, b: int) -> int: # TODO: Extract potential entities from the user question (experimental) # entities_list = extract_entities(question) # for entity in entities_list: -# prompt += f'\n\nEntities found in the user question for "{" ".join(entity["term"])}":\n\n' +# prompt += f'\n\nEntities found in the user question for "{" ".join(entity["text"])}":\n\n' # for match in entity["matchs"]: # prompt += f"- {match.payload['label']} with IRI <{match.payload['uri']}> in endpoint {match.payload['endpoint_url']}\n\n" # if len(entities_list) == 0: diff --git a/packages/expasy-agent/src/expasy_agent/prompts.py b/packages/expasy-agent/src/expasy_agent/prompts.py index 8e6b968..80c4605 100644 --- a/packages/expasy-agent/src/expasy_agent/prompts.py +++ b/packages/expasy-agent/src/expasy_agent/prompts.py @@ -16,6 +16,8 @@ {extracted_entities} """ + + # try to make it as efficient as possible to avoid timeout due to how large the datasets are, make sure the query written is valid SPARQL, # If the answer to the question is in the provided context, do not provide a query, just provide the answer, unless explicitly asked.