-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsummarize.py
77 lines (57 loc) · 2.42 KB
/
summarize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import re
import streamlit as st
from utils.document_util import extract_from_pdf, extract_from_scanned_pdf, tag_documents
from llm.llm_chain import LLM
from vectorstore.weaviate_store import WeaviateStore
def main():
st.set_page_config(page_title="ClauseSense", layout="wide", page_icon="♦")
st.markdown(
"<h1 style='text-align: center; font-weight:bold; font-family:comic sans ms; padding-top: 0rem;'> \
Legal Document Summarization</h1>",
unsafe_allow_html=True,
)
clauses = {}
clause_count = st.sidebar.slider(label="Number of clause", min_value=1, max_value=5)
## upload document
file = st.file_uploader(label="Upload your document (pdf)", type="pdf")
print()
if file:
button = st.button("Submit")
llm = LLM()
llm_chain = llm.get_chain()
st.subheader("Enter clauses that you want to summarize 👇")
cols = st.columns(clause_count)
for i, x in enumerate(cols):
clauses[x.text_input(label=f"clause {i+1}", value="", key=i)] = []
if file and button:
st.divider()
with st.spinner("Generating summaries..."):
data = extract_from_scanned_pdf(file)
tagged_documents = tag_documents(data)
weaviatestore = WeaviateStore()
class_name = weaviatestore.generate_class_name()
weaviatestore.create_class_obj(class_name=class_name)
weaviatestore.add_documents(class_name=class_name, tagged_documents=tagged_documents)
for k, v in clauses.items():
relevant_documents = weaviatestore.bm25_search_weaviate(
query=k, class_name=class_name
)
print(relevant_documents)
v.append(
llm.summarize(
relevant_documents=relevant_documents,
query=k,
chain=llm_chain,
class_name=class_name,
)
)
weaviatestore.delete_class(class_name=class_name)
print(f"{class_name} successfully deleted!")
# expander view
items = list(clauses.items())
for i, x in enumerate(cols):
x.subheader(items[i][0])
text = "".join(items[i][1][0])
x.text_area("Summary", text, key=i+10, height=400)
if __name__ == "__main__":
main()