-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Finally got a satisfactory result
- Loading branch information
Showing
9 changed files
with
275 additions
and
454 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
|
||
from codebase_indexer.argparser import parse_args | ||
from codebase_indexer.constants import DEFAULT_VECTOR_DB_DIR | ||
from codebase_indexer.rag import init_llm, init_vector_store | ||
|
||
|
||
def cli(): | ||
args = parse_args() | ||
|
||
repo_path = os.path.abspath(os.path.join(os.path.curdir, args.repo_path)) | ||
vector_db_dir = args.vector_db_dir or os.path.join(repo_path, DEFAULT_VECTOR_DB_DIR) | ||
if not repo_path: | ||
raise Exception("A repository must be specified") | ||
|
||
retriever = init_vector_store(repo_path, vector_db_dir) | ||
_, qa = init_llm(retriever, args.ollama_inference_model) | ||
|
||
while True: | ||
query = input(">>> ") | ||
qa(query) | ||
print("\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from argparse import ArgumentParser | ||
from dataclasses import dataclass | ||
|
||
from codebase_indexer.constants import ( | ||
DEFAULT_OLLAMA_EMBEDDINGS_MODEL, | ||
DEFAULT_OLLAMA_INFERENCE_MODEL, | ||
) | ||
|
||
|
||
def create_argparser() -> ArgumentParser: | ||
parser = ArgumentParser( | ||
description="Indexes your codebase and, given a prompt, provide information about your codebase via an LLM", | ||
) | ||
parser.add_argument("repo_path", type=str, help="The path to your repo", default="") | ||
parser.add_argument( | ||
"--ollama_inference_model", | ||
type=str, | ||
help="The LLM you want to use. Defaults to `mistral-openorca`.", | ||
default=DEFAULT_OLLAMA_INFERENCE_MODEL, | ||
) | ||
parser.add_argument( | ||
"--ollama_embeddings_model", | ||
type=str, | ||
help="The embeddings model you wish to use when creating the index in the vector database. Defaults to `tinyllama`.", | ||
default=DEFAULT_OLLAMA_EMBEDDINGS_MODEL, | ||
) | ||
parser.add_argument( | ||
"--vector_db_dir", | ||
type=str, | ||
help="The path to the vector database. Defaults to `{repo_path}/.llm-index/vectorstores/db`", | ||
default=None, | ||
) | ||
return parser | ||
|
||
|
||
@dataclass | ||
class Args: | ||
repo_path: str | ||
ollama_inference_model: str | ||
ollama_embeddings_model: str | ||
vector_db_dir: str | None | ||
|
||
|
||
def parse_args() -> Args: | ||
args = create_argparser().parse_args() | ||
repo_path = args.repo_path | ||
ollama_inference_model = args.ollama_inference_model | ||
ollama_embeddings_model = args.ollama_embeddings_model | ||
vector_db_dir = args.vector_db_dir | ||
|
||
return Args( | ||
repo_path=repo_path, | ||
ollama_inference_model=ollama_inference_model, | ||
ollama_embeddings_model=ollama_embeddings_model, | ||
vector_db_dir=vector_db_dir, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
|
||
from langchain.text_splitter import Language | ||
|
||
OLLAMA_BASE_URL = "http://" + os.environ["OLLAMA_HOST"] | ||
|
||
DEFAULT_VECTOR_DB_DIR = ".llm-index/vectorstores/db" | ||
|
||
DEFAULT_OLLAMA_INFERENCE_MODEL = "mistral-openorca" | ||
DEFAULT_OLLAMA_EMBEDDINGS_MODEL = "tinyllama" | ||
|
||
LANGUAGES = [ | ||
Language.CPP, | ||
Language.GO, | ||
Language.JAVA, | ||
Language.KOTLIN, | ||
Language.JS, | ||
Language.TS, | ||
Language.PHP, | ||
Language.PROTO, | ||
Language.PYTHON, | ||
Language.RST, | ||
Language.RUBY, | ||
Language.RUST, | ||
Language.SCALA, | ||
Language.SWIFT, | ||
Language.MARKDOWN, | ||
Language.LATEX, | ||
Language.HTML, | ||
Language.SOL, | ||
Language.CSHARP, | ||
Language.COBOL, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from gpt4all import Embed4All | ||
from langchain import hub | ||
from langchain.callbacks.manager import CallbackManager | ||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | ||
from langchain.chains import ConversationalRetrievalChain, RetrievalQA | ||
from langchain.chains.question_answering import load_qa_chain | ||
from langchain.memory import ConversationSummaryMemory | ||
from langchain_community.chat_models import ChatOllama | ||
from langchain_community.embeddings import GPT4AllEmbeddings | ||
from langchain_core.vectorstores import VectorStoreRetriever | ||
|
||
from codebase_indexer.constants import DEFAULT_OLLAMA_INFERENCE_MODEL, OLLAMA_BASE_URL | ||
from codebase_indexer.vector_store import create_index, load_index | ||
|
||
|
||
def init_vector_store( | ||
repo_path: str, | ||
vector_db_dir: str, | ||
) -> VectorStoreRetriever: | ||
embeddings_factory = lambda: GPT4AllEmbeddings(client=Embed4All) | ||
|
||
indexing_args = (repo_path, vector_db_dir, embeddings_factory) | ||
repo_path, db = load_index(*indexing_args) or create_index(*indexing_args) | ||
|
||
retriever = db.as_retriever( | ||
search_type="mmr", search_kwargs={"k": 6, "lambda_mult": 0.25} | ||
) | ||
|
||
return retriever | ||
|
||
|
||
def init_llm( | ||
retriever: VectorStoreRetriever, ollama_inference_model: str | None = None | ||
): | ||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | ||
llm = ChatOllama( | ||
base_url=OLLAMA_BASE_URL, | ||
model=ollama_inference_model or DEFAULT_OLLAMA_INFERENCE_MODEL, | ||
callback_manager=callback_manager, | ||
num_ctx=8192, | ||
num_gpu=1, | ||
) | ||
memory = ConversationSummaryMemory( | ||
llm=llm, return_messages=True, memory_key="chat_history" | ||
) | ||
qa = ConversationalRetrievalChain.from_llm( | ||
llm=llm, retriever=retriever, memory=memory | ||
) | ||
return llm, qa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
from typing import Callable | ||
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain_community.document_loaders.git import GitLoader | ||
from langchain_community.embeddings import GPT4AllEmbeddings | ||
from langchain_community.vectorstores.chroma import Chroma | ||
from langchain_core.documents import Document | ||
|
||
from codebase_indexer.constants import LANGUAGES | ||
|
||
|
||
def collection_name(repo_path: str) -> str: | ||
return repo_path.rsplit("/", 1)[0].split("/", 1)[1].replace("/", "-") | ||
|
||
|
||
def load_docs(repo_path: str) -> list[Document]: | ||
def file_filter(file_path: str) -> bool: | ||
is_meta = file_path.lower().endswith( | ||
( | ||
".lock", | ||
".md", | ||
"lock.yaml", | ||
"lock.yml", | ||
"lock.json", | ||
".gitkeep", | ||
".gitattributes", | ||
".gitignore", | ||
".git", | ||
".xml", | ||
".log", | ||
".csv", | ||
".jmx", | ||
) | ||
) | ||
return not is_meta | ||
|
||
doc_pool: list[Document] = [] | ||
splitters = { | ||
language.value: RecursiveCharacterTextSplitter.from_language( | ||
language, chunk_size=2000, chunk_overlap=200 | ||
) | ||
for language in LANGUAGES | ||
} | ||
|
||
loader = GitLoader(repo_path=repo_path, branch="master", file_filter=file_filter) | ||
docs = loader.load() | ||
|
||
for language in LANGUAGES: | ||
docs_per_lang: list[Document] = [] | ||
for doc in docs: | ||
if doc.metadata["file_type"] == f".{language.value}": | ||
docs_per_lang.append(doc) | ||
doc_pool.extend(splitters[language.value].split_documents(docs_per_lang)) | ||
|
||
print("Loaded documents") | ||
return doc_pool | ||
|
||
|
||
def create_index( | ||
repo_path: str, | ||
vector_db_dir: str, | ||
embeddings_factory: Callable[[], GPT4AllEmbeddings], | ||
) -> tuple[str, Chroma]: | ||
docs = load_docs(repo_path) | ||
db = Chroma.from_documents( | ||
docs, | ||
embeddings_factory(), | ||
persist_directory=vector_db_dir, | ||
collection_name=collection_name(repo_path), | ||
) | ||
print("Indexed documents") | ||
return repo_path, db | ||
|
||
|
||
def load_index( | ||
repo_path: str, | ||
vector_db_dir: str, | ||
embeddings_factory: Callable[[], GPT4AllEmbeddings], | ||
) -> tuple[str, Chroma] | None: | ||
store = Chroma( | ||
collection_name(repo_path), | ||
embeddings_factory(), | ||
persist_directory=vector_db_dir, | ||
) | ||
result = store.get() | ||
if len(result["documents"]) == 0: | ||
return None | ||
return repo_path, store |
Oops, something went wrong.