-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRAGmodel.py
81 lines (69 loc) · 2.5 KB
/
RAGmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware # Import CORSMiddleware
from pydantic import BaseModel
from typing import List
import os
from pinecone import Pinecone
from langchain_community.llms import OpenAI
from langchain_community.vectorstores import Pinecone as PineconeVectorStore
from langchain_openai import OpenAIEmbeddings # Updated import
from langchain.chains import RetrievalQA
# Initialize FastAPI app
app = FastAPI()
# Configure CORS middleware -> this allows api to be called with react frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # For development, allow all origins. In production, restrict this.
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# set/load environment variables -> edit this for production
# Initialize Pinecone client
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
# Set up components
INDEX_NAME = "alloraproduction"
embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"),
model="text-embedding-3-large")
# Connect to existing Pinecone index
try:
# Get Pinecone index object
index = pc.Index(INDEX_NAME)
# Create vector store directly with index
vectorstore = PineconeVectorStore(
index=index,
embedding=embeddings,
text_key="text" # Match your index's text field key
)
except Exception as e:
raise RuntimeError(f"Error connecting to Pinecone index: {str(e)}")
# Create retrieval QA chain
qa = RetrievalQA.from_chain_type(
llm=OpenAI(temperature=0, openai_api_key=os.getenv("OPENAI_API_KEY")),
chain_type="stuff",
retriever=vectorstore.as_retriever(),
return_source_documents=True
)
# Request/Response models
class ChatRequest(BaseModel):
message: str
class ChatResponse(BaseModel):
response: str
sources: List[str]
@app.get("/")
async def root():
return {"Ok"}
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
try:
result = qa.invoke({"query": request.message})
sources = list(set([doc.metadata.get("source", "") for doc in result["source_documents"]]))
return {
"response": result["result"],
"sources": sources
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT",8000))) # may need to edit this for production