Skip to content

Commit

Permalink
feat: Finally got a satisfactory result
Browse files Browse the repository at this point in the history
  • Loading branch information
j4ndrw committed Jan 16, 2024
1 parent 65c7d42 commit 814c56e
Show file tree
Hide file tree
Showing 9 changed files with 275 additions and 454 deletions.
8 changes: 1 addition & 7 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,10 @@ name = "pypi"
langchain = "*"
chromadb = "*"
langchain-community = "*"
tiktoken = "*"
esprima = "*"
tqdm = "*"
gitpython = "*"
psutil = "*"
pip = "*"
install = "*"
gpt4all = "*"
fastembed = "*"
sentence-transformers = "*"
langchainhub = "*"

[dev-packages]

Expand Down
216 changes: 21 additions & 195 deletions Pipfile.lock

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions __main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

from codebase_indexer.argparser import parse_args
from codebase_indexer.constants import DEFAULT_VECTOR_DB_DIR
from codebase_indexer.rag import init_llm, init_vector_store


def cli():
args = parse_args()

repo_path = os.path.abspath(os.path.join(os.path.curdir, args.repo_path))
vector_db_dir = args.vector_db_dir or os.path.join(repo_path, DEFAULT_VECTOR_DB_DIR)
if not repo_path:
raise Exception("A repository must be specified")

retriever = init_vector_store(repo_path, vector_db_dir)
_, qa = init_llm(retriever, args.ollama_inference_model)

while True:
query = input(">>> ")
qa(query)
print("\n")


if __name__ == "__main__":
cli()
Empty file added codebase_indexer/__init__.py
Empty file.
56 changes: 56 additions & 0 deletions codebase_indexer/argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from argparse import ArgumentParser
from dataclasses import dataclass

from codebase_indexer.constants import (
DEFAULT_OLLAMA_EMBEDDINGS_MODEL,
DEFAULT_OLLAMA_INFERENCE_MODEL,
)


def create_argparser() -> ArgumentParser:
parser = ArgumentParser(
description="Indexes your codebase and, given a prompt, provide information about your codebase via an LLM",
)
parser.add_argument("repo_path", type=str, help="The path to your repo", default="")
parser.add_argument(
"--ollama_inference_model",
type=str,
help="The LLM you want to use. Defaults to `mistral-openorca`.",
default=DEFAULT_OLLAMA_INFERENCE_MODEL,
)
parser.add_argument(
"--ollama_embeddings_model",
type=str,
help="The embeddings model you wish to use when creating the index in the vector database. Defaults to `tinyllama`.",
default=DEFAULT_OLLAMA_EMBEDDINGS_MODEL,
)
parser.add_argument(
"--vector_db_dir",
type=str,
help="The path to the vector database. Defaults to `{repo_path}/.llm-index/vectorstores/db`",
default=None,
)
return parser


@dataclass
class Args:
repo_path: str
ollama_inference_model: str
ollama_embeddings_model: str
vector_db_dir: str | None


def parse_args() -> Args:
args = create_argparser().parse_args()
repo_path = args.repo_path
ollama_inference_model = args.ollama_inference_model
ollama_embeddings_model = args.ollama_embeddings_model
vector_db_dir = args.vector_db_dir

return Args(
repo_path=repo_path,
ollama_inference_model=ollama_inference_model,
ollama_embeddings_model=ollama_embeddings_model,
vector_db_dir=vector_db_dir,
)
33 changes: 33 additions & 0 deletions codebase_indexer/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os

from langchain.text_splitter import Language

OLLAMA_BASE_URL = "http://" + os.environ["OLLAMA_HOST"]

DEFAULT_VECTOR_DB_DIR = ".llm-index/vectorstores/db"

DEFAULT_OLLAMA_INFERENCE_MODEL = "mistral-openorca"
DEFAULT_OLLAMA_EMBEDDINGS_MODEL = "tinyllama"

