diff --git a/codebase_indexer/api/models.py b/codebase_indexer/api/models.py index 1671b79..499a8b0 100644 --- a/codebase_indexer/api/models.py +++ b/codebase_indexer/api/models.py @@ -10,5 +10,4 @@ class Meta(BaseModel): ollama_inference_model: str | None = None -Command = Literal["test", "search", "review"] -Context = Literal["from", "infer"] +Command = Literal["test", "search", "review", "new_conversation"] diff --git a/codebase_indexer/cli/__init__.py b/codebase_indexer/cli/__init__.py index 8817b73..8add2c4 100644 --- a/codebase_indexer/cli/__init__.py +++ b/codebase_indexer/cli/__init__.py @@ -4,21 +4,14 @@ from git import Repo from langchain_core.callbacks import StreamingStdOutCallbackHandler -from codebase_indexer.api.models import Command, Context from codebase_indexer.cli.argparser import Args -from codebase_indexer.constants import ( - COMMANDS, - CONTEXTS, - DEFAULT_OLLAMA_INFERENCE_MODEL, - DEFAULT_VECTOR_DB_DIR, -) -from codebase_indexer.rag import ( - RAG, - create_llm, - create_memory, - create_retriever, - init_vector_store, -) +from codebase_indexer.constants import (DEFAULT_OLLAMA_INFERENCE_MODEL, + DEFAULT_VECTOR_DB_DIR) +from codebase_indexer.rag import (RAG, create_llm, create_memory, + create_retriever, init_vector_store) +from codebase_indexer.rag.agents import create_tools +from codebase_indexer.rag.chains import create_search_request_removal_chain +from codebase_indexer.utils import strip_generated_text def cli(args: Args): @@ -50,65 +43,51 @@ def cli(args: Args): } ) memory = create_memory(llm) - prev_search: list[str] = [] + sources: list[str] = [] while True: question = input(">>> ") - command: Command | None = None - contexts: dict[Context, list[str]] = {} - for command_iter in COMMANDS: - if question.startswith(f"/{command_iter}"): - command = command_iter - question = question.replace(f"/{command}", "").strip() - break - # Very ugly, but I can't be asked to implement a proper DSL... - for context in CONTEXTS: - if question.startswith(f"@{context}"): - question = question.replace(f"@{context}", "") - argument = None - if context == "from": - argument, question = question.split(" ", 1) - argument = [*contexts.get(context, []), argument] - elif context == "infer": - argument = prev_search + commands = RAG.extract_commands(llm, question) + sources = RAG.extract_sources_to_search_in(llm, question, sources) - if argument is not None: - contexts.update({context: argument}) - question = question.strip() + for command in commands: + if command == "new_conversation": + sources = [] + continue - system_prompt, qa_chain_factory, custom_retriever_factory = RAG.from_command( - command - ) + system_prompt, qa_chain_factory, custom_retriever_factory = ( + RAG.from_command(command) + ) - try: - retriever = create_retriever(db) - sources = contexts.get("from", []) + contexts.get("infer", []) - if len(sources) > 0: - retriever.search_kwargs.update( - dict(filter={"source": {"$in": sources}}) - ) - prev_search = [] - if command == "search": - if custom_retriever_factory is not None and db.embeddings is not None: - retriever.search_kwargs.update(dict(k=20)) - retriever = custom_retriever_factory(db.embeddings, retriever) - docs = retriever.get_relevant_documents( - question, callbacks=llm.callbacks + try: + retriever = create_retriever(db) + if command == "search": + if ( + custom_retriever_factory is not None + and db.embeddings is not None + ): + embeddings_retriever = custom_retriever_factory( + db.embeddings, retriever + ) + sources = RAG.search(embeddings_retriever, question) + for source in sources: + print(f"\t- {source}") + + sources.extend(sources) + question = RAG.remove_search_request(llm, question) + + else: + retriever = RAG.filter_on_sources(retriever, sources) + qa = qa_chain_factory(llm, retriever, memory, system_prompt) + qa.invoke( + { + "question": question, + "chat_history": memory.chat_memory, + }, ) + except KeyboardInterrupt: + continue + finally: + print("\n") - sources = [doc.metadata["source"] for doc in docs] - for source in sources: - print(f"\t- {source}") - prev_search.extend(sources) - else: - qa = qa_chain_factory(llm, retriever, memory, system_prompt) - qa.invoke( - { - "question": question, - "chat_history": memory.chat_memory, - }, - ) - except KeyboardInterrupt: - continue - finally: - print("\n") + RAG.cycle_sources_buffer(sources) diff --git a/codebase_indexer/constants.py b/codebase_indexer/constants.py index 7f7cd57..a3d4261 100644 --- a/codebase_indexer/constants.py +++ b/codebase_indexer/constants.py @@ -2,17 +2,16 @@ from langchain.text_splitter import Language -from codebase_indexer.api.models import Command, Context +from codebase_indexer.api.models import Command OLLAMA_BASE_URL = "http://" + os.environ["OLLAMA_HOST"] DEFAULT_VECTOR_DB_DIR = os.path.join( os.path.expanduser("~"), ".codebase-indexer/vectorstores/db" ) -DEFAULT_OLLAMA_INFERENCE_MODEL = "mistral-openorca" +DEFAULT_OLLAMA_INFERENCE_MODEL = "mistral-openorca:7b" -COMMANDS: list[Command] = ["test", "search", "review"] -CONTEXTS: list[Context] = ["from", "infer"] +COMMANDS: list[Command] = ["test", "search", "review", "new_conversation"] LANGUAGES = [ Language.CPP, @@ -59,3 +58,5 @@ Language.CSHARP: ["cs", "cshtml"], Language.COBOL: ["cbl"], } + +MAX_SOURCES_WINDOW = 10 diff --git a/codebase_indexer/rag/__init__.py b/codebase_indexer/rag/__init__.py index d35fb76..9d087b3 100644 --- a/codebase_indexer/rag/__init__.py +++ b/codebase_indexer/rag/__init__.py @@ -1,4 +1,4 @@ -from dataclasses import astuple, dataclass, field +from dataclasses import dataclass from os import PathLike from typing import Any, TypedDict @@ -10,9 +10,7 @@ from langchain.prompts import PromptTemplate from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import ( - DocumentCompressorPipeline, - EmbeddingsFilter, -) + DocumentCompressorPipeline, EmbeddingsFilter) from langchain.schema import BaseRetriever from langchain_community.chat_models import ChatOllama from langchain_community.document_transformers import EmbeddingsRedundantFilter @@ -21,13 +19,16 @@ from langchain_core.vectorstores import VectorStore, VectorStoreRetriever from codebase_indexer.api.models import Command -from codebase_indexer.constants import DEFAULT_OLLAMA_INFERENCE_MODEL, OLLAMA_BASE_URL +from codebase_indexer.constants import (DEFAULT_OLLAMA_INFERENCE_MODEL, + MAX_SOURCES_WINDOW, OLLAMA_BASE_URL) +from codebase_indexer.rag.agents import create_tools +from codebase_indexer.rag.chains import create_search_request_removal_chain from codebase_indexer.rag.prompts import ( DEFAULT_CONVERSATIONAL_RETRIEVAL_CHAIN_PROMPT, REVIEW_CONVERSATIONAL_RETRIEVAL_CHAIN_PROMPT, - TEST_CONVERSATIONAL_RETRIEVAL_CHAIN_PROMPT, -) + TEST_CONVERSATIONAL_RETRIEVAL_CHAIN_PROMPT) from codebase_indexer.rag.vector_store import create_index, load_index +from codebase_indexer.utils import strip_generated_text def init_vector_store( @@ -73,7 +74,7 @@ def create_llm(*, params: OllamaLLMParams) -> ChatOllama: def create_retriever(db: VectorStore): retriever_kwargs: dict[str, str | dict[str, Any]] = dict( - search_type="mmr", search_kwargs=dict(k=5, fetch_k=1000) + search_type="mmr", search_kwargs=dict(k=8, fetch_k=1000) ) return db.as_retriever(**retriever_kwargs) @@ -109,6 +110,8 @@ def create_conversational_retrieval_chain( def create_contextual_retriever( embeddings: Embeddings, retriever: VectorStoreRetriever ): + initial_args = dict(**retriever.search_kwargs) + retriever.search_kwargs.update(dict(k=20)) redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.8) pipeline_compressor = DocumentCompressorPipeline( @@ -117,6 +120,7 @@ def create_contextual_retriever( compression_retriever = ContextualCompressionRetriever( base_compressor=pipeline_compressor, base_retriever=retriever ) + retriever.search_kwargs = initial_args return compression_retriever @@ -125,6 +129,79 @@ class RAG: def create(): return None, create_conversational_retrieval_chain, None + @staticmethod + def extract_commands(llm: ChatOllama, question: str) -> list[Command]: + command_extraction_tool, _ = create_tools(llm) + + extracted_commands = [ + *filter( + lambda x: x.lower() != "n/a", # type: ignore + map( + lambda x: x.strip(), + ( + strip_generated_text( + command_extraction_tool.invoke({"question": question}).get( + "text", "" + ) + ) + or "" + ).split(", "), + ), + ) + ] + if "search" in extracted_commands: + extracted_commands.remove("search") + return ["search", *extracted_commands] # type: ignore + return extracted_commands # type: ignore + + @staticmethod + def extract_sources_to_search_in( + llm: ChatOllama, question: str, sources: list[str] | None = None + ) -> list[str]: + _, file_path_extraction_tool = create_tools(llm) + sources = sources or [] + + file_paths = strip_generated_text( + file_path_extraction_tool.invoke({"question": question}).get("text", "") + ) + if file_paths.lower() != "n/a": + sources.extend([*map(lambda x: x.strip(), file_paths.split(","))]) + + return sources + + @staticmethod + def filter_on_sources( + retriever: VectorStoreRetriever, sources: list[str] + ) -> VectorStoreRetriever: + if len(sources) > 0: + retriever.search_kwargs.update(dict(filter={"source": {"$in": sources}})) + return retriever + + @staticmethod + def cycle_sources_buffer(sources: list[str]) -> list[str]: + if len(sources) <= MAX_SOURCES_WINDOW: + return sources + return sources[len(sources) - MAX_SOURCES_WINDOW : MAX_SOURCES_WINDOW] + + @staticmethod + def search(retriever: ContextualCompressionRetriever, question: str) -> list[str]: + docs = retriever.get_relevant_documents(question) + + sources = [doc.metadata["source"] for doc in docs] + return sources + + @staticmethod + def remove_search_request(llm: ChatOllama, question: str) -> str: + question = ( + strip_generated_text( + create_search_request_removal_chain(llm) + .invoke({"question": question}) + .get("text", question) + ) + or question + ) + return question + @classmethod def from_command(cls, command: Command | None): if command == "test": diff --git a/codebase_indexer/rag/agents.py b/codebase_indexer/rag/agents.py new file mode 100644 index 0000000..757976c --- /dev/null +++ b/codebase_indexer/rag/agents.py @@ -0,0 +1,73 @@ +from langchain import hub +from langchain.agents import AgentExecutor, Tool, create_structured_chat_agent +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate +from langchain_community.chat_models import ChatOllama + +from codebase_indexer.constants import COMMANDS + + +def create_tools(llm: ChatOllama): + command_prompt = PromptTemplate( + input_variables=["question"], + template=f"""Only answer the question with the most suitable commands, out of: {[*COMMANDS]}. + + If only one command suits the question, the answer should looks like this: "command_1". + If multiple commands suit the question, the answer should looks like this: "command_1, command_2, command_3". + If no command suits the question, simply answer with "N/A". + + Question: {'{question}'} + Answer: """, + ) + command_chain = LLMChain(llm=llm, prompt=command_prompt) + + file_path_extractor_prompt = PromptTemplate( + input_variables=["question"], + template=f"""Only answer with the file paths provided in the question, in CSV format. If you can't find a file path in the question, answer with "N/A". + + A file path is any string that looks like this: path/to/file + + EXAMPLE START + Question: Can you review my implementation from those files: path/to/file/a.js, path/to/file/b.rs and path/to/file/c.py + Answer: path/to/file/a.js, path/to/file/b.rs, path/to/file/c.py + EXAMPLE END + + EXAMPLE START + Question: How do I fix this bug? + Answer: N/A + EXAMPLE END + + Question: {'{question}'} + Answer: """, + ) + file_path_extractor_chain = LLMChain(llm=llm, prompt=file_path_extractor_prompt) + + command_tool = Tool( + name="Command classification", + func=command_chain.invoke, + return_direct=True, + description=f"Useful to figure out what command out of the following relates best to the question: {COMMANDS}.", + ) + file_path_extractor_tool = Tool( + name="File path", + func=file_path_extractor_chain.invoke, + return_direct=True, + description="Useful to extract file paths from a question, if there are any.", + ) + return [command_tool, file_path_extractor_tool] + + +def create_agent(llm: ChatOllama, tools: list[Tool]): + prompt = hub.pull("hwchase17/structured-chat-agent") + agent = create_structured_chat_agent( + llm=llm, + tools=tools, + prompt=prompt, + ) + agent_executor = AgentExecutor( + agent=agent, # type: ignore + tools=tools, + max_iterations=3, + verbose=False, + ) + return agent_executor diff --git a/codebase_indexer/rag/chains.py b/codebase_indexer/rag/chains.py new file mode 100644 index 0000000..6208945 --- /dev/null +++ b/codebase_indexer/rag/chains.py @@ -0,0 +1,12 @@ +from langchain.chains import LLMChain +from langchain_community.chat_models import ChatOllama + +from codebase_indexer.rag.prompts import QUERY_EXPANSION_PROMPT, SEARCH_REQUEST_REMOVAL + + +def create_query_expansion_chain(llm: ChatOllama): + return LLMChain(llm=llm, prompt=QUERY_EXPANSION_PROMPT) + + +def create_search_request_removal_chain(llm: ChatOllama): + return LLMChain(llm=llm, prompt=SEARCH_REQUEST_REMOVAL) diff --git a/codebase_indexer/rag/prompts/__init__.py b/codebase_indexer/rag/prompts/__init__.py index 7849a13..45e8947 100644 --- a/codebase_indexer/rag/prompts/__init__.py +++ b/codebase_indexer/rag/prompts/__init__.py @@ -26,3 +26,12 @@ def read_prompt_template(file_path: str) -> str: input_variables=["summary", "new_lines"], template=read_prompt_template("./conversation_summary_memory.txt"), ) + +QUERY_EXPANSION_PROMPT = PromptTemplate( + input_variables=["question"], template=read_prompt_template("./query_expansion.txt") +) + +SEARCH_REQUEST_REMOVAL = PromptTemplate( + input_variables=["question"], + template=read_prompt_template("./search_request_removal.txt"), +) diff --git a/codebase_indexer/rag/prompts/query_expansion.txt b/codebase_indexer/rag/prompts/query_expansion.txt new file mode 100644 index 0000000..d37ef4d --- /dev/null +++ b/codebase_indexer/rag/prompts/query_expansion.txt @@ -0,0 +1,5 @@ +Rephrase the following question such that it is more clear, more detailed and with no typography errors. Always assume you know the context and that the context is related to code or programming. + +Initial question: {question} +Rephrased question: + diff --git a/codebase_indexer/rag/prompts/search_request_removal.txt b/codebase_indexer/rag/prompts/search_request_removal.txt new file mode 100644 index 0000000..e25c74d --- /dev/null +++ b/codebase_indexer/rag/prompts/search_request_removal.txt @@ -0,0 +1,5 @@ +If the following question is not related to searching for code in a particular place, keep the question as is. Otherwise, rephrase the question so that it doesn't ask about searching for code or files in a particular place. + +Initial question: {question} +Rephrased question: + diff --git a/codebase_indexer/utils.py b/codebase_indexer/utils.py new file mode 100644 index 0000000..5eb9673 --- /dev/null +++ b/codebase_indexer/utils.py @@ -0,0 +1,2 @@ +def strip_generated_text(s: str) -> str: + return s.replace("<|im_end|>", "").replace("'", "").replace('"', "").strip() diff --git a/main.py b/main.py index f80b9bb..3f5c7a0 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,6 @@ ) from codebase_indexer.rag import ( CodebaseIndexer, - OllamaLLMParams, create_llm, create_memory, init_vector_store,