Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major improvements on chat completion. #198

Merged
merged 3 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,52 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
## [0.1.0-alpha.20] - 2025-01-19

### Upgrade Guide

- There were substancial changes on the codebase search engines, as it was rewritten to improve the quality of the search results. You will need to run command `update_index` with `reset_all=True` to update the indexes.

### Added

- Show comment on the issue when the agent is replanning the tasks.
- Repositories file paths are now being indexed too, allowing the agent to search for file paths. **You must run command `update_index` with `reset_all=True` to update the indexes to include file paths.**
- Repositories file paths are now being indexed too, allowing the agent to search for file paths.
- New option `reset_all` added to the `update_index` command to allow reset indexes from all branches, not only the default branch.

### Changed

- Improved `CodebaseQAAgent` response, even when tool calls are not being called. Added web search tool to the agent to allow it to search for answers when the codebase doesn't have the information.
- Changed models url paths on Chat API from `api/v1/chat/models` -> `api/v1/models` to be more consistent with OpenAI API. **This is a breaking change, as it will affect all clients using the Chat API.**
- Rewriten lexical search engine to only have one index, allowing the agent to search for answers in multiple repositories at once.
- Rewritten `CodebaseSearchAgent` to improve the quality of search results by using techniques such as: generating multiple queries, rephrasing those queries and compressing documents with listwise reranking.
- Simplified `CodebaseQAAgent` to use `CodebaseSearchAgent` to search for answers instead of binding tools to the agent.
- Changed models url paths on Chat API from `api/v1/chat/models` -> `api/v1/models` to be more consistent with OpenAI API.
- Default embedding model changed to `text-embedding-3-large` to improve the quality of the search results. The dimension of stored vectors was increased from 1536 to 2000.

### Fixed

- Issues with images were not being processed correctly, leading to errors interpreting the image from the description.
- Chat API was not returning the correct response structure when not on streaming mode.
- Codebase retriever used for non-scoped indexes on `CodebaseQAAgent` was returning duplicate documents, from different branches. Now it's filtering always by repository default branch.q
- Codebase retriever used for non-scoped indexes on `CodebaseSearchAgent` was returning duplicate documents, from different refs. Now it's filtering always by repository default branch.

### Chore

- Updated dependencies:
- `django` from 5.1.4 to 5.1.5
- `duckduckgo-search` from 7.2.0 to 7.2.1
- `httpx` from 0.27.2 to 0.28.1
- `langchain-anthropic` from 0.3.2 to 0.3.3
- `langchain-openai` from 0.3.0 to 0.3.1
- `langchain-text-splitters` from 0.3.4 to 0.3.5
- `langgraph` from 0.2.61 to 0.2.64
- `langgraph-checkpoint-postgres` from 2.0.9 to 2.0.13
- `langsmith` from 0.2.10 to 0.2.11
- `psycopg` from 3.2.3 to 3.2.4
- `pyopenssl` from 24.3 to 25
- `pydantic` from 2.10.4 to 2.10.5
- `pytest-asyncio` from 0.25.1 to 0.25.2
- `python-gitlab` from 5.3.0 to 5.3.1
- `ruff` from 0.8.7 to 0.9.2
- `sentry-sdk` from 2.19.2 to 2.20
- `watchfiles` from 1.0.3 to 1.0.4

## [0.1.0-alpha.19] - 2025-01-10

Expand Down
3 changes: 2 additions & 1 deletion daiv/automation/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from enum import StrEnum
from typing import TYPE_CHECKING, Generic, TypeVar, cast

from langchain.chat_models.base import BaseChatModel, _attempt_infer_model_provider, init_chat_model
from langchain.chat_models.base import _attempt_infer_model_provider, init_chat_model
from langchain_community.callbacks import OpenAICallbackHandler
from langchain_core.runnables import Runnable, RunnableConfig
from pydantic import BaseModel

from automation.conf import settings

if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_openai.chat_models import ChatOpenAI
from langgraph.checkpoint.postgres import PostgresSaver
Expand Down
73 changes: 32 additions & 41 deletions daiv/automation/agents/codebase_qa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode, tools_condition

from automation.agents import BaseAgent
from automation.conf import settings
from automation.tools.repository import SearchCodeSnippetsTool
from automation.tools.toolkits import WebSearchToolkit
from automation.tools.web_search import WebSearchTool
from codebase.indexes import CodebaseIndex

from .prompts import system, system_query_or_respond
from .prompts import system
from .state import OverallState

