Skip to content

Commit

Permalink
improve indexing workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Dec 19, 2024
1 parent e7f88bb commit 7643843
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 96 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ print("\n".join(issues))
4. Run the script to index the resources (SPARQL endpoints listed in config file):

```sh
docker compose run api python src/sparql_llm/embed.py
docker compose run api python src/sparql_llm/index.py
```

> [!WARNING]
Expand All @@ -213,10 +213,10 @@ print("\n".join(issues))
>
> ```sh
> pip install -e ".[chat,gpu]"
> python src/sparql_llm/embed_entities.py
> python src/sparql_llm/index_entities.py
> ```
>
> Then move the CSV containing the embeddings in `data/embeddings/entities_embeddings.py` before running the `embed.py` script
> Then move the entities collection containing the embeddings in `data/qdrant/collections/entities` before starting the stack
## 🧑‍💻 Contributing
Expand All @@ -228,13 +228,13 @@ If you reuse any part of this work, please cite [the arXiv paper](https://arxiv.
```
@misc{emonet2024llmbasedsparqlquerygeneration,
title={LLM-based SPARQL Query Generation from Natural Language over Federated Knowledge Graphs},
title={LLM-based SPARQL Query Generation from Natural Language over Federated Knowledge Graphs},
author={Vincent Emonet and Jerven Bolleman and Severine Duvaud and Tarcisio Mendes de Farias and Ana Claudia Sima},
year={2024},
eprint={2410.06062},
archivePrefix={arXiv},
primaryClass={cs.DB},
url={https://arxiv.org/abs/2410.06062},
url={https://arxiv.org/abs/2410.06062},
}
```
2 changes: 1 addition & 1 deletion compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ services:
- 8000:80
volumes:
- ./src:/app/src
- ./prestart.sh:/app/prestart.sh
# - ./prestart.sh:/app/prestart.sh
entrypoint: /start-reload.sh

# In case you need a GPU-enabled workspace
Expand Down
4 changes: 2 additions & 2 deletions notebooks/compare_queries_examples_to_void.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -357,7 +357,7 @@
"from qdrant_client.models import FieldCondition, Filter, MatchValue\n",
"\n",
"from sparql_llm.config import settings\n",
"from sparql_llm.embed import get_vectordb\n",
"from sparql_llm.index import get_vectordb\n",
"from sparql_llm.validate_sparql import get_void_dict, sparql_query_to_dict\n",
"\n",
"check_endpoints = {\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/rag_embed_queries.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"metadata": {},
"outputs": [],
"source": [
"from sparql_llm.embed import init_vectordb\n",
"from sparql_llm.index import init_vectordb\n",
"\n",
"init_vectordb(\"localhost\")\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/test_example_queries.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -319,7 +319,7 @@
"from SPARQLWrapper import JSON, TURTLE, SPARQLWrapper\n",
"\n",
"from sparql_llm.config import settings\n",
"from sparql_llm.embed import get_vectordb\n",
"from sparql_llm.index import get_vectordb\n",
"\n",
"vectordb = get_vectordb(\"localhost\")\n",
"all_queries, _ = vectordb.scroll(\n",
Expand Down
125 changes: 63 additions & 62 deletions src/sparql_llm/embed.py → src/sparql_llm/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from rdflib import RDF, ConjunctiveGraph, Namespace

from sparql_llm.config import get_embedding_model, get_vectordb, settings
from sparql_llm.embed_entities import load_entities_embeddings_to_vectordb
from sparql_llm.index_entities import load_entities_embeddings_to_vectordb
from sparql_llm.sparql_examples_loader import SparqlExamplesLoader
from sparql_llm.sparql_void_shapes_loader import SparqlVoidShapesLoader
from sparql_llm.utils import get_prefixes_for_endpoints
Expand Down Expand Up @@ -134,78 +134,79 @@ def init_vectordb(vectordb_host: str = settings.vectordb_host) -> None:
"""Initialize the vectordb with example queries and ontology descriptions from the SPARQL endpoints"""
vectordb = get_vectordb(vectordb_host)

if not vectordb.collection_exists(settings.docs_collection_name):
vectordb.create_collection(
collection_name=settings.docs_collection_name,
vectors_config=VectorParams(size=settings.embedding_dimensions, distance=Distance.COSINE),
)
embedding_model = get_embedding_model()
docs: list[Document] = []
# if not vectordb.collection_exists(settings.docs_collection_name):
vectordb.create_collection(
collection_name=settings.docs_collection_name,
vectors_config=VectorParams(size=settings.embedding_dimensions, distance=Distance.COSINE),
)
embedding_model = get_embedding_model()
docs: list[Document] = []

endpoints_urls = [endpoint["endpoint_url"] for endpoint in settings.endpoints]
prefix_map = get_prefixes_for_endpoints(endpoints_urls)
endpoints_urls = [endpoint["endpoint_url"] for endpoint in settings.endpoints]
prefix_map = get_prefixes_for_endpoints(endpoints_urls)

for endpoint in settings.endpoints:
print(f"\n 🔎 Getting metadata for {endpoint['label']} at {endpoint['endpoint_url']}")
queries_loader = SparqlExamplesLoader(endpoint["endpoint_url"], verbose=True)
docs += queries_loader.load()
for endpoint in settings.endpoints:
print(f"\n 🔎 Getting metadata for {endpoint['label']} at {endpoint['endpoint_url']}")
queries_loader = SparqlExamplesLoader(endpoint["endpoint_url"], verbose=True)
docs += queries_loader.load()

void_loader = SparqlVoidShapesLoader(
endpoint["endpoint_url"],
prefix_map=prefix_map,
verbose=True,
)
docs += void_loader.load()
void_loader = SparqlVoidShapesLoader(
endpoint["endpoint_url"],
prefix_map=prefix_map,
verbose=True,
)
docs += void_loader.load()

docs += load_schemaorg_description(endpoint)
# NOTE: we dont use the ontology for now, schema from shex is better
# docs += load_ontology(endpoint)
docs += load_schemaorg_description(endpoint)
# NOTE: we dont use the ontology for now, schema from shex is better
# docs += load_ontology(endpoint)

# NOTE: Manually add infos for UniProt since we cant retrieve it for now. Taken from https://www.uniprot.org/help/about
uniprot_description_question = "What is the SIB resource UniProt about?"
docs.append(
Document(
page_content=uniprot_description_question,
metadata={
"question": uniprot_description_question,
"answer": """The Universal Protein Resource (UniProt) is a comprehensive resource for protein sequence and annotation data. The UniProt databases are the UniProt Knowledgebase (UniProtKB), the UniProt Reference Clusters (UniRef), and the UniProt Archive (UniParc). The UniProt consortium and host institutions EMBL-EBI, SIB and PIR are committed to the long-term preservation of the UniProt databases.
# NOTE: Manually add infos for UniProt since we cant retrieve it for now. Taken from https://www.uniprot.org/help/about
uniprot_description_question = "What is the SIB resource UniProt about?"
docs.append(
Document(
page_content=uniprot_description_question,
metadata={
"question": uniprot_description_question,
"answer": """The Universal Protein Resource (UniProt) is a comprehensive resource for protein sequence and annotation data. The UniProt databases are the UniProt Knowledgebase (UniProtKB), the UniProt Reference Clusters (UniRef), and the UniProt Archive (UniParc). The UniProt consortium and host institutions EMBL-EBI, SIB and PIR are committed to the long-term preservation of the UniProt databases.
UniProt is a collaboration between the European Bioinformatics Institute (EMBL-EBI), the SIB Swiss Institute of Bioinformatics and the Protein Information Resource (PIR). Across the three institutes more than 100 people are involved through different tasks such as database curation, software development and support.
UniProt is a collaboration between the European Bioinformatics Institute (EMBL-EBI), the SIB Swiss Institute of Bioinformatics and the Protein Information Resource (PIR). Across the three institutes more than 100 people are involved through different tasks such as database curation, software development and support.
EMBL-EBI and SIB together used to produce Swiss-Prot and TrEMBL, while PIR produced the Protein Sequence Database (PIR-PSD). These two data sets coexisted with different protein sequence coverage and annotation priorities. TrEMBL (Translated EMBL Nucleotide Sequence Data Library) was originally created because sequence data was being generated at a pace that exceeded Swiss-Prot's ability to keep up. Meanwhile, PIR maintained the PIR-PSD and related databases, including iProClass, a database of protein sequences and curated families. In 2002 the three institutes decided to pool their resources and expertise and formed the UniProt consortium.
EMBL-EBI and SIB together used to produce Swiss-Prot and TrEMBL, while PIR produced the Protein Sequence Database (PIR-PSD). These two data sets coexisted with different protein sequence coverage and annotation priorities. TrEMBL (Translated EMBL Nucleotide Sequence Data Library) was originally created because sequence data was being generated at a pace that exceeded Swiss-Prot's ability to keep up. Meanwhile, PIR maintained the PIR-PSD and related databases, including iProClass, a database of protein sequences and curated families. In 2002 the three institutes decided to pool their resources and expertise and formed the UniProt consortium.
The UniProt consortium is headed by Alex Bateman, Alan Bridge and Cathy Wu, supported by key staff, and receives valuable input from an independent Scientific Advisory Board.
""",
"endpoint_url": "https://sparql.uniprot.org/sparql/",
"iri": "http://www.uniprot.org/help/about",
"doc_type": "schemaorg_description",
},
)
The UniProt consortium is headed by Alex Bateman, Alan Bridge and Cathy Wu, supported by key staff, and receives valuable input from an independent Scientific Advisory Board.
""",
"endpoint_url": "https://sparql.uniprot.org/sparql/",
"iri": "http://www.uniprot.org/help/about",
"doc_type": "schemaorg_description",
},
)
)

print(f"Generating embeddings for {len(docs)} documents")
embeddings = embedding_model.embed([q.page_content for q in docs])
start_time = time.time()
vectordb.upsert(
collection_name=settings.docs_collection_name,
points=models.Batch(
ids=list(range(1, len(docs) + 1)),
vectors=embeddings,
payloads=[doc.metadata for doc in docs],
),
# wait=False, # Waiting for indexing to finish or not
)
print(
f"Done generating and indexing {len(docs)} documents into the vectordb in {time.time() - start_time} seconds"
)
print(f"Generating embeddings for {len(docs)} documents")
embeddings = embedding_model.embed([q.page_content for q in docs])
start_time = time.time()
vectordb.upsert(
collection_name=settings.docs_collection_name,
points=models.Batch(
ids=list(range(1, len(docs) + 1)),
vectors=embeddings,
payloads=[doc.metadata for doc in docs],
),
# wait=False, # Waiting for indexing to finish or not
)
print(
f"Done generating and indexing {len(docs)} documents into the vectordb in {time.time() - start_time} seconds"
)

if not vectordb.collection_exists(settings.entities_collection_name):
vectordb.create_collection(
collection_name=settings.entities_collection_name,
vectors_config=VectorParams(size=settings.embedding_dimensions, distance=Distance.COSINE),
)
# if vectordb.get_collection(settings.entities_collection_name).points_count == 0:
load_entities_embeddings_to_vectordb()
# NOTE: Loading entities embeddings to the vectordb is done in another script, because too long
# if not vectordb.collection_exists(settings.entities_collection_name):
# vectordb.create_collection(
# collection_name=settings.entities_collection_name,
# vectors_config=VectorParams(size=settings.embedding_dimensions, distance=Distance.COSINE),
# )
# # if vectordb.get_collection(settings.entities_collection_name).points_count == 0:
# load_entities_embeddings_to_vectordb()


# docs = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ast import literal_eval

from langchain_core.documents import Document
from qdrant_client import models
from qdrant_client import QdrantClient, models
from tqdm import tqdm

from sparql_llm.config import get_embedding_model, get_vectordb, settings
Expand Down Expand Up @@ -266,20 +266,34 @@ def generate_embeddings_for_entities():
if not os.path.exists(entities_embeddings_dir):
os.makedirs(entities_embeddings_dir)

with open(entities_embeddings_filepath, mode="w", newline="") as file:
writer = csv.writer(file)
header = ["label", "uri", "endpoint_url", "entity_type", "embedding"]
writer.writerow(header)

for doc, embedding in zip(docs, embeddings):
row = [
doc.metadata["label"],
doc.metadata["uri"],
doc.metadata["endpoint_url"],
doc.metadata["entity_type"],
embedding.tolist(),
]
writer.writerow(row)
vectordb_local = QdrantClient(
path="data/qdrant",
# host=host,
prefer_grpc=True,
)
vectordb_local.upsert(
collection_name=settings.entities_collection_name,
points=models.Batch(
ids=list(range(1, len(docs) + 1)),
vectors=embeddings,
payloads=[doc.metadata for doc in docs],
),
)

# NOTE: saving to a CSV makes it too slow to then read and upload to the vectordb, so we now directly upload to the vectordb
# with open(entities_embeddings_filepath, mode="w", newline="") as file:
# writer = csv.writer(file)
# header = ["label", "uri", "endpoint_url", "entity_type", "embedding"]
# writer.writerow(header)
# for doc, embedding in zip(docs, embeddings):
# row = [
# doc.metadata["label"],
# doc.metadata["uri"],
# doc.metadata["endpoint_url"],
# doc.metadata["entity_type"],
# embedding.tolist(),
# ]
# writer.writerow(row)

print(f"Done generating embeddings for {len(docs)} entities in {time.time() - start_time} seconds")

Expand All @@ -295,11 +309,14 @@ def generate_embeddings_for_entities():
# )


# NOTE: not used anymore, we now directly load to a local vectordb then move the loaded entities collection to the main vectordb
def load_entities_embeddings_to_vectordb():
start_time = time.time()
vectordb = get_vectordb()
docs = []
embeddings = []
batch_size = 100000
embeddings_count = 0

print(f"Reading entities embeddings from the .csv file at {entities_embeddings_filepath}")
with open(entities_embeddings_filepath) as file:
Expand All @@ -317,17 +334,45 @@ def load_entities_embeddings_to_vectordb():
)
)
embeddings.append(literal_eval(row["embedding"]))
embeddings_count += 1