LANGUAGES = [
Language.CPP,
Language.GO,
Language.JAVA,
Language.KOTLIN,
Language.JS,
Language.TS,
Language.PHP,
Language.PROTO,
Language.PYTHON,
Language.RST,
Language.RUBY,
Language.RUST,
Language.SCALA,
Language.SWIFT,
Language.MARKDOWN,
Language.LATEX,
Language.HTML,
Language.SOL,
Language.CSHARP,
Language.COBOL,
]
49 changes: 49 additions & 0 deletions codebase_indexer/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from gpt4all import Embed4All
from langchain import hub
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import ConversationalRetrievalChain, RetrievalQA
from langchain.chains.question_answering import load_qa_chain
from langchain.memory import ConversationSummaryMemory
from langchain_community.chat_models import ChatOllama
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_core.vectorstores import VectorStoreRetriever

from codebase_indexer.constants import DEFAULT_OLLAMA_INFERENCE_MODEL, OLLAMA_BASE_URL
from codebase_indexer.vector_store import create_index, load_index


def init_vector_store(
repo_path: str,
vector_db_dir: str,
) -> VectorStoreRetriever:
embeddings_factory = lambda: GPT4AllEmbeddings(client=Embed4All)

indexing_args = (repo_path, vector_db_dir, embeddings_factory)
repo_path, db = load_index(*indexing_args) or create_index(*indexing_args)

retriever = db.as_retriever(
search_type="mmr", search_kwargs={"k": 6, "lambda_mult": 0.25}
)

return retriever


def init_llm(
retriever: VectorStoreRetriever, ollama_inference_model: str | None = None
):
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
llm = ChatOllama(
base_url=OLLAMA_BASE_URL,
model=ollama_inference_model or DEFAULT_OLLAMA_INFERENCE_MODEL,
callback_manager=callback_manager,
num_ctx=8192,
num_gpu=1,
)
memory = ConversationSummaryMemory(
llm=llm, return_messages=True, memory_key="chat_history"
)
qa = ConversationalRetrievalChain.from_llm(
llm=llm, retriever=retriever, memory=memory
)
return llm, qa
89 changes: 89 additions & 0 deletions codebase_indexer/vector_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
from typing import Callable

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.git import GitLoader
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_community.vectorstores.chroma import Chroma
from langchain_core.documents import Document

from codebase_indexer.constants import LANGUAGES


def collection_name(repo_path: str) -> str:
return repo_path.rsplit("/", 1)[0].split("/", 1)[1].replace("/", "-")


def load_docs(repo_path: str) -> list[Document]:
def file_filter(file_path: str) -> bool:
is_meta = file_path.lower().endswith(
(
".lock",
".md",
"lock.yaml",
"lock.yml",
"lock.json",
".gitkeep",
".gitattributes",
".gitignore",
".git",
".xml",
".log",
".csv",
".jmx",
)
)
return not is_meta

doc_pool: list[Document] = []
splitters = {
language.value: RecursiveCharacterTextSplitter.from_language(
language, chunk_size=2000, chunk_overlap=200
)
for language in LANGUAGES
}

loader = GitLoader(repo_path=repo_path, branch="master", file_filter=file_filter)
docs = loader.load()

for language in LANGUAGES:
docs_per_lang: list[Document] = []
for doc in docs:
if doc.metadata["file_type"] == f".{language.value}":
docs_per_lang.append(doc)
doc_pool.extend(splitters[language.value].split_documents(docs_per_lang))

print("Loaded documents")
return doc_pool


def create_index(
repo_path: str,
vector_db_dir: str,
embeddings_factory: Callable[[], GPT4AllEmbeddings],
) -> tuple[str, Chroma]:
docs = load_docs(repo_path)
db = Chroma.from_documents(
docs,
embeddings_factory(),
persist_directory=vector_db_dir,
collection_name=collection_name(repo_path),
)
print("Indexed documents")
return repo_path, db


def load_index(
repo_path: str,
vector_db_dir: str,
embeddings_factory: Callable[[], GPT4AllEmbeddings],
) -> tuple[str, Chroma] | None:
store = Chroma(
collection_name(repo_path),
embeddings_factory(),
persist_directory=vector_db_dir,
)
result = store.get()
if len(result["documents"]) == 0:
return None
return repo_path, store
Loading

0 comments on commit 814c56e

Please sign in to comment.