logger = logging.getLogger("daiv.agents")
Expand All @@ -21,64 +20,56 @@ class CodebaseQAAgent(BaseAgent[CompiledStateGraph]):
Agent to answer questions about the codebase.
"""

model_name = settings.GENERIC_COST_EFFICIENT_MODEL_NAME
model_name = settings.CODING_COST_EFFICIENT_MODEL_NAME

def __init__(self, index: CodebaseIndex):
self.index = index
self.tools = [SearchCodeSnippetsTool(api_wrapper=index)] + WebSearchToolkit.create_instance().get_tools()
super().__init__()

def get_model_kwargs(self) -> dict:
kwargs = super().get_model_kwargs()
kwargs["temperature"] = 0.3
return kwargs

def compile(self) -> CompiledStateGraph:
workflow = StateGraph(OverallState)

# Add nodes
workflow.add_node("query_or_respond", self.query_or_respond)
workflow.add_node("tools", ToolNode(self.tools))
workflow.add_node("retrieve", self.retrieve)
workflow.add_node("generate", self.generate)

# Add edges
workflow.add_edge(START, "query_or_respond")
workflow.add_conditional_edges("query_or_respond", tools_condition, {END: END, "tools": "tools"})
workflow.add_edge("tools", "generate")
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)

return workflow.compile()

def query_or_respond(self, state: OverallState):
def retrieve(self, state: OverallState):
"""
Generate tool call for retrieval or respond.
Retrieve the context.
"""
llm_with_tools = self.model.bind_tools(self.tools)
response = llm_with_tools.invoke([system_query_or_respond] + state["messages"])
return {"messages": [response]}
context = SearchCodeSnippetsTool(api_wrapper=self.index).invoke({
"query": state["messages"][-1].content,
"intent": "Searching the codebase.",
})
if not context:
context = WebSearchTool().invoke({
"query": state["messages"][-1].content,
"intent": "No code snippets found, searching the web.",
})
return {"context": context}

def generate(self, state: OverallState):
"""
Generate answer.
"""
tool_messages = []
for message in reversed(state["messages"]):
if message.type == "tool":
tool_messages.append(message)

if tool_messages:
docs_content = "\n\n".join(doc.content for doc in tool_messages[::-1])

conversation_messages = [
message
for message in state["messages"]
if message.type in ("human", "system") or (message.type == "ai" and not message.tool_calls)
]
prompt = [
system.format(
context=docs_content,
codebase_client=self.index.repo_client.client_slug,
codebase_url=self.index.repo_client.codebase_url,
)
] + conversation_messages

response = self.model.invoke(prompt)
else:
response = state["messages"][-1]

return {"messages": [response]}
prompt = [
system.format(
context=state["context"],
codebase_client=self.index.repo_client.client_slug,
codebase_url=self.index.repo_client.codebase_url,
)
] + state["messages"]

return {"messages": [self.model.invoke(prompt)]}
2 changes: 1 addition & 1 deletion daiv/automation/agents/codebase_qa/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


class OverallState(MessagesState):
pass
context: str
160 changes: 41 additions & 119 deletions daiv/automation/agents/codebase_search/agent.py
Original file line number Diff line number Diff line change
@@ -1,148 +1,70 @@
from __future__ import annotations

import logging
from typing import cast
from typing import TYPE_CHECKING

from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langchain.retrievers.document_compressors import LLMListwiseRerank
from langchain_core.documents import Document
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough

from automation.agents import BaseAgent
from automation.agents.codebase_search.schemas import GradeDocumentsOutput, ImprovedQueryOutput
from automation.conf import settings
from codebase.indexes import CodebaseIndex
from automation.retrievers import MultiQueryRephraseRetriever

if TYPE_CHECKING:
from collections.abc import Sequence

from .prompts import grade_human, grade_system, re_write_human, re_write_system
from .state import GradeDocumentState, OverallState
from langchain_core.retrievers import BaseRetriever

logger = logging.getLogger("daiv.agents")


class CodebaseSearchAgent(BaseAgent[CompiledStateGraph]):
class CodebaseSearchAgent(BaseAgent[Runnable[str, list[Document]]]):
"""
Agent to search for code snippets in the codebase.
"""

def __init__(self, index: CodebaseIndex, source_repo_id: str | None = None, source_ref: str | None = None):
super().__init__()
self.index = index
self.source_repo_id = source_repo_id
self.source_ref = source_ref

def compile(self) -> CompiledStateGraph:
workflow = StateGraph(OverallState)

# Add nodes
workflow.add_node("retrieve", self.retrieve)
workflow.add_node("grade_document", self.grade_document)
workflow.add_node("post_grade_document", self.post_grade_document)
workflow.add_node("transform_query", self.transform_query)

# Add edges
workflow.add_edge(START, "retrieve")
workflow.add_conditional_edges(
"retrieve", self.should_grade_documents, ["transform_query", "grade_document", END]
)
workflow.add_edge("grade_document", "post_grade_document")
workflow.add_conditional_edges("post_grade_document", self.should_transform_query, ["transform_query", END])
workflow.add_edge("transform_query", "retrieve")
model_name = settings.CODING_COST_EFFICIENT_MODEL_NAME

return workflow.compile()
def __init__(self, retriever: BaseRetriever, *args, **kwargs):
self.retriever = retriever
super().__init__(*args, **kwargs)

def retrieve(self, state: OverallState):
def compile(self) -> Runnable:
"""
Retrieve documents from the codebase index.
Compile the agent into a Runnable.

