Skip to content

Commit

Permalink
Merge pull request #8 from fastenhealth/reranking
Browse files Browse the repository at this point in the history
Reranking
  • Loading branch information
dgbaenar authored Sep 16, 2024
2 parents 8b7e6ae + 5d87acc commit cb79f5a
Show file tree
Hide file tree
Showing 17 changed files with 251 additions and 280 deletions.
2 changes: 2 additions & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from app.config.elasticsearch_config import create_index_if_not_exists
from app.models.sentence_transformer import get_sentence_transformer
from app.config.settings import settings
from app.services.reranking import RerankingService


embedding_model = get_sentence_transformer()
es_client = create_index_if_not_exists(settings.elasticsearch.index_name)
reranker_service = RerankingService()


def create_app():
Expand Down
2 changes: 1 addition & 1 deletion app/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ def __init__(self):
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
Empty file added app/data_models/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions app/data_models/search_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass


@dataclass
class SearchResult:
score: float
content: str
metadata: dict
44 changes: 20 additions & 24 deletions app/evaluation/retrieval/retrieval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def evaluate_resources_summaries_retrieval(
qa_references: list[dict],
search_text_boost: int = 1,
search_embedding_boost: int = 1,
k: int = 5
k: int = 5,
rerank_top_k: int = 0,
) -> dict:
# Initialize counters and sums for metrics
total_questions = 0
Expand All @@ -29,8 +30,7 @@ def evaluate_resources_summaries_retrieval(
reference_resource_id = response["custom_id"]
content = response["response"]["body"]["choices"][0]["message"]["content"]

questions_and_answers = json.loads(
content)["questions_and_answers"]
questions_and_answers = json.loads(content)["questions_and_answers"]

if len(questions_and_answers) > 0:
# Sample one random question per resource_id to evaluate
Expand All @@ -42,13 +42,15 @@ def evaluate_resources_summaries_retrieval(
total_questions += 1

# Query question
search_results = search_query(question,
embedding_model,
es_client,
k=k,
text_boost=search_text_boost,
embedding_boost=search_embedding_boost)

search_results = search_query(
question,
embedding_model,
es_client,
k=k,
text_boost=search_text_boost,
embedding_boost=search_embedding_boost,
rerank_top_k=rerank_top_k,
)
# Evaluate if any returned chunk belongs to the correct resource_id
found = False
rank = 0
Expand All @@ -59,19 +61,18 @@ def evaluate_resources_summaries_retrieval(

if search_results != {"detail": "Not Found"}:
for i, result in enumerate(search_results):
if result["metadata"]["resource_id"] == reference_resource_id:
if result.metadata["resource_id"] == reference_resource_id:
if not found:
total_contexts_found += 1
rank = i + 1
reciprocal_rank_sum += 1 / rank
found = True
retrieved_relevant_chunks += 1
elif search_results == {"detail": "Not Found"}:
search_results = {}
search_results = []

# Calculate precision and recall for this specific question
precision = retrieved_relevant_chunks / \
len(search_results) if len(search_results) > 0 else 0
precision = retrieved_relevant_chunks / len(search_results) if len(search_results) > 0 else 0
recall = retrieved_relevant_chunks / relevant_chunks if relevant_chunks > 0 else 0

precision_sum += precision
Expand All @@ -81,16 +82,11 @@ def evaluate_resources_summaries_retrieval(
position_sum += rank

# Calculate final metrics
retrieval_accuracy = round(
total_contexts_found / total_questions, 3) if total_questions > 0 else 0
average_position = round(
position_sum / total_contexts_found, 3) if total_contexts_found > 0 else 0
mrr = round(reciprocal_rank_sum / total_questions,
3) if total_questions > 0 else 0
average_precision = round(
precision_sum / total_questions, 3) if total_questions > 0 else 0
average_recall = round(recall_sum / total_questions,
3) if total_questions > 0 else 0
retrieval_accuracy = round(total_contexts_found / total_questions, 3) if total_questions > 0 else 0
average_position = round(position_sum / total_contexts_found, 3) if total_contexts_found > 0 else 0
mrr = round(reciprocal_rank_sum / total_questions, 3) if total_questions > 0 else 0
average_precision = round(precision_sum / total_questions, 3) if total_questions > 0 else 0
average_recall = round(recall_sum / total_questions, 3) if total_questions > 0 else 0

return {
# The percentage of questions for which the system successfully retrieved at least one relevant chunk.
Expand Down
23 changes: 7 additions & 16 deletions app/routes/database_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ async def bulk_load(file: UploadFile = File(...), text_key: str = Form(...)):
json_data = csv_to_dict(data)

else:
raise HTTPException(
status_code=400, detail="Unsupported file format. Only JSON and CSV are supported.")
raise HTTPException(status_code=400, detail="Unsupported file format. Only JSON and CSV are supported.")

try:
helpers.bulk(
Expand All @@ -43,33 +42,25 @@ async def bulk_load(file: UploadFile = File(...), text_key: str = Form(...)):
@router.delete("/delete_all_documents")
async def delete_all_documents(index_name: str):
try:
es_client.delete_by_query(index=index_name, body={
"query": {"match_all": {}}})
es_client.delete_by_query(index=index_name, body={"query": {"match_all": {}}})
logger.info(f"All documents deleted from index '{index_name}'")
return {"status": "success", "message": f"All documents deleted from index '{index_name}'"}
except Exception as e:
logger.error(f"Failed to delete documents: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to delete documents: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to delete documents: {str(e)}")


@router.get("/get_all_documents")
async def get_all_documents(index_name: str = settings.elasticsearch.index_name,
size: int = 2000):
async def get_all_documents(index_name: str = settings.elasticsearch.index_name, size: int = 2000):
try:
documents = fetch_all_documents(
index_name=index_name,
es_client=es_client,
size=size)
documents = fetch_all_documents(index_name=index_name, es_client=es_client, size=size)
return documents
except Exception as e:
logger.error(f"Error retrieving documents: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error retrieving documents: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error retrieving documents: {str(e)}")


@router.get("/search")
async def search_documents(query: str, k: int = 5, text_boost: float = 0.25, embedding_boost: float = 4.0):
results = search_query(query, embedding_model, es_client, k=k,
text_boost=text_boost, embedding_boost=embedding_boost)
results = search_query(query, embedding_model, es_client, k=k, text_boost=text_boost, embedding_boost=embedding_boost)
return results
61 changes: 30 additions & 31 deletions app/routes/evaluation_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,59 +15,55 @@


@router.post("/evaluate_retrieval")
async def evaluate_retrieval(file: UploadFile = File(...),
index_name: str = Form(
settings.elasticsearch.index_name),
size: int = Form(2000),
search_text_boost: float = Form(1),
search_embedding_boost: float = Form(1),
k: int = Form(5),
urls_in_resources: bool = Form(None),
questions_with_ids_and_dates: str = Form(None),
chunk_size: int = Form(None),
chunk_overlap: int = Form(None),
clearml_track_experiment: bool = Form(False),
clearml_experiment_name: str = Form("Retrieval evaluation"),
clearml_project_name: str = Form("Fasten")):
async def evaluate_retrieval(
file: UploadFile = File(...),
index_name: str = Form(settings.elasticsearch.index_name),
size: int = Form(2000),
search_text_boost: float = Form(1),
search_embedding_boost: float = Form(1),
k: int = Form(5),
rerank_top_k: int = Form(0),
urls_in_resources: bool = Form(None),
questions_with_ids_and_dates: str = Form(None),
chunk_size: int = Form(None),
chunk_overlap: int = Form(None),
clearml_track_experiment: bool = Form(False),
clearml_experiment_name: str = Form("Retrieval evaluation"),
clearml_project_name: str = Form("Fasten"),
):
# Read and process reference questions and answers in JSONL
try:
qa_references = []

file_data = await file.read()

for line in file_data.decode('utf-8').splitlines():
for line in file_data.decode("utf-8").splitlines():
qa_references.append(json.loads(line))
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON format.")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON format.")
# Count total chunks by resource in database
try:
documents = fetch_all_documents(
index_name=index_name,
es_client=es_client,
size=size)
id, counts = np.unique([resource["metadata"]["resource_id"]
for resource in documents], return_counts=True)
documents = fetch_all_documents(index_name=index_name, es_client=es_client, size=size)
id, counts = np.unique([resource["metadata"]["resource_id"] for resource in documents], return_counts=True)
resources_counts = dict(zip(id, counts))
except Exception as e:
logger.error(f"Error retrieving documents: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error retrieving documents: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error retrieving documents: {str(e)}")
# Evaluate retrieval
try:
if clearml_track_experiment:
params = {
"search_text_boost": search_text_boost,
"search_embedding_boost": search_embedding_boost,
"k": k,
"rerank_top_k": rerank_top_k,
"urls_in_resources": urls_in_resources,
"questions_with_ids_and_dates": questions_with_ids_and_dates,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap
"chunk_overlap": chunk_overlap,
}
unique_task_name = f"{clearml_experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
task = Task.init(project_name=clearml_project_name,
task_name=unique_task_name)
task = Task.init(project_name=clearml_project_name, task_name=unique_task_name)
task.connect(params)

retrieval_metrics = evaluate_resources_summaries_retrieval(
Expand All @@ -77,7 +73,9 @@ async def evaluate_retrieval(file: UploadFile = File(...),
qa_references=qa_references,
search_text_boost=search_text_boost,
search_embedding_boost=search_embedding_boost,
k=k)
k=k,
rerank_top_k=rerank_top_k,
)

# Upload metrics and close task
if task:
Expand All @@ -90,5 +88,6 @@ async def evaluate_retrieval(file: UploadFile = File(...),

except Exception as e:
logger.error(f"Error during retrieval evaluation: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error during retrieval evaluation: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error during retrieval evaluation: {str(e)}"
)
17 changes: 7 additions & 10 deletions app/routes/llm_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
async def answer_query(
query: str, k: int = 5, params=None, stream: bool = False, text_boost: float = 0.25, embedding_boost: float = 4.0
):
results = search_query(query, embedding_model, es_client, k=k,
text_boost=text_boost, embedding_boost=embedding_boost)
results = search_query(
query, embedding_model, es_client, k=k, text_boost=text_boost, embedding_boost=embedding_boost, rerank_top_k=0
)
if not results:
concatenated_content = "There is no context"
else:
Expand All @@ -41,21 +42,17 @@ async def summarize(
resources = await file.read()
resources = json.loads(resources)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON format.")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid JSON format.")
# Process file and summarize resources
try:
limit = 1 if limit <= 0 else limit
if limit:
resources_processed = process_resources(
data=resources, remove_urls=remove_urls)[:limit]
resources_processed = process_resources(data=resources, remove_urls=remove_urls)[:limit]
else:
resources_processed = process_resources(
data=resources, remove_urls=remove_urls)
resources_processed = process_resources(data=resources, remove_urls=remove_urls)
resources_summarized = summarize_resources(resources_processed, stream)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error during processing: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error during processing: {str(e)}")
# Save resources
try:
helpers.bulk(
Expand Down
8 changes: 5 additions & 3 deletions app/services/conversation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import time
from typing import List

from fastapi.responses import StreamingResponse

from app.services.llama_client import llm_client
from app.config.settings import logger
from app.data_models.search_result import SearchResult


def process_search_output(search_results):
def process_search_output(search_results: List[SearchResult]):
logger.info("Processing search results")
processed_contents = []
resources_id = []

for result in search_results:
content = result["content"]
resource_id = result["metadata"]["resource_id"]
content = result.content
resource_id = result.metadata["resource_id"]

processed_contents.append(content.replace("\\", ""))
resources_id.append(resource_id)
Expand Down
22 changes: 22 additions & 0 deletions app/services/reranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from FlagEmbedding import FlagReranker
from typing import List, Tuple

from app.data_models.search_result import SearchResult


class RerankingService:
def __init__(self, model_name="BAAI/bge-reranker-v2-m3"):
self.reranker = FlagReranker(model_name, use_fp16=True)

def rerank(self, query: str, documents: List[SearchResult]) -> List[Tuple[SearchResult, float]]:
"""Computes a score for each document in the list of documents and returns a ranked list of documents.
:param str query: user query used for ranking
:param List[str] documents: Documents to be ranked
:return tuple(str, float): list of tuples containing the document and its score
"""
scores = self.reranker.compute_score([[query, doc.content] for doc in documents], normalize=True)
print(scores)

ranked_docs = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
return ranked_docs
Loading

0 comments on commit cb79f5a

Please sign in to comment.