-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp_api.py
executable file
Β·173 lines (136 loc) Β· 4.54 KB
/
app_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
from fastapi.responses import FileResponse
from libs.find_sources import get_query_sources
from libs.common import project_path, load_project_env
from fastapi import FastAPI, HTTPException, Header
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import libs.config as config
from graphrag.cli.query import run_local_search, run_global_search, run_drift_search
from dotenv import load_dotenv
load_dotenv()
app = FastAPI(
title="GraphRAG WebUI API",
version=config.app_version,
terms_of_service="https://github.com/TheodoreNiu/graphrag_webui",
license_info={
"name": "MIT",
"url": "https://opensource.org/licenses/MIT",
},
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Item(BaseModel):
query: str
project_name: str
community_level: int = 2
dynamic_community_selection: bool = False
query_source: bool = False
user_cache: bool = False
context_data: bool = False
local_search_cache = {}
local_search_cache_limit = 20
def get_local_search_cache(item: Item):
if not item.user_cache:
return None
if item.query in local_search_cache:
return local_search_cache[item.query]
return None
def set_local_search_cache(item: Item, result: any):
if not item.user_cache:
return
if len(local_search_cache) >= local_search_cache_limit:
local_search_cache.pop(list(local_search_cache.keys())[0])
local_search_cache[item.query] = result
def check_api_key(project_name: str, api_key: str):
load_project_env(project_name)
if os.getenv("API_KEY") and os.getenv("API_KEY") != api_key:
raise Exception("Invalid api-key")
# -----------------------------------------------------------------
@app.post("/api/local_search")
def local_search(item: Item, api_key: str = Header(...)):
try:
check_api_key(item.project_name, api_key)
cached_result = get_local_search_cache(item)
if cached_result:
return cached_result
(response, context_data) = run_local_search(
root_dir=project_path(item.project_name),
query=item.query,
community_level=int(item.community_level),
response_type="Multiple Paragraphs",
streaming=False,
config_filepath=None,
data_dir=None,
)
result = {
"message": "ok",
"response": response,
"query": item.query,
}
if item.query_source:
result["sources"] = get_query_sources(item.project_name, context_data)
if item.context_data:
result["context_data"] = context_data
set_local_search_cache(item, result)
return result
except Exception as e:
return {
"error": str(e),
}
# -----------------------------------------------------------------
@app.post("/api/global_search")
def global_search(item: Item, api_key: str = Header(...)):
try:
check_api_key(item.project_name, api_key)
(response, context_data) = run_global_search(
root_dir=project_path(item.project_name),
query=item.query,
community_level=int(item.community_level),
response_type="Multiple Paragraphs",
dynamic_community_selection=bool(item.dynamic_community_selection),
streaming=False,
config_filepath=None,
data_dir=None,
)
result = {
"message": "ok",
"response": response,
"query": item.query,
}
if item.context_data:
result["context_data"] = context_data
return result
except Exception as e:
return {
"error": str(e),
}
@app.post("/api/drift_search")
def drift_search(item: Item, api_key: str = Header(...)):
try:
check_api_key(item.project_name, api_key)
(response, context_data) = run_drift_search(
root_dir=project_path(item.project_name),
query=item.query,
community_level=int(item.community_level),
streaming=False,
config_filepath=None,
data_dir=None,
)
result = {
"message": "ok",
"response": response,
"query": item.query,
}
if item.context_data:
result["context_data"] = context_data
return result
except Exception as e:
return {
"error": str(e),
}