Args:
state (GraphState): The current state of the graph.
Returns:
Runnable: The compiled agent
"""
if self.source_repo_id and self.source_ref:
# we need to update the index before retrieving the documents
# because the codebase search agent needs to search for the codebase changes
# and we need to make sure the index is updated before the agent starts retrieving the documents
self.index.update(self.source_repo_id, self.source_ref)

if self.source_repo_id and self.source_ref:
documents = self.index.search(self.source_repo_id, self.source_ref, state["query"])
else:
documents = self.index.search_all(state["query"])

return {"documents": documents, "iterations": state.get("iterations", 0) + 1}
return {
"query": RunnablePassthrough(),
"documents": MultiQueryRephraseRetriever.from_llm(self.retriever, llm=self.model),
} | RunnableLambda(
lambda inputs: self._compress_documents(inputs["documents"], inputs["query"]), name="compress_documents"
)

def grade_document(self, state: GradeDocumentState):
def get_model_kwargs(self) -> dict:
"""
Grade the relevance of the retrieved document to the query.
Get the model kwargs with a redefined temperature to make the model more creative.

Args:
state (GraphState): The current state of the graph.
"""
grader_agent = self.model.with_structured_output(GradeDocumentsOutput, method="json_schema")

messages = [
grade_system,
grade_human.format(
query=state["query"], query_intent=state["query_intent"], document=state["document"].page_content
),
]
response = cast("GradeDocumentsOutput", grader_agent.invoke(messages))

if response.binary_score:
logger.info("[grade_document] Document '%s' is relevant to the query", state["document"].metadata["source"])
return {"documents": []}
return {"documents": [state["document"]]}

def post_grade_document(self, state: OverallState):
"""
Post-process the grade of the document.
Returns:
dict: The model kwargs
"""
return {"documents": []}
kwargs = super().get_model_kwargs()
kwargs["temperature"] = 0.5
return kwargs

def transform_query(self, state: OverallState):
def _compress_documents(self, documents: list[Document], query: str) -> Sequence[Document]:
"""
Transform the query to improve retrieval.
Compress the documents using a listwise reranker.

Args:
state (GraphState): The current state of the graph.
"""
messages = [re_write_system, re_write_human.format(query=state["query"], query_intent=state["query_intent"])]

query_rewriter = self.model.with_structured_output(ImprovedQueryOutput, method="json_schema")
response = cast(
"ImprovedQueryOutput", query_rewriter.invoke(messages, config={"configurable": {"temperature": 0.7}})
)

logger.info("[transform_query] Query '%s' improved to '%s'", state["query"], response.query)

return {"query": response.query}
documents (Sequence[Document]): The documents to compress
query (str): The search query string

def should_grade_documents(self, state: OverallState):
Returns:
Sequence[Document]: The compressed documents
"""
Check if we should transform the query.
"""
if not state["documents"]:
if state["iterations"] < settings.CODEBASE_SEARCH_MAX_TRANSFORMATIONS:
logger.info("[should_grade_documents] No documents retrieved. Moving to transform_query state.")
return "transform_query"
else:
logger.info("[should_grade_documents] No documents retrieved. Ending the process.")
return END

logger.info("[should_grade_documents] Documents retrieved. Moving to grade_documents state.")
return [
Send(
"grade_document", {"document": document, "query": state["query"], "query_intent": state["query_intent"]}
)
for document in state["documents"]
]

def should_transform_query(self, state: OverallState):
"""
Check if we should transform the query.
"""
logger.info(
"[should_transform_query] %d documents found (iteration: %d)", len(state["documents"]), state["iterations"]
)
if not state["documents"] and state["iterations"] < settings.CODEBASE_SEARCH_MAX_TRANSFORMATIONS:
logger.info("[should_transform_query] No relevant documents found. Moving to transform_query state.")
return "transform_query"
if not state["documents"]:
logger.info("[should_transform_query] No relevant documents found. Ending the process.")
return END
reranker = LLMListwiseRerank.from_llm(llm=self.model, top_n=5)
return reranker.compress_documents(documents, query)
2 changes: 1 addition & 1 deletion daiv/automation/agents/issue_addressor/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.store.base import BaseStore # noqa: TC002
from langgraph.store.memory import InMemoryStore

from automation.agents import BaseAgent
Expand Down Expand Up @@ -33,7 +34,6 @@

if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
from langgraph.store.base import BaseStore

from codebase.base import FileChange
from codebase.clients import AllRepoClient
Expand Down
Loading
Loading