Skip to content

Commit

Permalink
feat: added tools for extracting stuff!!! yeeeeeee
Browse files Browse the repository at this point in the history
  • Loading branch information
j4ndrw committed Jan 28, 2024
1 parent fef78a9 commit 4baddff
Show file tree
Hide file tree
Showing 11 changed files with 244 additions and 83 deletions.
3 changes: 1 addition & 2 deletions codebase_indexer/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ class Meta(BaseModel):
ollama_inference_model: str | None = None


Command = Literal["test", "search", "review"]
Context = Literal["from", "infer"]
Command = Literal["test", "search", "review", "new_conversation"]
115 changes: 47 additions & 68 deletions codebase_indexer/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,14 @@
from git import Repo
from langchain_core.callbacks import StreamingStdOutCallbackHandler

from codebase_indexer.api.models import Command, Context
from codebase_indexer.cli.argparser import Args
from codebase_indexer.constants import (
COMMANDS,
CONTEXTS,
DEFAULT_OLLAMA_INFERENCE_MODEL,
DEFAULT_VECTOR_DB_DIR,
)
from codebase_indexer.rag import (
RAG,
create_llm,
create_memory,
create_retriever,
init_vector_store,
)
from codebase_indexer.constants import (DEFAULT_OLLAMA_INFERENCE_MODEL,
DEFAULT_VECTOR_DB_DIR)
from codebase_indexer.rag import (RAG, create_llm, create_memory,
create_retriever, init_vector_store)
from codebase_indexer.rag.agents import create_tools
from codebase_indexer.rag.chains import create_search_request_removal_chain
from codebase_indexer.utils import strip_generated_text


def cli(args: Args):
Expand Down Expand Up @@ -50,65 +43,51 @@ def cli(args: Args):
}
)
memory = create_memory(llm)
prev_search: list[str] = []
sources: list[str] = []
while True:
question = input(">>> ")
command: Command | None = None
contexts: dict[Context, list[str]] = {}
for command_iter in COMMANDS:
if question.startswith(f"/{command_iter}"):
command = command_iter
question = question.replace(f"/{command}", "").strip()
break

# Very ugly, but I can't be asked to implement a proper DSL...
for context in CONTEXTS:
if question.startswith(f"@{context}"):
question = question.replace(f"@{context}", "")
argument = None
if context == "from":
argument, question = question.split(" ", 1)
argument = [*contexts.get(context, []), argument]
elif context == "infer":
argument = prev_search
commands = RAG.extract_commands(llm, question)
sources = RAG.extract_sources_to_search_in(llm, question, sources)

if argument is not None:
contexts.update({context: argument})
question = question.strip()
for command in commands:
if command == "new_conversation":
sources = []
continue

system_prompt, qa_chain_factory, custom_retriever_factory = RAG.from_command(
command
)
system_prompt, qa_chain_factory, custom_retriever_factory = (
RAG.from_command(command)
)

