-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
316 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from fastapi import FastAPI, HTTPException | ||
from fastapi.middleware.cors import CORSMiddleware # Import CORSMiddleware | ||
from pydantic import BaseModel | ||
from typing import List | ||
import os | ||
from pinecone import Pinecone | ||
from langchain_community.llms import OpenAI | ||
from langchain_community.vectorstores import Pinecone as PineconeVectorStore | ||
from langchain_openai import OpenAIEmbeddings # Updated import | ||
from langchain.chains import RetrievalQA | ||
|
||
# Initialize FastAPI app | ||
app = FastAPI() | ||
|
||
# Configure CORS middleware -> this allows api to be called with react frontend | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], # For development, allow all origins. In production, restrict this. | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
|
||
# set/load environment variables -> edit this for production | ||
os.environ["OPENAI_API_KEY"] = "" | ||
os.environ["PINECONE_API_KEY"] = "" | ||
|
||
|
||
# Initialize Pinecone client | ||
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) | ||
|
||
# Set up components | ||
INDEX_NAME = "alloraproduction" | ||
embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"), | ||
model="text-embedding-3-large") | ||
|
||
# Connect to existing Pinecone index | ||
try: | ||
# Get Pinecone index object | ||
index = pc.Index(INDEX_NAME) | ||
|
||
# Create vector store directly with index | ||
vectorstore = PineconeVectorStore( | ||
index=index, | ||
embedding=embeddings, | ||
text_key="text" # Match your index's text field key | ||
) | ||
except Exception as e: | ||
raise RuntimeError(f"Error connecting to Pinecone index: {str(e)}") | ||
|
||
# Create retrieval QA chain | ||
qa = RetrievalQA.from_chain_type( | ||
llm=OpenAI(temperature=0, openai_api_key=os.getenv("OPENAI_API_KEY")), | ||
chain_type="stuff", | ||
retriever=vectorstore.as_retriever(), | ||
return_source_documents=True | ||
) | ||
|
||
# Request/Response models | ||
class ChatRequest(BaseModel): | ||
message: str | ||
|
||
class ChatResponse(BaseModel): | ||
response: str | ||
sources: List[str] | ||
|
||
@app.post("/chat", response_model=ChatResponse) | ||
async def chat_endpoint(request: ChatRequest): | ||
try: | ||
result = qa.invoke({"query": request.message}) | ||
sources = list(set([doc.metadata.get("source", "") for doc in result["source_documents"]])) | ||
return { | ||
"response": result["result"], | ||
"sources": sources | ||
} | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run(app, host="0.0.0.0", port=8000) # may need to edit this for production |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
|
||
import React, { useState } from "react"; | ||
import ChatComponent from "./chatbutton1"; // Adjust the path as needed | ||
|
||
function AiButton() { | ||
// State to control whether the ChatComponent is displayed | ||
const [showChat, setShowChat] = useState(false); | ||
|
||
// Toggle function to open or close the chat | ||
const toggleChat = () => { | ||
setShowChat((prev) => !prev); | ||
}; | ||
|
||
return ( | ||
<div style={{ maxWidth: "800px", margin: "0 auto", padding: "20px" }}> | ||
{/* Render the "Ask AI" button if the chat is not shown */} | ||
{!showChat && ( | ||
<button | ||
onClick={toggleChat} | ||
style={{ | ||
padding: "10px 20px", | ||
fontSize: "16px", | ||
cursor: "pointer", | ||
marginBottom: "20px", | ||
backgroundColor: "#007bff", // Blue background | ||
color: "#fff", // White text | ||
border: "none", // Remove default border | ||
borderRadius: "5px", // Rounded corners | ||
transition: "background-color 0.3s ease", // Smooth hover effect | ||
}} | ||
onMouseOver={(e) => (e.target.style.backgroundColor = "#0056b3")} // Darker blue on hover | ||
onMouseOut={(e) => (e.target.style.backgroundColor = "#007bff")} // Revert on mouse out | ||
> | ||
Ask AI | ||
</button> | ||
)} | ||
{/* Render the ChatComponent when showChat is true, passing the onClose prop */} | ||
{showChat && <ChatComponent onClose={toggleChat} />} | ||
</div> | ||
); | ||
} | ||
|
||
export default AiButton; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import React, { useState, useRef, useEffect } from "react"; | ||
|
||
function ChatComponent({ onClose }) { | ||
// holds the current user input and the chat history. | ||
const [inputMessage, setInputMessage] = useState(""); | ||
const [chatHistory, setChatHistory] = useState([]); | ||
|
||
// this references the chat history container. | ||
const chatContainerRef = useRef(null); | ||
|
||
// gives an auto-scroll effect to the bottom whenever the chat history changes. | ||
useEffect(() => { | ||
if (chatContainerRef.current) { | ||
chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight; | ||
} | ||
}, [chatHistory]); | ||
|
||
// this is handler for form submission. | ||
const handleSubmit = async (e) => { | ||
e.preventDefault(); | ||
|
||
// Add user's message to the chat history. | ||
const newUserEntry = { sender: "user", text: inputMessage }; | ||
setChatHistory((prev) => [...prev, newUserEntry]); | ||
|
||
try { | ||
// Send user's message to the FastAPI backend. | ||
const response = await fetch("http://localhost:8000/chat", { // Update the URL during production | ||
method: "POST", | ||
headers: { | ||
"Content-Type": "application/json", | ||
}, | ||
body: JSON.stringify({ message: inputMessage }), | ||
}); | ||
console.log("API went through"); | ||
|
||
if (!response.ok) { | ||
throw new Error(`Server error: ${response.statusText}`); | ||
} | ||
|
||
// Parse the JSON response. | ||
const data = await response.json(); | ||
|
||
|
||
// Add the assistant's response to the chat history. | ||
const newBotEntry = { | ||
sender: "bot", | ||
text: data.response, | ||
sources: data.sources, | ||
}; | ||
setChatHistory((prev) => [...prev, newBotEntry]); | ||
} catch (error) { | ||
console.error("Error fetching chat response:", error); | ||
// display an error message in the UI. | ||
const errorEntry = { | ||
sender: "bot", text: "Sorry, something went wrong." | ||
}; | ||
setChatHistory((prev) => [...prev, errorEntry]); | ||
} | ||
|
||
// Clear the input field. | ||
setInputMessage(""); | ||
}; | ||
|
||
return ( | ||
<div | ||
className="chat-container" | ||
style={{ | ||
maxWidth: "600px", | ||
margin: "0 auto", | ||
backgroundColor: "#000", | ||
color: "#fff", // Set text color to white for visibility | ||
padding: "20px", | ||
borderRadius: "10px", // Rounded corners nicer UI | ||
}} | ||
> | ||
{/* Header with title and close button */} | ||
<div | ||
style={{ | ||
display: "flex", | ||
justifyContent: "space-between", | ||
alignItems: "center", | ||
marginBottom: "10px", | ||
}} | ||
> | ||
<h2 style={{ color: "#fff", margin: 0 }}>Chat with our AI</h2> | ||
<button | ||
onClick={onClose} | ||
style={{ | ||
background: "transparent", | ||
border: "none", | ||
color: "#fff", | ||
cursor: "pointer", | ||
fontSize: "16px", | ||
}} | ||
aria-label="Close Chat" | ||
> | ||
❌ | ||
</button> | ||
</div> | ||
|
||
<div | ||
className="chat-history" | ||
ref={chatContainerRef} | ||
style={{ | ||
border: "1px solid #ccc", | ||
padding: "10px", | ||
height: "300px", // Fixed height | ||
overflowY: "scroll", // Enable vertical scrolling | ||
backgroundColor: "#1e1e1e", // Darker background for chat history | ||
}} | ||
> | ||
{chatHistory.map((entry, index) => ( | ||
<div | ||
key={index} | ||
style={{ | ||
textAlign: entry.sender === "user" ? "right" : "left", | ||
margin: "10px 0", | ||
}} | ||
> | ||
<div | ||
style={{ | ||
display: "inline-block", | ||
background: entry.sender === "user" ? "#4caf50" : "#333", | ||
color: "#fff", | ||
padding: "10px", | ||
borderRadius: "10px", | ||
}} | ||
> | ||
<p style={{ margin: 0 }}>{entry.text}</p> | ||
{entry.sources && entry.sources.length > 0 } | ||
</div> | ||
</div> | ||
))} | ||
</div> | ||
<form onSubmit={handleSubmit} style={{ marginTop: "10px" }}> | ||
<input | ||
type="text" | ||
value={inputMessage} | ||
onChange={(e) => setInputMessage(e.target.value)} | ||
placeholder="Type your message..." | ||
required | ||
style={{ | ||
width: "80%", | ||
padding: "10px", | ||
backgroundColor: "#333", | ||
color: "#fff", | ||
border: "1px solid #555", | ||
borderRadius: "5px", | ||
}} | ||
/> | ||
<button | ||
type="submit" | ||
style={{ | ||
width: "18%", | ||
padding: "10px", | ||
marginLeft: "2%", | ||
backgroundColor: "#4caf50", | ||
color: "#fff", | ||
border: "none", | ||
borderRadius: "5px", | ||
cursor: "pointer", | ||
}} | ||
> | ||
Send | ||
</button> | ||
</form> | ||
</div> | ||
); | ||
} | ||
|
||
export default ChatComponent; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters