Skip to content

Commit

Permalink
Update layout and tips
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Aug 11, 2023
1 parent ed92a00 commit a28d299
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 61 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
English | [简体中文](https://github.com/RapidAI/Knowledge-QA-LLM/blob/main/docs/README_zh.md)
[简体中文](https://github.com/RapidAI/Knowledge-QA-LLM/blob/main/docs/README_zh.md) | English

## Knowledge QA LLM
## 🧐 Knowledge QA LLM
<p>
<a href=""><img src="https://img.shields.io/badge/Python->=3.8,<3.12-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
Expand Down Expand Up @@ -74,8 +74,9 @@ English | [简体中文](https://github.com/RapidAI/Knowledge-QA-LLM/blob/main/d
│ └── raw_upload_files
├── knowledge_qa_llm
│ ├── __init__.py
│ ├── config.yaml # configuration file
│ ├── config.yaml # configuration file
│ ├── file_loader # Handle documents in various formats
│ ├── encoder # Extract embeddings
│ ├── llm # Large model interface, the large model needs to be deployed separately and called by interface
│ ├── utils
│ └── vector_utils # embedding access and search
Expand All @@ -87,7 +88,12 @@ English | [简体中文](https://github.com/RapidAI/Knowledge-QA-LLM/blob/main/d
└── webui.py # UI implementation based on streamlit
```

#### Update Log
#### Change Log
- 2023-08-11 v0.0.7 update:
- Optimize layout, remove the plugin option, and put the extract vector model option on the home page.
- The tips are translated into English for easy communication.
- Add project logo:🧐
- Update CLI module code.
- 2023-08-05 v0.0.6 update:
- Adapt more llm_api, include online llm api, such ad ERNIE-Bot-Turbo.
- Add the status of extracting embeddings.
Expand Down
17 changes: 10 additions & 7 deletions cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from knowledge_qa_llm.encoder import EncodeText
from knowledge_qa_llm.file_loader import FileLoader
from knowledge_qa_llm.llm import ChatGLM26B
from knowledge_qa_llm.llm import Qwen7B_Chat
from knowledge_qa_llm.utils import make_prompt, read_yaml
from knowledge_qa_llm.vector_utils import DBUtils, EncodeText
from knowledge_qa_llm.vector_utils import DBUtils

config = read_yaml("knowledge_qa_llm/config.yaml")

Expand All @@ -16,18 +17,20 @@
# sentences = text[0][1]

# 提取特征
embedding_model = EncodeText(config.get("encoder_model_path"))
embedding_model = EncodeText(config.get("Encoder")["m3e-small"])
# embeddings = embedding_model(sentences)

# 插入数据到数据库中
db_tools = DBUtils(config.get("vector_db_path"))
# db_tools.insert(file_path, embeddings, sentences)

llm_engine = ChatGLM26B(api_url=config.get("llm_api_url"))
llm_engine = Qwen7B_Chat(api_url=config.get("LLM_API")["Qwen7B_Chat"])

print("欢迎使用 Knowledge QA LLM,输入内容即可进行对话,stop 终止程序")
print(
"Welcom to 🧐 Knowledge QA LLM,enter the content to start the conversation, enter stop to terminate the program."
)
while True:
query = input("\n用户:")
query = input("\n😀 User:")
if query.strip() == "stop":
break

Expand All @@ -37,4 +40,4 @@

prompt = make_prompt(query, context, custom_prompt=config.get("DEFAULT_PROMPT"))
response = llm_engine(prompt, history=None)
print(response)
print(f"🤖 LLM:{response}")
10 changes: 8 additions & 2 deletions docs/README_zh.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[English](https://github.com/RapidAI/Knowledge-QA-LLM) | 简体中文
简体中文 | [English](https://github.com/RapidAI/Knowledge-QA-LLM)

## Knowledge QA LLM
## 🧐 Knowledge QA LLM
<p>
<a href=""><img src="https://img.shields.io/badge/Python->=3.8,<3.12-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
Expand Down Expand Up @@ -79,6 +79,7 @@
│   ├── __init__.py
│   ├── config.yaml # 配置文件
│   ├── file_loader # 处理各种格式的文档
│   ├── encoder # 提取特征向量
│   ├── llm # 大模型接口,大模型需要单独部署,以接口方式调用
│   ├── utils
│   └── vector_utils # embedding的存取和搜索
Expand All @@ -91,6 +92,11 @@
```

#### 更新日志
- 2023-08-11 v0.0.7 update:
- 优化布局,去掉插件选项,将提取向量模型选项放到主页部分
- 将提示语英语化,便于交流使用。
- 添加项目logo: 🧐
- 更新CLI使用代码
- 2023-08-05 v0.0.6 update:
- 适配更多模型接口,包括在线大模型接口,例如文心一言
- 添加提取特征向量的状态提示
Expand Down
10 changes: 5 additions & 5 deletions knowledge_qa_llm/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
title: Knowledge QA LLM
version: 0.0.6
title: 🧐 Knowledge QA LLM
version: 0.0.7

LLM_API:
Qwen7B_Chat: your_api
Expand Down Expand Up @@ -34,16 +34,16 @@ Parameter:
max_value: 4096
default: 1024
step: 1
tip: 输入input_ids的最大长度
tip: The longest length of input_ids
top_p:
min_value: 0.0
max_value: 1.0
default: 0.7
step: 0.01
tip: 限制模型为仅考虑最可能的前p个标记
tip: Limit the model to only consider the most likely first p markers.
temperature:
min_value: 0.01
max_value: 1.0
default: 0.01
step: 0.01
tip: 控制模型输出的随机性,温度越低将导致输出更加可预测和重复,越高将更富创意和自发的输出。
tip: Control the randomness of the model output. The smaller the value, the more standardized it is, and vice versa, the more creative it is.
77 changes: 34 additions & 43 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


def init_sidebar():
st.sidebar.markdown("### 🛶 参数设置")
st.sidebar.markdown("### 🛶 Parameter Settings")
param = config.get("Parameter")

param_max_length = param.get("max_length")
Expand Down Expand Up @@ -61,42 +61,35 @@ def init_sidebar():
)
st.session_state["params"]["temperature"] = temperature

st.sidebar.markdown("### 🧻 知识库")
st.sidebar.markdown("### 🧻 Knowledge Base")
uploaded_files = st.sidebar.file_uploader(
"default",
accept_multiple_files=True,
label_visibility="hidden",
help="支持多选",
help="Support for multiple selections",
)

upload_dir = config.get("upload_dir")

ENCODER_OPTIONS = config.get("Encoder")
select_encoder = st.sidebar.selectbox("🧬提取向量模型:", ENCODER_OPTIONS.keys())
tips(f"初始化{select_encoder}...")
embedding_extract = init_encoder(ENCODER_OPTIONS[select_encoder])
tips("初始化完成!")

btn_upload = st.sidebar.button("上传文档并加载数据库", use_container_width=True)
btn_upload = st.sidebar.button("Upload and load database", use_container_width=True)
if btn_upload:
time_stamp = get_timestamp()
save_dir = Path(upload_dir) / time_stamp
st.session_state["upload_dir"] = save_dir

tips("正在上传文件到平台...", icon="⏳")
tips("Uploading files to platform...", icon="⏳")
for file in uploaded_files:
bytes_data = file.getvalue()

mkdir(save_dir)
save_path = save_dir / file.name
with open(save_path, "wb") as f:
f.write(bytes_data)
tips("上传完毕!")
tips("Upload completed!")

doc_dir = st.session_state["upload_dir"]
all_doc_contents = file_loader(doc_dir)

pro_text = "正在提取特征向量..."
pro_text = "Extracting embeddings..."
batch_size = config.get("encoder_batch_size", 32)
for file_path, one_doc_contents in all_doc_contents.items():
my_bar = st.sidebar.progress(0, text=pro_text)
Expand All @@ -111,23 +104,21 @@ def init_sidebar():

my_bar.progress(
end_idx / content_nums,
f"提取{file_path}数据: [{end_idx}/{content_nums}]",
f"Extract {file_path} datas: [{end_idx}/{content_nums}]",
)
my_bar.empty()
all_embeddings = np.vstack(all_embeddings)
db_tools.insert(file_path, all_embeddings, one_doc_contents)
my_bar.empty()

shutil.rmtree(doc_dir.resolve())
tips("已经加载并存入数据库中,可以提问了!")
tips("You can now ask a question!")

had_files = db_tools.get_files()
if had_files:
st.sidebar.markdown("仓库已有文档:")
st.sidebar.markdown("Existing documents:")
st.sidebar.markdown("\n".join([f" - {v}" for v in had_files]))

return embedding_extract


def init_state():
if "history" not in st.session_state:
Expand All @@ -153,26 +144,26 @@ def predict(
logger.info(f"Using {type(model).__name__}")

query_embedding = embedding_extract(text)
with st.spinner("从文档中搜索相关内容"):
with st.spinner("Search for relevant contents from docs..."):
search_res, search_elapse = db_tools.search_local(
query_embedding, top_k=config.get("top_k")
)
if search_res is None:
bot_print("从文档中搜索相关内容为空,暂不能回答该问题")
bot_print("The results of searching from docs is empty.")
else:
context = "\n".join(sum(search_res.values(), []))
res_cxt = f"**从文档中检索到的相关内容Top5\n(相关性从高到低,耗时:{search_elapse:.5f}s):** \n"
res_cxt = f"**Find Top{search_top}\n(Scores from high to low,cost:{search_elapse:.5f}s):** \n"
bot_print(res_cxt)

for file, content in search_res.items():
content = "\n".join(content)
one_context = f"**来自文档:《{file}》** \n{content}"
one_context = f"**From:《{file}》** \n{content}"
bot_print(one_context)

logger.info(f"上下文\n{one_context}\n")
logger.info(f"Context\n{one_context}\n")

response, elapse = get_model_response(text, context, custom_prompt, model)
print_res = f"**使用模型{select_model}**\n**模型推理耗时{elapse:.5f}s**"
print_res = f"**Use{select_model}**\n**Infer model cost{elapse:.5f}s**"
bot_print(print_res)
bot_print(response)

Expand All @@ -193,14 +184,14 @@ def get_model_response(text, context, custom_prompt, model):

s_model = time.perf_counter()
prompt_msg = make_prompt(text, context, custom_prompt)
logger.info(f"最终拼接后的文本:\n{prompt_msg}\n")
logger.info(f"Final prompt: \n{prompt_msg}\n")

response = model(prompt_msg, history=None, **params_dict)
elapse = time.perf_counter() - s_model

logger.info(f"模型回答: \n{response}\n")
logger.info(f"Reponse of LLM: \n{response}\n")
if not response:
response = "抱歉,未能正确回答该问题"
response = "Sorry, I didn't answer the question correctly"
return response, elapse


Expand All @@ -222,7 +213,7 @@ def tips(txt: str, wait_time: int = 2, icon: str = "🎉"):
db_path = config.get("vector_db_path")
db_tools = DBUtils(db_path)

embedding_extract = init_sidebar()
init_sidebar()
init_state()

llm_module = importlib.import_module("knowledge_qa_llm.llm")
Expand All @@ -240,22 +231,24 @@ def tips(txt: str, wait_time: int = 2, icon: str = "🎉"):
}
)

PLUGINS_OPTIONS = {
"文档": 0,
}
TOP_OPTIONS = [5, 10, 15]
ENCODER_OPTIONS = config.get("Encoder")

menu_col1, menu_col2, menu_col3 = st.columns([1, 1, 1])
select_model = menu_col1.selectbox("🎨基础模型:", MODEL_OPTIONS.keys())
select_plugin = menu_col2.selectbox("🛠Plugin:", PLUGINS_OPTIONS.keys())
search_top = menu_col3.selectbox("🔍查找Top_K", TOP_OPTIONS)
select_model = menu_col1.selectbox("🎨Base model:", MODEL_OPTIONS.keys())
select_encoder = menu_col2.selectbox(
"🧬Extract Embedding Model:", ENCODER_OPTIONS.keys()
)
search_top = menu_col3.selectbox("🔍Search Top_K", TOP_OPTIONS)

embedding_extract = init_encoder(ENCODER_OPTIONS[select_encoder])

input_prompt_container = st.container()
with input_prompt_container:
with st.expander("💡Prompt", expanded=False):
text_area = st.empty()
input_prompt = text_area.text_area(
label="输入",
label="Input",
max_chars=500,
height=200,
label_visibility="hidden",
Expand All @@ -268,15 +261,13 @@ def tips(txt: str, wait_time: int = 2, icon: str = "🎉"):
with st.chat_message("user", avatar="😀"):
st.markdown(input_txt)

plugin_id = PLUGINS_OPTIONS[select_plugin]
llm = MODEL_OPTIONS[select_model]

if not input_prompt:
input_prompt = config.get("DEFAULT_PROMPT")

if plugin_id == 0:
predict(
input_txt,
llm,
input_prompt,
)
predict(
input_txt,
llm,
input_prompt,
)

0 comments on commit a28d299

Please sign in to comment.