Skip to content

Commit

Permalink
Integrated gemini and deepseek chat models. (#204)
Browse files Browse the repository at this point in the history
* Integrated gemini and deepseel chat models.

* Fixed unittests.

* Updated changelog.
  • Loading branch information
srtab authored Jan 24, 2025
1 parent 91bf145 commit d2a2acd
Show file tree
Hide file tree
Showing 15 changed files with 444 additions and 66 deletions.
10 changes: 6 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed

- Performance improvements and cleaner use of compression retrievers on `CodebaseSearchAgent`.

### Added

- Codebase search now allows to configure how many results are returned by the search.
- Added support to Google GenAI and DeepSeek chat models.
- Added command to `search_documents` to search for documents on the codebase, usefull for debugging and testing.

### Changed

- Performance improvements and cleaner use of compression retrievers on `CodebaseSearchAgent`.

### Fixed

Expand Down
49 changes: 39 additions & 10 deletions daiv/automation/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
class ModelProvider(StrEnum):
ANTHROPIC = "anthropic"
OPENAI = "openai"
DEEPSEEK = "deepseek"
GOOGLE_GENAI = "google_genai"


T = TypeVar("T", bound=Runnable)
Expand Down Expand Up @@ -81,11 +83,23 @@ def get_model_kwargs(self) -> dict:
"model_kwargs": {},
}

if BaseAgent.get_model_provider(self.model_name) == ModelProvider.ANTHROPIC:
kwargs["model_kwargs"]["extra_headers"] = {
"anthropic-beta": "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15"
}
kwargs["max_tokens"] = str(self.get_max_token_value())
model_provider = BaseAgent.get_model_provider(self.model_name)

if model_provider == ModelProvider.ANTHROPIC:
# As stated in docs: https://docs.anthropic.com/en/api/rate-limits#updated-rate-limits
# the OTPM is calculated based on the max_tokens. We need to use a fair value to avoid rate limiting.
# If needed, we can increase this value using the configurable field.
kwargs["max_tokens"] = "2048"
kwargs["model_kwargs"]["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"}
elif model_provider == ModelProvider.DEEPSEEK:
assert settings.DEEPSEEK_API_KEY is not None, "DEEPSEEK_API_KEY is not set"

kwargs["model_provider"] = "openai"
kwargs["base_url"] = settings.DEEPSEEK_API_BASE
kwargs["api_key"] = settings.DEEPSEEK_API_KEY
elif model_provider == ModelProvider.GOOGLE_GENAI:
# otherwise it will be inferred as google_vertexai
kwargs["model_provider"] = "google_genai"
return kwargs

def get_config(self) -> RunnableConfig:
Expand Down Expand Up @@ -128,15 +142,20 @@ def get_max_token_value(self) -> int:

match BaseAgent.get_model_provider(self.model_name):
case ModelProvider.ANTHROPIC:
# As stated in docs: https://docs.anthropic.com/en/api/rate-limits#updated-rate-limits
# the OTPM is calculated based on the max_tokens. We need to use a fair value to avoid rate limiting.
# If needed, we can increase this value using the configurable field.
return 2048
return 8192

case ModelProvider.OPENAI:
_, encoding_model = cast("ChatOpenAI", self.model)._get_encoding_model()
return encoding_model.max_token_value

case ModelProvider.DEEPSEEK:
# As stated in docs: https://api-docs.deepseek.com/quick_start/pricing
return 8192

case ModelProvider.GOOGLE_GENAI:
# As stated in docs: https://ai.google.dev/gemini-api/docs/models/gemini#gemini-2.0-flash
return 8192

case _:
raise ValueError(f"Unknown provider for model {self.model_name}")

