-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
69 lines (42 loc) · 2.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from dotenv import load_dotenv
from modules.loader import Loader
from modules.models import Router, DocGrader, Generator, HallucinationGrader,AnswerGrader,QuestionRewriter
from modules.nodes import Nodes
from modules.graph import Graph
load_dotenv()
loader = Loader()
class NodeHelpers:
retriever = loader.get_retriever()
question_router = Router.get_model()
retrieval_grader = DocGrader.get_model()
rag_chain = Generator.get_model()
hallucination_grader = HallucinationGrader.get_model()
answer_grader = AnswerGrader.get_model()
question_rewriter = QuestionRewriter.get_model()
cite_mapper = {"data\\L1-Introduction.pdf.md":"https://stanford-cs324.github.io/winter2022/lectures/introduction/",
"data\\L2-CAPABILITIES.pdf.md" : "https://stanford-cs324.github.io/winter2022/lectures/capabilities/",
"data\\L3-Harms.pdf.md" :"https://stanford-cs324.github.io/winter2022/lectures/harms-1/" ,
"data\\L4-Harms-II.pdf.md": "https://stanford-cs324.github.io/winter2022/lectures/harms-2/"}
nodes = Nodes(NodeHelpers())
app = Graph.create(nodes)
# from print import print
# Run
inp = ""
message_history = ["the messages till now are given below: \n\n"]
while inp!="exit":
inp = input("Enter question: ")
inputs = {"question": inp,"message_history" : message_history }
for output in app.stream(inputs):
for key, value in output.items():
# Node
print(f"Node '{key}': \nREGENERATED : {nodes.REGENERATION_COUNT}\nRETREVIAL : {nodes.RERETREVIAL_COUNT}" )
print("\n---\n")
# Final generation
print(value["generation"])
# print(sentence)
value["message_history"].append("msg number ["+str(len(value["message_history"]))+"] agent : "+value["generation"])
message_history = value["message_history"]
# print(message_history)
for doc in value["documents"]:
print(doc.metadata)
# print("site : " + cite_mapper[doc.metadata["source"]] + "\t page : " +str( doc.metadata["page"]) )