Skip to content

Commit

Permalink
Use agents
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Oct 13, 2024
1 parent b703295 commit 99a179c
Show file tree
Hide file tree
Showing 10 changed files with 452 additions and 242 deletions.
27 changes: 20 additions & 7 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,39 @@ on:
push:
branches:
- main
- prod
pull_request:
branches:
- main
- prod

jobs:
lint:
name: Lint and Format Code
runs-on: ubuntu-latest

steps:
- name: Check out repository
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: "3.9"

- name: Cache pip dependencies
uses: actions/cache@v3
with:
python-version: 3.9
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: pip install black
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install black ruff mypy codespell
- name: Lint with black
run: black --check .
- name: Run Formatting and Linting
run: |
make format
make lint
57 changes: 57 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
.PHONY: all format lint lint_diff format_diff lint_package lint_tests spell_check spell_fix help lint-fix

# Define a variable for Python and notebook files.
PYTHON_FILES=src/
MYPY_CACHE=.mypy_cache

######################
# LINTING AND FORMATTING
######################

lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=src
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
python -m ruff check .
[ "$(PYTHON_FILES)" = "" ] || python -m ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || python -m ruff check --select I,F401 --fix $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || python -m mypy --strict $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && python -m mypy --strict $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
ruff format $(PYTHON_FILES)
ruff check --fix $(PYTHON_FILES)

spell_check:
codespell --toml pyproject.toml

spell_fix:
codespell --toml pyproject.toml -w

######################
# RUN ALL
######################

all: format lint spell_check

######################
# HELP
######################

help:
@echo '----'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'spell_check - run spell check'
@echo 'all - run all tasks'
@echo 'lint-fix - run lint and fix issues'

######################
# LINT-FIX TARGET
######################

lint-fix: format lint
@echo "Linting and fixing completed successfully."
102 changes: 102 additions & 0 deletions agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
from dataclasses import dataclass
from typing import Annotated, Sequence, Optional

from langchain.callbacks.base import BaseCallbackHandler
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph.message import add_messages
from langchain_core.messages import BaseMessage

from template import TEMPLATE
from tools import retriever_tool


@dataclass
class MessagesState:
messages: Annotated[Sequence[BaseMessage], add_messages]


memory = MemorySaver()


@dataclass
class ModelConfig:
model_name: str
api_key: str
base_url: Optional[str] = None


def create_agent(callback_handler: BaseCallbackHandler, model_name: str):
model_configurations = {
"gpt-4o-mini": ModelConfig(
model_name="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY")
),
"gemma2-9b": ModelConfig(
model_name="gemma2-9b-it",
api_key=os.getenv("GROQ_API_KEY"),
base_url="https://api.groq.com/openai/v1",
),
"claude3-haiku": ModelConfig(
model_name="claude-3-haiku-20240307", api_key=os.getenv("ANTHROPIC_API_KEY")
),
"mixtral-8x22b": ModelConfig(
model_name="accounts/fireworks/models/mixtral-8x22b-instruct",
api_key=os.getenv("FIREWORKS_API_KEY"),
base_url="https://api.fireworks.ai/inference/v1",
),
"llama-3.1-405b": ModelConfig(
model_name="accounts/fireworks/models/llama-v3p1-405b-instruct",
api_key=os.getenv("FIREWORKS_API_KEY"),
base_url="https://api.fireworks.ai/inference/v1",
),
}
config = model_configurations.get(model_name)
if not config:
raise ValueError(f"Unsupported model name: {model_name}")

sys_msg = SystemMessage(
content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate.
Call the tool "Database_Schema" to search for database schema details when needed to generate the SQL code.
"""
)

llm = (
ChatOpenAI(
model=config.model_name,
api_key=config.api_key,
callbacks=[callback_handler],
streaming=True,
base_url=config.base_url,
)
if config.model_name != "claude-3-haiku-20240307"
else ChatAnthropic(
model=config.model_name,
api_key=config.api_key,
callbacks=[callback_handler],
streaming=True,
)
)

tools = [retriever_tool]

llm_with_tools = llm.bind_tools(tools)

def reasoner(state: MessagesState):
return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]}

# Build the graph
builder = StateGraph(MessagesState)
builder.add_node("reasoner", reasoner)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "reasoner")
builder.add_conditional_edges("reasoner", tools_condition)
builder.add_edge("tools", "reasoner")

react_graph = builder.compile(checkpointer=memory)

return react_graph
Loading

0 comments on commit 99a179c

Please sign in to comment.