Expand All @@ -151,7 +170,17 @@ def get_model_provider(model_name: str) -> ModelProvider:
Returns:
ModelProvider: The model provider
"""
return _attempt_infer_model_provider(model_name)
if model_name.startswith("deepseek"):
model_provider = ModelProvider.DEEPSEEK
elif model_name.startswith("gemini"):
model_provider = ModelProvider.GOOGLE_GENAI
else:
model_provider = cast("ModelProvider | None", _attempt_infer_model_provider(model_name))

if model_provider is None:
raise ValueError(f"Unknown provider for model {model_name}")

return model_provider


class Usage(BaseModel):
Expand Down
12 changes: 12 additions & 0 deletions daiv/automation/agents/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from enum import StrEnum


class ModelName(StrEnum):
CLAUDE_3_5_SONNET_20241022 = "claude-3-5-sonnet-20241022"
CLAUDE_3_5_HAIKU_20241022 = "claude-3-5-haiku-20241022"
GPT_4O_2024_11_20 = "gpt-4o-2024-11-20"
GPT_4O_MINI_2024_07_18 = "gpt-4o-mini-2024-07-18"
GEMINI_2_0_FLASH_EXP = "gemini-2.0-flash-exp"
GEMINI_2_0_FLASH_THINKING_EXP_01_21 = "gemini-2.0-flash-thinking-exp-01-21"
DEEPSEEK_CHAT = "deepseek-chat"
DEEPSEEK_REASONER = "deepseek-reasoner"
2 changes: 1 addition & 1 deletion daiv/automation/agents/error_log_evaluator/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ class ErrorLogEvaluatorAgent(BaseAgent[Runnable[ErrorLogEvaluatorInput, ErrorLog

def compile(self) -> Runnable:
prompt = ChatPromptTemplate.from_messages([system, human])
return prompt | self.model.with_structured_output(ErrorLogEvaluatorOutput, method="json_schema")
return prompt | self.model.with_structured_output(ErrorLogEvaluatorOutput)
2 changes: 1 addition & 1 deletion daiv/automation/agents/image_url_extractor/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def compile(self) -> Runnable:
prompt = ChatPromptTemplate.from_messages([system, human])
return (
prompt
| self.model.with_structured_output(ImageURLExtractorOutput, method="json_schema")
| self.model.with_structured_output(ImageURLExtractorOutput)
| RunnableLambda(_post_process, name="post_process_extracted_images")
)
2 changes: 1 addition & 1 deletion daiv/automation/agents/issue_addressor/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def human_feedback(self, state: OverallState):
dict: The state of the agent to update.
"""