try:
retriever = create_retriever(db)
sources = contexts.get("from", []) + contexts.get("infer", [])
if len(sources) > 0:
retriever.search_kwargs.update(
dict(filter={"source": {"$in": sources}})
)
prev_search = []
if command == "search":
if custom_retriever_factory is not None and db.embeddings is not None:
retriever.search_kwargs.update(dict(k=20))
retriever = custom_retriever_factory(db.embeddings, retriever)
docs = retriever.get_relevant_documents(
question, callbacks=llm.callbacks
try:
retriever = create_retriever(db)
if command == "search":
if (
custom_retriever_factory is not None
and db.embeddings is not None
):
embeddings_retriever = custom_retriever_factory(
db.embeddings, retriever
)
sources = RAG.search(embeddings_retriever, question)
for source in sources:
print(f"\t- {source}")

sources.extend(sources)
question = RAG.remove_search_request(llm, question)

else:
retriever = RAG.filter_on_sources(retriever, sources)
qa = qa_chain_factory(llm, retriever, memory, system_prompt)
qa.invoke(
{
"question": question,
"chat_history": memory.chat_memory,
},
)
except KeyboardInterrupt:
continue
finally:
print("\n")

sources = [doc.metadata["source"] for doc in docs]
for source in sources:
print(f"\t- {source}")
prev_search.extend(sources)
else:
qa = qa_chain_factory(llm, retriever, memory, system_prompt)
qa.invoke(
{
"question": question,
"chat_history": memory.chat_memory,
},
)
except KeyboardInterrupt:
continue
finally:
print("\n")
RAG.cycle_sources_buffer(sources)
9 changes: 5 additions & 4 deletions codebase_indexer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

from langchain.text_splitter import Language

from codebase_indexer.api.models import Command, Context
from codebase_indexer.api.models import Command

OLLAMA_BASE_URL = "http://" + os.environ["OLLAMA_HOST"]
DEFAULT_VECTOR_DB_DIR = os.path.join(
os.path.expanduser("~"), ".codebase-indexer/vectorstores/db"
)

DEFAULT_OLLAMA_INFERENCE_MODEL = "mistral-openorca"
DEFAULT_OLLAMA_INFERENCE_MODEL = "mistral-openorca:7b"

COMMANDS: list[Command] = ["test", "search", "review"]
CONTEXTS: list[Context] = ["from", "infer"]
COMMANDS: list[Command] = ["test", "search", "review", "new_conversation"]

LANGUAGES = [
Language.CPP,
Expand Down Expand Up @@ -59,3 +58,5 @@
Language.CSHARP: ["cs", "cshtml"],
Language.COBOL: ["cbl"],
}

MAX_SOURCES_WINDOW = 10
93 changes: 85 additions & 8 deletions codebase_indexer/rag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import astuple, dataclass, field
from dataclasses import dataclass
from os import PathLike
from typing import Any, TypedDict

Expand All @@ -10,9 +10,7 @@
from langchain.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import (
DocumentCompressorPipeline,
EmbeddingsFilter,
)
DocumentCompressorPipeline, EmbeddingsFilter)
from langchain.schema import BaseRetriever
from langchain_community.chat_models import ChatOllama
from langchain_community.document_transformers import EmbeddingsRedundantFilter
Expand All @@ -21,13 +19,16 @@
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever

from codebase_indexer.api.models import Command
from codebase_indexer.constants import DEFAULT_OLLAMA_INFERENCE_MODEL, OLLAMA_BASE_URL
from codebase_indexer.constants import (DEFAULT_OLLAMA_INFERENCE_MODEL,
MAX_SOURCES_WINDOW, OLLAMA_BASE_URL)
from codebase_indexer.rag.agents import create_tools
from codebase_indexer.rag.chains import create_search_request_removal_chain
from codebase_indexer.rag.prompts import (
DEFAULT_CONVERSATIONAL_RETRIEVAL_CHAIN_PROMPT,
REVIEW_CONVERSATIONAL_RETRIEVAL_CHAIN_PROMPT,
TEST_CONVERSATIONAL_RETRIEVAL_CHAIN_PROMPT,
)
TEST_CONVERSATIONAL_RETRIEVAL_CHAIN_PROMPT)
from codebase_indexer.rag.vector_store import create_index, load_index
from codebase_indexer.utils import strip_generated_text


