-
-
Notifications
You must be signed in to change notification settings - Fork 273
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b703295
commit 99a179c
Showing
10 changed files
with
452 additions
and
242 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
.PHONY: all format lint lint_diff format_diff lint_package lint_tests spell_check spell_fix help lint-fix | ||
|
||
# Define a variable for Python and notebook files. | ||
PYTHON_FILES=src/ | ||
MYPY_CACHE=.mypy_cache | ||
|
||
###################### | ||
# LINTING AND FORMATTING | ||
###################### | ||
|
||
lint format: PYTHON_FILES=. | ||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') | ||
lint_package: PYTHON_FILES=src | ||
lint_tests: PYTHON_FILES=tests | ||
lint_tests: MYPY_CACHE=.mypy_cache_test | ||
|
||
lint lint_diff lint_package lint_tests: | ||
python -m ruff check . | ||
[ "$(PYTHON_FILES)" = "" ] || python -m ruff format $(PYTHON_FILES) --diff | ||
[ "$(PYTHON_FILES)" = "" ] || python -m ruff check --select I,F401 --fix $(PYTHON_FILES) | ||
[ "$(PYTHON_FILES)" = "" ] || python -m mypy --strict $(PYTHON_FILES) | ||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && python -m mypy --strict $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) | ||
|
||
format format_diff: | ||
ruff format $(PYTHON_FILES) | ||
ruff check --fix $(PYTHON_FILES) | ||
|
||
spell_check: | ||
codespell --toml pyproject.toml | ||
|
||
spell_fix: | ||
codespell --toml pyproject.toml -w | ||
|
||
###################### | ||
# RUN ALL | ||
###################### | ||
|
||
all: format lint spell_check | ||
|
||
###################### | ||
# HELP | ||
###################### | ||
|
||
help: | ||
@echo '----' | ||
@echo 'format - run code formatters' | ||
@echo 'lint - run linters' | ||
@echo 'spell_check - run spell check' | ||
@echo 'all - run all tasks' | ||
@echo 'lint-fix - run lint and fix issues' | ||
|
||
###################### | ||
# LINT-FIX TARGET | ||
###################### | ||
|
||
lint-fix: format lint | ||
@echo "Linting and fixing completed successfully." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import os | ||
from dataclasses import dataclass | ||
from typing import Annotated, Sequence, Optional | ||
|
||
from langchain.callbacks.base import BaseCallbackHandler | ||
from langchain_anthropic import ChatAnthropic | ||
from langchain_core.messages import SystemMessage | ||
from langchain_openai import ChatOpenAI | ||
from langgraph.checkpoint.memory import MemorySaver | ||
from langgraph.graph import START, StateGraph | ||
from langgraph.prebuilt import ToolNode, tools_condition | ||
from langgraph.graph.message import add_messages | ||
from langchain_core.messages import BaseMessage | ||
|
||
from template import TEMPLATE | ||
from tools import retriever_tool | ||
|
||
|
||
@dataclass | ||
class MessagesState: | ||
messages: Annotated[Sequence[BaseMessage], add_messages] | ||
|
||
|
||
memory = MemorySaver() | ||
|
||
|
||
@dataclass | ||
class ModelConfig: | ||
model_name: str | ||
api_key: str | ||
base_url: Optional[str] = None | ||
|
||
|
||
def create_agent(callback_handler: BaseCallbackHandler, model_name: str): | ||
model_configurations = { | ||
"gpt-4o-mini": ModelConfig( | ||
model_name="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY") | ||
), | ||
"gemma2-9b": ModelConfig( | ||
model_name="gemma2-9b-it", | ||
api_key=os.getenv("GROQ_API_KEY"), | ||
base_url="https://api.groq.com/openai/v1", | ||
), | ||
"claude3-haiku": ModelConfig( | ||
model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY") | ||
), | ||
"mixtral-8x22b": ModelConfig( | ||
model_name="accounts/fireworks/models/mixtral-8x22b-instruct", | ||
api_key=os.getenv("FIREWORKS_API_KEY"), | ||
base_url="https://api.fireworks.ai/inference/v1", | ||
), | ||
"llama-3.1-405b": ModelConfig( | ||
model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", | ||
api_key=os.getenv("FIREWORKS_API_KEY"), | ||
base_url="https://api.fireworks.ai/inference/v1", | ||
), | ||
} | ||
config = model_configurations.get(model_name) | ||
if not config: | ||
raise ValueError(f"Unsupported model name: {model_name}") | ||
|
||
sys_msg = SystemMessage( | ||
content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. | ||
Call the tool "Database_Schema" to search for database schema details when needed to generate the SQL code. | ||
""" | ||
) | ||
|
||
llm = ( | ||
ChatOpenAI( | ||
model=config.model_name, | ||
api_key=config.api_key, | ||
callbacks=[callback_handler], | ||
streaming=True, | ||
base_url=config.base_url, | ||
) | ||
if config.model_name != "claude-3-haiku-20240307" | ||
else ChatAnthropic( | ||
model=config.model_name, | ||
api_key=config.api_key, | ||
callbacks=[callback_handler], | ||
streaming=True, | ||
) | ||
) | ||
|
||
tools = [retriever_tool] | ||
|
||
llm_with_tools = llm.bind_tools(tools) | ||
|
||
def reasoner(state: MessagesState): | ||
return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]} | ||
|
||
# Build the graph | ||
builder = StateGraph(MessagesState) | ||
builder.add_node("reasoner", reasoner) | ||
builder.add_node("tools", ToolNode(tools)) | ||
builder.add_edge(START, "reasoner") | ||
builder.add_conditional_edges("reasoner", tools_condition) | ||
builder.add_edge("tools", "reasoner") | ||
|
||
react_graph = builder.compile(checkpointer=memory) | ||
|
||
return react_graph |
Oops, something went wrong.