human_feedback_evaluator = self.model.with_structured_output(HumanFeedbackResponse, method="json_schema")
human_feedback_evaluator = self.model.with_structured_output(HumanFeedbackResponse)
result = cast(
"HumanFeedbackResponse",
human_feedback_evaluator.invoke(
Expand Down
4 changes: 1 addition & 3 deletions daiv/automation/agents/pr_describer/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def compile(self) -> Runnable:
prompt = ChatPromptTemplate.from_messages([system, human]).partial(
branch_name_convention=None, extra_details={}
)
return prompt | self.model.with_structured_output(
PullRequestDescriberOutput, method="json_schema"
).with_fallbacks([
return prompt | self.model.with_structured_output(PullRequestDescriberOutput).with_fallbacks([
self.get_model(model=settings.CODING_COST_EFFICIENT_MODEL_NAME).with_structured_output(
PullRequestDescriberOutput
)
Expand Down
4 changes: 1 addition & 3 deletions daiv/automation/agents/prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@ def respond(self, state: AgentState):
except ValidationError:
logger.warning("[ReAcT] Error structuring output with tool args. Fallback to llm with_structured_output.")

llm_with_structured_output = self.model.with_structured_output(
self.with_structured_output, method="json_schema"
)
llm_with_structured_output = self.model.with_structured_output(self.with_structured_output)
response = cast(
"BaseModel",
llm_with_structured_output.invoke(
Expand Down
2 changes: 1 addition & 1 deletion daiv/automation/agents/snippet_replacer/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _route(self, input_data: SnippetReplacerInput) -> Runnable:
Runnable: The appropriate method
"""
if settings.SNIPPET_REPLACER_STRATEGY == "llm" and self.validate_max_token_not_exceeded(input_data):
return self._prompt | self.model.with_structured_output(SnippetReplacerOutput, method="json_schema")
return self._prompt | self.model.with_structured_output(SnippetReplacerOutput)
return RunnableLambda(self._replace_content_snippet)

def _replace_content_snippet(self, input_data: SnippetReplacerInput) -> SnippetReplacerOutput | str:
Expand Down
37 changes: 25 additions & 12 deletions daiv/automation/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,44 @@
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict

from automation.agents.constants import ModelName


class AutomationSettings(BaseSettings):
model_config = SettingsConfigDict(secrets_dir="/run/secrets", env_prefix="AUTOMATION_")

# DeepSeek settings
DEEPSEEK_API_BASE: str = Field(
default="https://api.deepseek.com/v1", description="Base URL for the DeepSeek API", alias="DEEPSEEK_API_BASE"
)
DEEPSEEK_API_KEY: str | None = Field(
default=None, description="API key for the DeepSeek API", alias="DEEPSEEK_API_KEY"
)

# Agent settings
RECURSION_LIMIT: int = Field(default=50, description="Default recursion limit for the agent")
PLANING_PERFORMANT_MODEL_NAME: str = Field(
default="claude-3-5-sonnet-20241022", description="Model name to be used to plan tasks with high performance."
PLANING_PERFORMANT_MODEL_NAME: ModelName = Field(
default=ModelName.CLAUDE_3_5_SONNET_20241022,
description="Model name to be used to plan tasks with high performance.",
)
CODING_PERFORMANT_MODEL_NAME: str = Field(
default="claude-3-5-sonnet-20241022", description="Model name to be used to code with high performance."
CODING_PERFORMANT_MODEL_NAME: ModelName = Field(
default=ModelName.CLAUDE_3_5_SONNET_20241022, description="Model name to be used to code with high performance."
)
CODING_COST_EFFICIENT_MODEL_NAME: str = Field(
default="claude-3-5-haiku-20241022", description="Model name to be used to code with cost efficiency."
CODING_COST_EFFICIENT_MODEL_NAME: ModelName = Field(
default=ModelName.CLAUDE_3_5_HAIKU_20241022, description="Model name to be used to code with cost efficiency."
)
GENERIC_PERFORMANT_MODEL_NAME: str = Field(
default="gpt-4o-2024-11-20", description="Model name to be used for generic tasks with high performance."
GENERIC_PERFORMANT_MODEL_NAME: ModelName = Field(
default=ModelName.GPT_4O_2024_11_20,
description="Model name to be used for generic tasks with high performance.",
)
GENERIC_COST_EFFICIENT_MODEL_NAME: str = Field(
default="gpt-4o-mini-2024-07-18", description="Model name to be used for generic tasks with cost efficiency."
GENERIC_COST_EFFICIENT_MODEL_NAME: ModelName = Field(
default=ModelName.GPT_4O_MINI_2024_07_18,
description="Model name to be used for generic tasks with cost efficiency.",
)

# Snippet replacer settings
SNIPPET_REPLACER_MODEL_NAME: str = Field(
default="claude-3-5-haiku-20241022", description="Model name to be used for snippet replacer."
SNIPPET_REPLACER_MODEL_NAME: ModelName = Field(
default=ModelName.CLAUDE_3_5_HAIKU_20241022, description="Model name to be used for snippet replacer."
)
SNIPPET_REPLACER_STRATEGY: Literal["llm", "find_and_replace"] = Field(
default="find_and_replace",
Expand Down
41 changes: 41 additions & 0 deletions daiv/codebase/management/commands/search_documents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging

from django.core.management.base import BaseCommand

from automation.agents.codebase_search.agent import CodebaseSearchAgent
from codebase.clients import RepoClient
from codebase.indexes import CodebaseIndex

logger = logging.getLogger("daiv.indexes")


class Command(BaseCommand):
help = "Search documents in the index."

def add_arguments(self, parser):
parser.add_argument("--repo-id", type=str, help="Update a specific repository by namepsace, slug or id.")
parser.add_argument(
"--ref",
type=str,
help="Update a specific reference of the repository. If not provided, the default branch will be used.",
)
parser.add_argument("--show-content", action="store_true", help="Show the content of the documents.")
parser.add_argument("query", type=str, help="The query to search for.")

def handle(self, *args, **options):
repo_client = RepoClient.create_instance()
indexer = CodebaseIndex(repo_client=repo_client)

namespace = None

if options["repo_id"]:
namespace = indexer._get_codebase_namespace(options["repo_id"], options["ref"])

codebase_search = CodebaseSearchAgent(indexer.as_retriever(namespace))

for doc in codebase_search.agent.invoke(options["query"]):
self.stdout.write("-" * 100)
self.stdout.write(f"{doc.metadata['repo_id']}[{doc.metadata['ref']}]: {doc.metadata['source']}\n")
if options["show_content"]:
self.stdout.write(doc.page_content)
self.stdout.write("-" * 100)
9 changes: 9 additions & 0 deletions docker/local/app/config.secrets.example.env
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Codebase
CODEBASE_GITLAB_AUTH_TOKEN=

# Models
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
DEEPSEEK_API_KEY=
GOOGLE_API_KEY=

# Monitoring
LANGCHAIN_API_KEY=

# WebSearch
TAVILY_API_KEY=
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"langchain==0.3.14",
"langchain-anthropic==0.3.3",
"langchain-community==0.3.14",
"langchain-google-genai==2.0.9",
"langchain-openai==0.3.1",
"langchain-text-splitters==0.3.5",
"langgraph==0.2.64",
Expand Down
4 changes: 2 additions & 2 deletions tests/automation/agents/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def test_get_max_token_value(self):
mock_provider.return_value = ModelProvider.ANTHROPIC

agent = ConcreteAgent(model_name="claude-3-5-sonnet-20240229")
assert agent.get_max_token_value() == 2048
assert agent.get_max_token_value() == 8192

agent = ConcreteAgent(model_name="claude-3-opus-20240229")
assert agent.get_max_token_value() == 2048
assert agent.get_max_token_value() == 8192

def test_get_config(self):
agent = ConcreteAgent(run_name="TestAgent")
Expand Down
Loading

0 comments on commit d2a2acd

Please sign in to comment.