-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathretrieval.py
61 lines (53 loc) · 2.25 KB
/
retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from sentence_transformers import util
import textwrap
class Retriever:
def __init__(self, embedding_index, similarity_method="dot"):
"""
embedding_index: an instance of EmbeddingIndex
that has self.embedding_model & self.collection
similarity_method: 'dot' or 'cosine'
"""
self.embedding_model = embedding_index.embedding_model
self.collection = embedding_index.collection
self.similarity_method = similarity_method
def retrieve_relevant_chunks(self, query: str, top_k: int = 3):
# Encode query
query_embedding = self.embedding_model.encode([query])[0]
# If you're using a vector DB like Chroma, the similarity measure is
# typically determined at index time (dot or cosine).
# The code below is only for demonstration if you want to do manual scoring.
# In Chroma, we just do:
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=top_k
)
retrieved_docs = []
for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
retrieved_docs.append((doc, meta))
return retrieved_docs
def build_context(retrieved_docs: list[tuple[str, dict]]) -> str:
"""
Combine the retrieved doc texts with metadata into a single context string.
"""
context_blocks = []
for i, (doc, meta) in enumerate(retrieved_docs):
source_info = f"(Source: {meta['source']} - Chunk ID: {meta['chunk_id']})"
block = f"[{i+1}] {doc}\n{source_info}"
context_blocks.append(block)
return "\n\n".join(context_blocks)
def pretty_print_doc(doc_text: str, width=80):
"""
Simple helper to nicely wrap text for console display.
"""
return textwrap.fill(doc_text, width=width)
def score_chunks_manual(query_embedding: torch.Tensor, chunk_embeddings: torch.Tensor, method="dot"):
"""
Illustrative manual scoring function for demonstration.
method can be 'dot' or 'cosine'
"""
if method == "dot":
scores = util.dot_score(query_embedding, chunk_embeddings)[0]
else:
scores = util.cos_sim(query_embedding, chunk_embeddings)[0]
return scores