def init_vector_store(
Expand Down Expand Up @@ -73,7 +74,7 @@ def create_llm(*, params: OllamaLLMParams) -> ChatOllama:

def create_retriever(db: VectorStore):
retriever_kwargs: dict[str, str | dict[str, Any]] = dict(
search_type="mmr", search_kwargs=dict(k=5, fetch_k=1000)
search_type="mmr", search_kwargs=dict(k=8, fetch_k=1000)
)
return db.as_retriever(**retriever_kwargs)

Expand Down Expand Up @@ -109,6 +110,8 @@ def create_conversational_retrieval_chain(
def create_contextual_retriever(
embeddings: Embeddings, retriever: VectorStoreRetriever
):
initial_args = dict(**retriever.search_kwargs)
retriever.search_kwargs.update(dict(k=20))
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.8)
pipeline_compressor = DocumentCompressorPipeline(
Expand All @@ -117,6 +120,7 @@ def create_contextual_retriever(
compression_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=retriever
)
retriever.search_kwargs = initial_args
return compression_retriever


Expand All @@ -125,6 +129,79 @@ class RAG:
def create():
return None, create_conversational_retrieval_chain, None

@staticmethod
def extract_commands(llm: ChatOllama, question: str) -> list[Command]:
command_extraction_tool, _ = create_tools(llm)

extracted_commands = [
*filter(
lambda x: x.lower() != "n/a", # type: ignore
map(
lambda x: x.strip(),
(
strip_generated_text(
command_extraction_tool.invoke({"question": question}).get(
"text", ""
)
)
or ""
).split(", "),
),
)
]
if "search" in extracted_commands:
extracted_commands.remove("search")
return ["search", *extracted_commands] # type: ignore
return extracted_commands # type: ignore

@staticmethod
def extract_sources_to_search_in(
llm: ChatOllama, question: str, sources: list[str] | None = None
) -> list[str]:
_, file_path_extraction_tool = create_tools(llm)
sources = sources or []

file_paths = strip_generated_text(
file_path_extraction_tool.invoke({"question": question}).get("text", "")
)
if file_paths.lower() != "n/a":
sources.extend([*map(lambda x: x.strip(), file_paths.split(","))])

return sources

@staticmethod
def filter_on_sources(
retriever: VectorStoreRetriever, sources: list[str]
) -> VectorStoreRetriever:
if len(sources) > 0:
retriever.search_kwargs.update(dict(filter={"source": {"$in": sources}}))
return retriever

@staticmethod
def cycle_sources_buffer(sources: list[str]) -> list[str]:
if len(sources) <= MAX_SOURCES_WINDOW:
return sources
return sources[len(sources) - MAX_SOURCES_WINDOW : MAX_SOURCES_WINDOW]

@staticmethod
def search(retriever: ContextualCompressionRetriever, question: str) -> list[str]:
docs = retriever.get_relevant_documents(question)

sources = [doc.metadata["source"] for doc in docs]
return sources

@staticmethod
def remove_search_request(llm: ChatOllama, question: str) -> str:
question = (
strip_generated_text(
create_search_request_removal_chain(llm)
.invoke({"question": question})
.get("text", question)
)
or question
)
return question

@classmethod
def from_command(cls, command: Command | None):
if command == "test":
Expand Down
73 changes: 73 additions & 0 deletions codebase_indexer/rag/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from langchain import hub
from langchain.agents import AgentExecutor, Tool, create_structured_chat_agent
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOllama

from codebase_indexer.constants import COMMANDS


def create_tools(llm: ChatOllama):
command_prompt = PromptTemplate(
input_variables=["question"],
template=f"""Only answer the question with the most suitable commands, out of: {[*COMMANDS]}.
If only one command suits the question, the answer should looks like this: "command_1".
If multiple commands suit the question, the answer should looks like this: "command_1, command_2, command_3".
If no command suits the question, simply answer with "N/A".
Question: {'{question}'}
Answer: """,
)
command_chain = LLMChain(llm=llm, prompt=command_prompt)

file_path_extractor_prompt = PromptTemplate(
input_variables=["question"],
template=f"""Only answer with the file paths provided in the question, in CSV format. If you can't find a file path in the question, answer with "N/A".
A file path is any string that looks like this: path/to/file
EXAMPLE START
Question: Can you review my implementation from those files: path/to/file/a.js, path/to/file/b.rs and path/to/file/c.py
Answer: path/to/file/a.js, path/to/file/b.rs, path/to/file/c.py
EXAMPLE END
EXAMPLE START
Question: How do I fix this bug?
Answer: N/A
EXAMPLE END
Question: {'{question}'}
Answer: """,
)
file_path_extractor_chain = LLMChain(llm=llm, prompt=file_path_extractor_prompt)

command_tool = Tool(
name="Command classification",
func=command_chain.invoke,
return_direct=True,
description=f"Useful to figure out what command out of the following relates best to the question: {COMMANDS}.",
)
file_path_extractor_tool = Tool(
name="File path",
func=file_path_extractor_chain.invoke,
return_direct=True,
description="Useful to extract file paths from a question, if there are any.",
)
return [command_tool, file_path_extractor_tool]


def create_agent(llm: ChatOllama, tools: list[Tool]):
prompt = hub.pull("hwchase17/structured-chat-agent")
agent = create_structured_chat_agent(
llm=llm,
tools=tools,
prompt=prompt,
)
agent_executor = AgentExecutor(
agent=agent, # type: ignore
tools=tools,
max_iterations=3,
verbose=False,
)
return agent_executor
12 changes: 12 additions & 0 deletions codebase_indexer/rag/chains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from langchain.chains import LLMChain
from langchain_community.chat_models import ChatOllama

from codebase_indexer.rag.prompts import QUERY_EXPANSION_PROMPT, SEARCH_REQUEST_REMOVAL


def create_query_expansion_chain(llm: ChatOllama):
return LLMChain(llm=llm, prompt=QUERY_EXPANSION_PROMPT)


def create_search_request_removal_chain(llm: ChatOllama):
return LLMChain(llm=llm, prompt=SEARCH_REQUEST_REMOVAL)
Loading

0 comments on commit 4baddff

Please sign in to comment.