Skip to content

Commit

Permalink
Merge pull request #7 from fastenhealth/openai_summaries
Browse files Browse the repository at this point in the history
Openai summaries
  • Loading branch information
dgbaenar authored Sep 10, 2024
2 parents ff4e664 + fbca494 commit 8b7e6ae
Show file tree
Hide file tree
Showing 27 changed files with 778 additions and 330 deletions.
27 changes: 27 additions & 0 deletions app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from fastapi import FastAPI
from app.config.elasticsearch_config import create_index_if_not_exists
from app.models.sentence_transformer import get_sentence_transformer
from app.config.settings import settings


embedding_model = get_sentence_transformer()
es_client = create_index_if_not_exists(settings.elasticsearch.index_name)


def create_app():
app = FastAPI()

from app.routes.database_endpoints import router as database_router
from app.routes.llm_endpoints import router as llm_router
from app.routes.openai_endpoints import router as openai_router
from app.routes.evaluation_endpoints import router as evaluation_router

app.include_router(database_router, prefix="/database")
app.include_router(llm_router, prefix="/generation")
app.include_router(openai_router, prefix="/openai")
app.include_router(evaluation_router, prefix="/evaluation")

return app


app = create_app()
21 changes: 6 additions & 15 deletions app/config/elasticsearch_config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from elasticsearch import Elasticsearch

from config.settings import settings, logger
from app.config.settings import settings, logger


def get_es_client():
return Elasticsearch(
hosts=[settings.elasticsearch.host],
basic_auth=(settings.elasticsearch.user,
settings.elasticsearch.password),
basic_auth=(settings.elasticsearch.user, settings.elasticsearch.password),
max_retries=10,
)

Expand All @@ -16,16 +15,9 @@ def get_mapping():
return {
"mappings": {
"properties": {
"content": {
"type": "text"
},
"embedding": {
"type": "dense_vector",
"dims": 384
},
"metadata": {
"type": "object"
}
"content": {"type": "text"},
"embedding": {"type": "dense_vector", "dims": 384},
"metadata": {"type": "object"},
}
}
}
Expand All @@ -34,7 +26,6 @@ def get_mapping():
def create_index_if_not_exists(index_name):
es_client = get_es_client()
if not es_client.indices.exists(index=index_name):
es_client.indices.create(index=index_name,
body=get_mapping())
es_client.indices.create(index=index_name, body=get_mapping())
logger.info(f"Index '{index_name}' created.")
return es_client
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ Provide detailed, helpful, and accurate responses, and include references where
If information is not available, politely inform the user that you cannot provide an answer.
<|end|>
<|user|>
Context information is below.\n
---------------------\n
{context}\n
---------------------\n
Context information is below.
---------------------
{context}
---------------------
Given the context information (if there is any),
this is my message: {query}
<|assistant|>
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ If information is not available, politely inform the user that you cannot provid

<|eot_id|><|start_header_id|>user<|end_header_id|>

Context information is below.\n
---------------------\n
{context}\n
---------------------\n
Context information is below.
---------------------
{context}
---------------------
Given the context information (if there is any),
this is my message: {query}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ attributes. Prioritize essential data for efficient storage
and retrieval, and omit any unnecessary details.
<|end|>
<|user|>
This is the resource:\n
---------------------\n
{query}\n
---------------------\n
This is the resource:
---------------------
{query}
---------------------
<|assistant|>
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ and retrieval, and omit any unnecessary details.

<|eot_id|><|start_header_id|>user<|end_header_id|>

This is the resource:\n
---------------------\n
{query}\n
---------------------\n
This is the resource:
---------------------
{query}
---------------------
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
5 changes: 5 additions & 0 deletions app/config/prompts/summaries_openai_system_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
You will receive a single FHIR resource. Summarize the key information
from the resource in a clear, concise paragraph of plain text,
ideally up to 800 characters. The output should be human-readable and
understandable, not in JSON or other structured formats. Focus on the most
relevant attributes and omit unnecessary details.
19 changes: 13 additions & 6 deletions app/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,36 @@ def __init__(self):
# Base dir
base_dir = os.path.dirname(os.path.abspath(__file__))
# Embedding model
self.embedding_model_name = os.getenv(
"EMBEDDING_MODEL_NAME", "all-MiniLM-L6-v2")
self.embedding_model_name = os.getenv("EMBEDDING_MODEL_NAME", "all-MiniLM-L6-v2")
# LLM host
self.llm_host = os.getenv("LLAMA_HOST", "http://localhost:8090")
# Conversation prompts
self.conversation_model_prompt = self.load_prompt(
os.path.join(base_dir, "prompts/conversation_model_prompt_Phi-3.5-instruct.txt"))
os.path.join(base_dir, "prompts/conversation_model_prompt_Phi-3.5-instruct.txt")
)
# Summaries prompts
self.summaries_model_prompt = self.load_prompt(
os.path.join(base_dir, "prompts/summaries_model_prompt_Phi-3.5-instruct.txt"))
os.path.join(base_dir, "prompts/summaries_model_prompt_Phi-3.5-instruct.txt")
)
# Summaries openai system prompt
self.summaries_openai_system_prompt = self.load_prompt(
os.path.join(base_dir, "prompts/summaries_openai_system_prompt.txt")
)

