Skip to content

Commit

Permalink
Added cost counting for anthropic client
Browse files Browse the repository at this point in the history
  • Loading branch information
rusiaaman committed Feb 2, 2025
1 parent 950b045 commit e1646cc
Showing 1 changed file with 82 additions and 10 deletions.
92 changes: 82 additions & 10 deletions src/wcgw_cli/anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
from anthropic.types import (
ImageBlockParam,
MessageParam,
ModelParam,
TextBlockParam,
ToolParam,
ToolResultBlockParam,
ToolUseBlockParam,
)
from dotenv import load_dotenv
from pydantic import BaseModel
from typer import Typer

from wcgw.client.common import discard_input
from wcgw.client.common import CostData, discard_input
from wcgw.client.memory import load_memory
from wcgw.client.tools import (
DoneFlag,
Expand All @@ -47,6 +49,14 @@
WriteIfEmpty,
)


class Config(BaseModel):
model: ModelParam
cost_limit: float
cost_file: dict[ModelParam, CostData]
cost_unit: str = "$"


History = list[MessageParam]


Expand Down Expand Up @@ -150,7 +160,51 @@ def loop(
first_message = ""
waiting_for_assistant = history[-1]["role"] != "assistant"

limit = 1
config = Config(
model="claude-3-5-sonnet-20241022",
cost_limit=0.1,
cost_unit="$",
cost_file={
# Claude 3.5 Haiku
"claude-3-5-haiku-latest": CostData(
cost_per_1m_input_tokens=0.80, cost_per_1m_output_tokens=4
),
"claude-3-5-haiku-20241022": CostData(
cost_per_1m_input_tokens=0.80, cost_per_1m_output_tokens=4
),
# Claude 3.5 Sonnet
"claude-3-5-sonnet-latest": CostData(
cost_per_1m_input_tokens=3.0, cost_per_1m_output_tokens=15.0
),
"claude-3-5-sonnet-20241022": CostData(
cost_per_1m_input_tokens=3.0, cost_per_1m_output_tokens=15.0
),
"claude-3-5-sonnet-20240620": CostData(
cost_per_1m_input_tokens=3.0, cost_per_1m_output_tokens=15.0
),
# Claude 3 Opus
"claude-3-opus-latest": CostData(
cost_per_1m_input_tokens=15.0, cost_per_1m_output_tokens=75.0
),
"claude-3-opus-20240229": CostData(
cost_per_1m_input_tokens=15.0, cost_per_1m_output_tokens=75.0
),
# Legacy Models
"claude-3-haiku-20240307": CostData(
cost_per_1m_input_tokens=0.25, cost_per_1m_output_tokens=1.25
),
"claude-2.1": CostData(
cost_per_1m_input_tokens=8.0, cost_per_1m_output_tokens=24.0
),
"claude-2.0": CostData(
cost_per_1m_input_tokens=8.0, cost_per_1m_output_tokens=24.0
),
},
)

if limit is not None:
config.cost_limit = limit
limit = config.cost_limit

tools = [
ToolParam(
Expand Down Expand Up @@ -321,9 +375,15 @@ def loop(
while True:
if cost > limit:
system_console.print(
f"\nCost limit exceeded. Current cost: {cost}, input tokens: {input_toks}, output tokens: {output_toks}"
f"\nCost limit exceeded. Current cost: {config.cost_unit}{cost:.4f}, "
f"input tokens: {input_toks}"
f"output tokens: {output_toks}"
)
break
else:
system_console.print(
f"\nTotal cost: {config.cost_unit}{cost:.4f}, input tokens: {input_toks}, output tokens: {output_toks}"
)

if not waiting_for_assistant:
if first_message:
Expand All @@ -335,13 +395,8 @@ def loop(
history.append(parse_user_message_special(msg))
else:
waiting_for_assistant = False

cost_, input_toks_ = 0, 0
cost += cost_
input_toks += input_toks_

stream = client.messages.stream(
model="claude-3-5-sonnet-20241022",
model=config.model,
messages=history,
tools=tools,
max_tokens=8096,
Expand All @@ -361,7 +416,24 @@ def loop(
with stream as stream_:
for chunk in stream_:
type_ = chunk.type
if type_ in {"message_start", "message_stop"}:
if type_ == "message_start":
message_start = chunk.message
# Update cost based on token usage from the API response
input_tokens = message_start.usage.input_tokens
input_toks += input_tokens
cost += (
input_tokens
* config.cost_file[config.model].cost_per_1m_input_tokens
) / 1_000_000
elif type_ == "message_stop":
message_stop = chunk.message
# Update cost based on output tokens
output_tokens = message_stop.usage.output_tokens
output_toks += output_tokens
cost += (
output_tokens
* config.cost_file[config.model].cost_per_1m_output_tokens
) / 1_000_000
continue
elif type_ == "content_block_start" and hasattr(
chunk, "content_block"
Expand Down

0 comments on commit e1646cc

Please sign in to comment.