if len(docs) == batch_size:
vectordb.upsert(
collection_name=settings.entities_collection_name,
points=models.Batch(
ids=list(range(embeddings_count - batch_size + 1, embeddings_count + 1)),
vectors=embeddings,
payloads=[doc.metadata for doc in docs],
),
)
docs = []
embeddings = []


print(
f"Found embeddings for {len(docs)} entities in {time.time() - start_time} seconds. Now adding them to the vectordb"
)
vectordb.upsert(
collection_name=settings.entities_collection_name,
points=models.Batch(
ids=list(range(1, len(docs) + 1)),
vectors=embeddings,
payloads=[doc.metadata for doc in docs],
),
)

# for i in range(0, len(docs), batch_size):
# batch_docs = docs[i:i + batch_size]
# batch_embeddings = embeddings[i:i + batch_size]
# vectordb.upsert(
# collection_name=settings.entities_collection_name,
# points=models.Batch(
# ids=list(range(i + 1, i + len(batch_docs) + 1)),
# vectors=batch_embeddings,
# payloads=[doc.metadata for doc in batch_docs],
# ),
# )

# vectordb.upsert(
# collection_name=settings.entities_collection_name,
# points=models.Batch(
# ids=list(range(1, len(docs) + 1)),
# vectors=embeddings,
# payloads=[doc.metadata for doc in docs],
# ),
# )

print(f"Done uploading embeddings for {len(docs)} entities in the vectordb in {time.time() - start_time} seconds")

Expand Down

0 comments on commit 7643843

Please sign in to comment.