-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdb.py
146 lines (119 loc) · 4.16 KB
/
db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
This module prepares the database for the chatbot.
It loads the documents from the directory and splits them into chunks
Chunks are then stored in the chroma database for easier vector search functionality
"""
import os
import shutil
import nltk
from dotenv import load_dotenv
from langchain.schema import Document
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.vectorstores.chroma import Chroma
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
# import spacy
# from langchain_community.embeddings import GooglePalmEmbeddings
# from langchain.text_splitter import RecursiveCharacterTextSplitter
DATA_PATH = "data/pdfs"
CHROMA_DB_PATH = "database/chroma"
MAX_CHUNK_SIZE = 1500
load_dotenv()
api_key = os.getenv("GOOGLE_API_KEY")
# nlp = spacy.load("en_core_web_sm")
def save_to_chroma_db(chunks: list[Document]):
"""
Saves the chunks to the chroma database.
Args:
chunks: A list of chunks.
"""
if os.path.exists(CHROMA_DB_PATH):
shutil.rmtree(CHROMA_DB_PATH)
db = Chroma.from_documents(
chunks,
GoogleGenerativeAIEmbeddings(google_api_key=api_key, model="models/embedding-001"),
persist_directory=CHROMA_DB_PATH,
)
db.persist()
print(f"Saved {len(chunks)} chunks to {CHROMA_DB_PATH}.")
def load_documents():
"""
Loads all documents from the data directory.
Returns:
list[Document]: A list of documents.
"""
loader = DirectoryLoader(DATA_PATH, glob="*.pdf")
documents = loader.load()
return documents
def split_texts(documents: list[Document]):
"""
Splits the documents into chunks based on sentences, considering complexity.
Args:
documents (list[Document]): A list of documents.
Returns:
chunks: A list of Document chunks
"""
chunks = []
for document in documents:
text = document.page_content
sentences = nltk.sent_tokenize(text)
current_chunk = []
chunk_size = 1 # Default to 1 sentence for short factual content
for sentence in sentences:
# Analyze sentence complexity
is_complex = is_sentence_complex(sentence)
# Update chunk size based on complexity
if is_complex:
chunk_size = min(chunk_size + 1, 5)
else:
chunk_size = min(chunk_size + 1, 3)
# Check if adding the sentence exceeds the chunk size
if (
len(current_chunk) == 0
or sum(len(s) for s in current_chunk) + len(sentence) <= MAX_CHUNK_SIZE
):
current_chunk.append(sentence)
else:
# Create a new Document object for the chunk
chunk_document = Document(
" ".join(current_chunk), metadata=document.metadata
)
chunks.append(chunk_document)
current_chunk = [sentence]
# Add the last chunk if any sentences remain
if current_chunk:
chunk_document = Document(
" ".join(current_chunk), metadata=document.metadata
)
chunks.append(chunk_document)
print(f"Split {len(documents)} documents into {len(chunks)} chunks.")
# print(chunks[15].page_content)
# print(chunks[15].metadata)
return chunks
def is_sentence_complex(sentence):
"""
Function to check sentence complexity
Args:
sentence: A sentence to check
Returns:
bool: True if the sentence is complex, False otherwise
"""
# TODO
# implement complexity checks here based on factors like:
# - Word count
# - Presence of subordinate clauses
# - Lexical density (ratio of unique words to total words)
# doc = nlp(sentence)
# word_count = len(sentence)
# has_subordinate_clause = any(
# token.dep_ == "mark" for token in doc
# ) # Check for subordinating conjunctions
return len(sentence) > 20
def main():
"""
Main function to prepare the database.
"""
documents = load_documents()
chunks = split_texts(documents)
save_to_chroma_db(chunks)
if __name__ == "__main__":
main()