def load_prompt(self, file_path: str) -> str:
with open(file_path, 'r') as file:
return file.read().strip()
with open(file_path, "r") as file:
return file.read().strip().replace("\n", " ")


class Settings:
def __init__(self):
self.elasticsearch = ElasticsearchSettings()
self.model = ModelsSettings()


settings = Settings()

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logging.getLogger("elastic_transport").setLevel(logging.WARNING)
23 changes: 5 additions & 18 deletions app/db/index_documents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
def bulk_load_fhir_data(data: list[dict],
text_key: str,
embedding_model,
index_name):
def bulk_load_fhir_data(data: list[dict], text_key: str, embedding_model, index_name):
"""
Function to load in bulk mode a FHIR data
"""
Expand All @@ -10,11 +7,8 @@ def bulk_load_fhir_data(data: list[dict],
resource_type = value.get("resource_type")
resource = value.get(text_key)
embedding = embedding_model.encode(resource)

metadata = {
"resource_id": resource_id,
"resource_type": resource_type
}

metadata = {"resource_id": resource_id, "resource_type": resource_type}

if "tokens_evaluated" in value:
metadata["tokens_evaluated"] = value["tokens_evaluated"]
Expand All @@ -24,12 +18,5 @@ def bulk_load_fhir_data(data: list[dict],
metadata["prompt_ms"] = value["prompt_ms"]
if "predicted_ms" in value:
metadata["predicted_ms"] = value["predicted_ms"]

yield {
"_index": index_name,
"_source": {
"content": resource,
"embedding": embedding,
"metadata": metadata
}
}

yield {"_index": index_name, "_source": {"content": resource, "embedding": embedding, "metadata": metadata}}
Empty file added app/evaluation/__init__.py
Empty file.
Empty file.
108 changes: 108 additions & 0 deletions app/evaluation/retrieval/retrieval_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import json
import random

from tqdm import tqdm

from app.services.search_documents import search_query


def evaluate_resources_summaries_retrieval(
es_client: str,
embedding_model: str,
resource_chunk_counts: dict,
qa_references: list[dict],
search_text_boost: int = 1,
search_embedding_boost: int = 1,
k: int = 5
) -> dict:
# Initialize counters and sums for metrics
total_questions = 0
total_contexts_found = 0
position_sum = 0
reciprocal_rank_sum = 0
precision_sum = 0
recall_sum = 0

# Iterate over the OpenAI responses
for response in tqdm(qa_references, total=len(qa_references), desc="Calculating retrieval metrics"):
# Get content and id of openai responses
reference_resource_id = response["custom_id"]
content = response["response"]["body"]["choices"][0]["message"]["content"]

questions_and_answers = json.loads(
content)["questions_and_answers"]

if len(questions_and_answers) > 0:
# Sample one random question per resource_id to evaluate
questions_and_answers = [random.choice(questions_and_answers)]

for qa in questions_and_answers:
if isinstance(qa, dict) and "question" in qa:
question = qa["question"]
total_questions += 1

# Query question
search_results = search_query(question,
embedding_model,
es_client,
k=k,
text_boost=search_text_boost,
embedding_boost=search_embedding_boost)

# Evaluate if any returned chunk belongs to the correct resource_id
found = False
rank = 0
retrieved_relevant_chunks = 0

# Get the total number of relevant chunks for this resource_id
relevant_chunks = resource_chunk_counts[reference_resource_id]

if search_results != {"detail": "Not Found"}:
for i, result in enumerate(search_results):
if result["metadata"]["resource_id"] == reference_resource_id:
if not found:
total_contexts_found += 1
rank = i + 1
reciprocal_rank_sum += 1 / rank
found = True
retrieved_relevant_chunks += 1
elif search_results == {"detail": "Not Found"}:
search_results = {}

# Calculate precision and recall for this specific question
precision = retrieved_relevant_chunks / \
len(search_results) if len(search_results) > 0 else 0
recall = retrieved_relevant_chunks / relevant_chunks if relevant_chunks > 0 else 0

precision_sum += precision
recall_sum += recall

if found:
position_sum += rank

# Calculate final metrics
retrieval_accuracy = round(
total_contexts_found / total_questions, 3) if total_questions > 0 else 0
average_position = round(
position_sum / total_contexts_found, 3) if total_contexts_found > 0 else 0
mrr = round(reciprocal_rank_sum / total_questions,
3) if total_questions > 0 else 0
average_precision = round(
precision_sum / total_questions, 3) if total_questions > 0 else 0
average_recall = round(recall_sum / total_questions,
3) if total_questions > 0 else 0

return {
# The percentage of questions for which the system successfully retrieved at least one relevant chunk.
"Retrieval Accuracy": retrieval_accuracy,
"Average Position": average_position,
"MRR": mrr,
# Precision = Number of relevant chunks returned / Total number of chunks returned
"Average Precision": average_precision,
# Recall = Number of relevant chunks returned / Total number of relevant chunks that exist
"Average Recall": average_recall,
# Others
"Total Questions": total_questions,
"Total contexts found": total_contexts_found,
"Total positions sum": position_sum,
}
Loading

0 comments on commit 8b7e6ae

Please sign in to comment.