Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reranking #8

Merged
merged 2 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from app.config.elasticsearch_config import create_index_if_not_exists
from app.models.sentence_transformer import get_sentence_transformer
from app.config.settings import settings
from app.services.reranking import RerankingService


embedding_model = get_sentence_transformer()
es_client = create_index_if_not_exists(settings.elasticsearch.index_name)
reranker_service = RerankingService()


def create_app():
Expand Down
2 changes: 1 addition & 1 deletion app/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ def __init__(self):
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
Empty file added app/data_models/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions app/data_models/search_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass


@dataclass
class SearchResult:
score: float
content: str
metadata: dict
44 changes: 20 additions & 24 deletions app/evaluation/retrieval/retrieval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def evaluate_resources_summaries_retrieval(
qa_references: list[dict],
search_text_boost: int = 1,
search_embedding_boost: int = 1,
k: int = 5
k: int = 5,
rerank_top_k: int = 0,
) -> dict:
# Initialize counters and sums for metrics
total_questions = 0
Expand All @@ -29,8 +30,7 @@ def evaluate_resources_summaries_retrieval(
reference_resource_id = response["custom_id"]
content = response["response"]["body"]["choices"][0]["message"]["content"]

questions_and_answers = json.loads(
content)["questions_and_answers"]
questions_and_answers = json.loads(content)["questions_and_answers"]

if len(questions_and_answers) > 0:
# Sample one random question per resource_id to evaluate
Expand All @@ -42,13 +42,15 @@ def evaluate_resources_summaries_retrieval(
total_questions += 1

# Query question
search_results = search_query(question,
embedding_model,
es_client,
k=k,
text_boost=search_text_boost,
embedding_boost=search_embedding_boost)

search_results = search_query(
question,
embedding_model,
es_client,
k=k,
text_boost=search_text_boost,
embedding_boost=search_embedding_boost,
rerank_top_k=rerank_top_k,
)
# Evaluate if any returned chunk belongs to the correct resource_id
found = False
rank = 0
Expand All @@ -59,19 +61,18 @@ def evaluate_resources_summaries_retrieval(

if search_results != {"detail": "Not Found"}:
for i, result in enumerate(search_results):
if result["metadata"]["resource_id"] == reference_resource_id:
if result.metadata["resource_id"] == reference_resource_id:
if not found:
total_contexts_found += 1
rank = i + 1
reciprocal_rank_sum += 1 / rank
found = True
retrieved_relevant_chunks += 1
elif search_results == {"detail": "Not Found"}:
search_results = {}
search_results = []

# Calculate precision and recall for this specific question
precision = retrieved_relevant_chunks / \
len(search_results) if len(search_results) > 0 else 0
precision = retrieved_relevant_chunks / len(search_results) if len(search_results) > 0 else 0
recall = retrieved_relevant_chunks / relevant_chunks if relevant_chunks > 0 else 0

precision_sum += precision
Expand All @@ -81,16 +82,11 @@ def evaluate_resources_summaries_retrieval(
position_sum += rank

# Calculate final metrics
retrieval_accuracy = round(
total_contexts_found / total_questions, 3) if total_questions > 0 else 0
average_position = round(
position_sum / total_contexts_found, 3) if total_contexts_found > 0 else 0
mrr = round(reciprocal_rank_sum / total_questions,
3) if total_questions > 0 else 0
average_precision = round(
precision_sum / total_questions, 3) if total_questions > 0 else 0
average_recall = round(recall_sum / total_questions,
3) if total_questions > 0 else 0
retrieval_accuracy = round(total_contexts_found / total_questions, 3) if total_questions > 0 else 0
average_position = round(position_sum / total_contexts_found, 3) if total_contexts_found > 0 else 0
mrr = round(reciprocal_rank_sum / total_questions, 3) if total_questions > 0 else 0
average_precision = round(precision_sum / total_questions, 3) if total_questions > 0 else 0
average_recall = round(recall_sum / total_questions, 3) if total_questions > 0 else 0

return {
# The percentage of questions for which the system successfully retrieved at least one relevant chunk.
Expand Down
23 changes: 7 additions & 16 deletions app/routes/database_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ async def bulk_load(file: UploadFile = File(...), text_key: str = Form(...)):
json_data = csv_to_dict(data)

else:
raise HTTPException(
status_code=400, detail="Unsupported file format. Only JSON and CSV are supported.")
raise HTTPException(status_code=400, detail="Unsupported file format. Only JSON and CSV are supported.")

try:
helpers.bulk(
Expand All @@ -43,33 +42,25 @@ async def bulk_load(file: UploadFile = File(...), text_key: str = Form(...)):
@router.delete("/delete_all_documents")
async def delete_all_documents(index_name: str):
try:
es_client.delete_by_query(index=index_name, body={
"query": {"match_all": {}}})
es_client.delete_by_query(index=index_name, body={"query": {"match_all": {}}})
logger.info(f"All documents deleted from index '{index_name}'")
return {"status": "success", "message": f"All documents deleted from index '{index_name}'"}
except Exception as e:
logger.error(f"Failed to delete documents: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to delete documents: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to delete documents: {str(e)}")


@router.get("/get_all_documents")
async def get_all_documents(index_name: str = settings.elasticsearch.index_name,
size: int = 2000):
async def get_all_documents(index_name: str = settings.elasticsearch.index_name, size: int = 2000):
try:
documents = fetch_all_documents(
index_name=index_name,
es_client=es_client,
size=size)
documents = fetch_all_documents(index_name=index_name, es_client=es_client, size=size)
return documents
except Exception as e:
logger.error(f"Error retrieving documents: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error retrieving documents: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error retrieving documents: {str(e)}")


@router.get("/search")
async def search_documents(query: str, k: int = 5, text_boost: float = 0.25, embedding_boost: float = 4.0):
results = search_query(query, embedding_model, es_client, k=k,
text_boost=text_boost, embedding_boost=embedding_boost)
results = search_query(query, embedding_model, es_client, k=k, text_boost=text_boost, embedding_boost=embedding_boost)
return results
61 changes: 30 additions & 31 deletions app/routes/evaluation_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,59 +15,55 @@


@router.post("/evaluate_retrieval")
async def evaluate_retrieval(file: UploadFile = File(...),
index_name: str = Form(
settings.elasticsearch.index_name),
size: int = Form(2000),
search_text_boost: float = Form(1),
search_embedding_boost: float = Form(1),
k: int = Form(5),
urls_in_resources: bool = Form(None),
questions_with_ids_and_dates: str = Form(None),
chunk_size: int = Form(None),
chunk_overlap: int = Form(None),
clearml_track_experiment: bool = Form(False),
clearml_experiment_name: str = Form("Retrieval evaluation"),
clearml_project_name: str = Form("Fasten")):
async def evaluate_retrieval(
file: UploadFile = File(...),
index_name: str = Form(settings.elasticsearch.index_name),
size: int = Form(2000),
search_text_boost: float = Form(1),
search_embedding_boost: float = Form(1),
k: int = Form(5),
rerank_top_k: int = Form(0),
urls_in_resources: bool = Form(None),
questions_with_ids_and_dates: str = Form(None),
chunk_size: int = Form(None),
chunk_overlap: int = Form(None),
clearml_track_experiment: bool = Form(False),
clearml_experiment_name: str = Form("Retrieval evaluation"),
clearml_project_name: str = Form("Fasten"),
):
# Read and process reference questions and answers in JSONL
try:
qa_references = []

file_data = await file.read()

for line in file_data.decode('utf-8').splitlines():
for line in file_data.decode("utf-8").splitlines():
qa_references.append(json.loads(line))
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON format.")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON format.")
# Count total chunks by resource in database
try:
documents = fetch_all_documents(
index_name=index_name,
es_client=es_client,
size=size)
id, counts = np.unique([resource["metadata"]["resource_id"]
for resource in documents], return_counts=True)
documents = fetch_all_documents(index_name=index_name, es_client=es_client, size=size)
id, counts = np.unique([resource["metadata"]["resource_id"] for resource in documents], return_counts=True)
resources_counts = dict(zip(id, counts))
except Exception as e:
logger.error(f"Error retrieving documents: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error retrieving documents: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error retrieving documents: {str(e)}")
# Evaluate retrieval
try:
if clearml_track_experiment:
params = {
"search_text_boost": search_text_boost,
"search_embedding_boost": search_embedding_boost,
"k": k,
"rerank_top_k": rerank_top_k,
"urls_in_resources": urls_in_resources,
"questions_with_ids_and_dates": questions_with_ids_and_dates,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap
"chunk_overlap": chunk_overlap,
}
unique_task_name = f"{clearml_experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
task = Task.init(project_name=clearml_project_name,
task_name=unique_task_name)
task = Task.init(project_name=clearml_project_name, task_name=unique_task_name)
task.connect(params)

retrieval_metrics = evaluate_resources_summaries_retrieval(
Expand All @@ -77,7 +73,9 @@ async def evaluate_retrieval(file: UploadFile = File(...),
qa_references=qa_references,
search_text_boost=search_text_boost,
search_embedding_boost=search_embedding_boost,
k=k)
k=k,
rerank_top_k=rerank_top_k,
)

# Upload metrics and close task
if task:
Expand All @@ -90,5 +88,6 @@ async def evaluate_retrieval(file: UploadFile = File(...),

except Exception as e:
logger.error(f"Error during retrieval evaluation: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error during retrieval evaluation: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error during retrieval evaluation: {str(e)}"
)
17 changes: 7 additions & 10 deletions app/routes/llm_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
async def answer_query(
query: str, k: int = 5, params=None, stream: bool = False, text_boost: float = 0.25, embedding_boost: float = 4.0
):
results = search_query(query, embedding_model, es_client, k=k,
text_boost=text_boost, embedding_boost=embedding_boost)
results = search_query(
query, embedding_model, es_client, k=k, text_boost=text_boost, embedding_boost=embedding_boost, rerank_top_k=0
)
if not results:
concatenated_content = "There is no context"
else:
Expand All @@ -41,21 +42,17 @@ async def summarize(
resources = await file.read()
resources = json.loads(resources)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON format.")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON format.")
# Process file and summarize resources
try:
limit = 1 if limit <= 0 else limit
if limit:
resources_processed = process_resources(
data=resources, remove_urls=remove_urls)[:limit]
resources_processed = process_resources(data=resources, remove_urls=remove_urls)[:limit]
else:
resources_processed = process_resources(
data=resources, remove_urls=remove_urls)
resources_processed = process_resources(data=resources, remove_urls=remove_urls)
resources_summarized = summarize_resources(resources_processed, stream)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error during processing: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error during processing: {str(e)}")
# Save resources
try:
helpers.bulk(
Expand Down
8 changes: 5 additions & 3 deletions app/services/conversation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import time
from typing import List

from fastapi.responses import StreamingResponse

from app.services.llama_client import llm_client
from app.config.settings import logger
from app.data_models.search_result import SearchResult


def process_search_output(search_results):
def process_search_output(search_results: List[SearchResult]):
logger.info("Processing search results")
processed_contents = []
resources_id = []

for result in search_results:
content = result["content"]
resource_id = result["metadata"]["resource_id"]
content = result.content
resource_id = result.metadata["resource_id"]

processed_contents.append(content.replace("\\", ""))
resources_id.append(resource_id)
Expand Down
22 changes: 22 additions & 0 deletions app/services/reranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from FlagEmbedding import FlagReranker
from typing import List, Tuple

from app.data_models.search_result import SearchResult


class RerankingService:
def __init__(self, model_name="BAAI/bge-reranker-v2-m3"):
self.reranker = FlagReranker(model_name, use_fp16=True)

def rerank(self, query: str, documents: List[SearchResult]) -> List[Tuple[SearchResult, float]]:
"""Computes a score for each document in the list of documents and returns a ranked list of documents.

:param str query: user query used for ranking
:param List[str] documents: Documents to be ranked
:return tuple(str, float): list of tuples containing the document and its score
"""
scores = self.reranker.compute_score([[query, doc.content] for doc in documents], normalize=True)
print(scores)

ranked_docs = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
return ranked_docs
Loading