-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathembed_and_retrieve.py
107 lines (92 loc) · 4.77 KB
/
embed_and_retrieve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, ServiceContext
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core.storage.storage_context import StorageContext
from llama_index.llms.openai import OpenAI
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import chromadb
import openai
from huggingface_hub.utils import HfHubHTTPError
from openai import OpenAIError
import requests
import logging
from logging.handlers import RotatingFileHandler
logger = logging.getLogger("chat_with_your_documents")
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s : %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[logging.StreamHandler(), RotatingFileHandler(filename="chat_with_your_documents_app.log", maxBytes=5*1024*1024, backupCount=1)])
def get_logger():
return logger
def validate_api_key(provider, api_key):
if provider == "OpenAI":
try:
url = "https://api.openai.com/v1/engines"
headers = {
"Authorization": f"Bearer {api_key}"
}
response = requests.get(url, headers=headers)
return response.status_code == 200
except openai.OpenAIError as e:
return False
elif provider == "HuggingFace":
if not api_key:
return True
try:
response = requests.get(
"https://huggingface.co/api/whoami-v2",
headers={"Authorization": f"Bearer {api_key}"},
)
return response.status_code == 200
except:
return False
def create_query_engine(file_path, provider, api_key, download_llm=False):
# Set up the embedding and inference models
if provider == "OpenAI":
embed_model = OpenAIEmbedding(model="text-embedding-3-small", api_key=api_key)
llm = OpenAI(model="gpt-4o-mini", api_key=api_key)
else:
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2", token=api_key)
llm = HuggingFaceInferenceAPI(model_name="meta-llama/Meta-Llama-3-8B-Instruct", token=api_key)
documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
# Set up the vector store (ChromaDB)
chroma_client = chromadb.Client()
collection_name = "document_collection"
try:
chroma_collection = chroma_client.create_collection(collection_name)
except (ValueError, chromadb.db.base.UniqueConstraintError, chromadb.errors.ChromaError) as e:
logger.warning(f"{e}: Collection already exists. Deleting and creating a new collection.")
chroma_client.delete_collection(collection_name)
chroma_collection = chroma_client.create_collection(collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Create vector store index
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
index = VectorStoreIndex.from_documents(documents, storage_context=storage_context, service_context=service_context)
# Create and return query engine (attempt streaming support)
try:
query_engine = index.as_query_engine(streaming=True, service_context=service_context)
response = query_engine.query("What is the document about?")
for text in response.response_gen:
if text is not None:
break
except NotImplementedError as e:
logger.warning(f"{e}: Streaming not supported. Creating query engine without streaming.")
try:
query_engine = index.as_query_engine(service_context=service_context)
response = query_engine.query("What is the document about?")
except HfHubHTTPError as e:
logger.error(f"{e}: Most likely the rate limits of the HuggingFace API have been exceeded.")
st.error(f"""{e}. Unable to create query engine from HuggingFace.
Please try after sometime. If the error persists, get a new API key
at https://hf.co/settings/tokens and try again.""")
return None
except OpenAIError as e:
logger.error(f"{e}: Most likely the rate limits of the OpenAI API have been exceeded.")
st.error(f"""{e}. Unable to create query engine from OpenAI. Please get a new API key at
https://platform.openai.com/api-keys or https://platform.openai.com/settings/profile?tab=api-keys
or try after sometime.""")
return None
return query_engine