Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: collect token usage #94

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions packages/framework/src/supervisor/finalizeWorkflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@ import { zodResponseFormat } from 'openai/helpers/zod'
import { z } from 'zod'

import { Provider } from '../models.js'
import { Message } from '../types.js'
import { Message, Usage } from '../types.js'

export async function finalizeWorkflow(provider: Provider, messages: Message[]): Promise<string> {
export type FinalizeWorkflowResult = {
response: string
usage?: Usage
}

export async function finalizeWorkflow(provider: Provider, messages: Message[]): Promise<FinalizeWorkflowResult> {
const response = await provider.completions({
messages: [
{
Expand All @@ -29,5 +34,9 @@ export async function finalizeWorkflow(provider: Provider, messages: Message[]):
if (!result) {
throw new Error('No parsed response received')
}
return result.finalAnswer

return {
response: result.finalAnswer,
usage: response.usage,
};
}
13 changes: 9 additions & 4 deletions packages/framework/src/supervisor/nextTask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@ import { zodResponseFormat } from 'openai/helpers/zod'
import { z } from 'zod'

import { Provider } from '../models.js'
import { Message } from '../types.js'
import { Message, Usage } from '../types.js'

export async function nextTask(provider: Provider, history: Message[]): Promise<string | null> {
export type NextTaskResult = {
task: string | null
usage?: Usage
}

export async function nextTask(provider: Provider, history: Message[]): Promise<NextTaskResult> {
const response = await provider.completions({
messages: [
{
Expand Down Expand Up @@ -53,10 +58,10 @@ export async function nextTask(provider: Provider, history: Message[]): Promise<
}

if (!content.task) {
return null
return { task: null, usage: response.usage }
}

return content.task
return { task: content.task, usage: response.usage }
} catch (error) {
throw new Error('Failed to determine next task')
}
Expand Down
44 changes: 32 additions & 12 deletions packages/framework/src/supervisor/nextTick.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Workflow, WorkflowState } from '../workflow.js'
import type { Usage } from '../types.js'
import type { Workflow, WorkflowState } from '../workflow.js'
import { finalizeWorkflow } from './finalizeWorkflow.js'
import { nextTask } from './nextTask.js'
import { runAgent } from './runAgent.js'
Expand All @@ -12,26 +13,27 @@ export async function nextTick(workflow: Workflow, state: WorkflowState): Promis
const { status, messages } = state

/**
* When number of messages exceedes number of maximum iterations, we must force finish the workflow
* When number of messages exceeds number of maximum iterations, we must force finish the workflow
* and produce best final answer
*/
if (messages.length > workflow.maxIterations) {
const content = await finalizeWorkflow(workflow.provider, messages)
const { response, usage } = await finalizeWorkflow(workflow.provider, messages)
return {
...state,
status: 'finished',
messages: state.messages.concat({
role: 'user',
content,
content: response,
}),
usage: combineUsage(state.usage, usage),
}
}

/**
* When workflow is idle, we must get next task to work on, or finish the workflow otherwise.
*/
if (status === 'idle') {
const task = await nextTask(workflow.provider, messages)
const { task, usage } = await nextTask(workflow.provider, messages)
if (task) {
return {
...state,
Expand All @@ -42,11 +44,13 @@ export async function nextTick(workflow: Workflow, state: WorkflowState): Promis
content: task,
},
],
usage: combineUsage(state.usage, usage),
}
} else {
return {
...state,
status: 'finished',
usage: combineUsage(state.usage, usage),
}
}
}
Expand All @@ -55,12 +59,17 @@ export async function nextTick(workflow: Workflow, state: WorkflowState): Promis
* When workflow is pending, we must find best agent to work on it.
*/
if (status === 'pending') {
const selectedAgent = await selectAgent(workflow.provider, state.agentRequest, workflow.members)
const { agent, usage } = await selectAgent(
workflow.provider,
state.agentRequest,
workflow.members
)
return {
...state,
status: 'assigned',
agentStatus: 'idle',
agent: selectedAgent.role,
agent: agent.role,
usage: combineUsage(state.usage, usage),
}
}

Expand All @@ -77,6 +86,7 @@ export async function nextTick(workflow: Workflow, state: WorkflowState): Promis
role: 'assistant',
content: 'No agent found.',
}),
usage: state.usage,
}
}

Expand All @@ -101,18 +111,20 @@ export async function nextTick(workflow: Workflow, state: WorkflowState): Promis
*
* If further processing is required, we will carry `agentRequest` over to the next iteration.
*/
const [agentResponse, status] = await runAgent(agent, state.messages, state.agentRequest)
if (status === 'complete') {
const { kind, message, usage } = await runAgent(agent, state.messages, state.agentRequest)
if (kind === 'complete') {
return {
...state,
status: 'idle',
messages: state.messages.concat(state.agentRequest[0], agentResponse),
messages: state.messages.concat(state.agentRequest[0], message),
usage: combineUsage(state.usage, usage),
}
}
return {
...state,
agentStatus: status,
agentRequest: state.agentRequest.concat(agentResponse),
agentStatus: kind,
agentRequest: state.agentRequest.concat(message),
usage: combineUsage(state.usage, usage),
}
}

Expand All @@ -138,3 +150,11 @@ export async function iterate(workflow: Workflow, state: WorkflowState) {
workflow.snapshot({ prevState: state, nextState })
return nextState
}

function combineUsage(prevUsage: Usage, usage: Usage | undefined) {
return {
prompt_tokens: prevUsage.prompt_tokens + (usage?.prompt_tokens ?? 0),
completion_tokens: prevUsage.completion_tokens + (usage?.completion_tokens ?? 0),
total_tokens: prevUsage.total_tokens + (usage?.total_tokens ?? 0),
}
}
21 changes: 14 additions & 7 deletions packages/framework/src/supervisor/runAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@ import { zodFunction, zodResponseFormat } from 'openai/helpers/zod'
import { z } from 'zod'

import { Agent } from '../agent.js'
import { Message } from '../types.js'
import { Message, Usage } from '../types.js'

export type RunAgentResult = {
message: Message
kind: 'step' | 'complete' | 'tool'
usage?: Usage
}

export async function runAgent(
agent: Agent,
agentContext: Message[],
agentRequest: Message[]
): Promise<[Message, 'step' | 'complete' | 'tool']> {
): Promise<RunAgentResult> {
const tools = agent.tools
? Object.entries(agent.tools).map(([name, tool]) =>
zodFunction({
Expand Down Expand Up @@ -81,7 +87,7 @@ export async function runAgent(
})

if (response.choices[0].message.tool_calls.length > 0) {
return [response.choices[0].message, 'tool']
return { message: response.choices[0].message, kind: 'tool', usage: response.usage }
}

const result = response.choices[0].message.parsed
Expand All @@ -93,11 +99,12 @@ export async function runAgent(
throw new Error(result.response.reasoning)
}

return [
{
return {
message: {
role: 'assistant',
content: result.response.result,
},
result.response.kind,
]
kind: result.response.kind,
usage: response.usage,
}
}
11 changes: 8 additions & 3 deletions packages/framework/src/supervisor/selectAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@ import { z } from 'zod'

import { Agent } from '../agent.js'
import { Provider } from '../models.js'
import { Message } from '../types.js'
import { Message, Usage } from '../types.js'

export type SelectAgentResult = {
agent: Agent
usage?: Usage
}

export async function selectAgent(
provider: Provider,
agentRequest: Message[],
agents: Agent[]
): Promise<Agent> {
): Promise<SelectAgentResult> {
const response = await provider.completions({
messages: [
{
Expand Down Expand Up @@ -64,5 +69,5 @@ export async function selectAgent(
throw new Error('Invalid agent')
}

return agent
return { agent, usage: response.usage }
}
9 changes: 6 additions & 3 deletions packages/framework/src/telemetry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ export const logger: Telemetry = ({ prevState, nextState }) => {
break
case 'finished':
logMessage(
'🎉',
'Workflow finished successfully!',
`Total messages: ${nextState.messages.length}`
"🎉",
"Workflow finished successfully!",
[
`Total messages: ${nextState.messages.length}`,
`Total tokens: ${nextState.usage.total_tokens} (input: ${nextState.usage.prompt_tokens}, output: ${nextState.usage.completion_tokens})`,
].join('\n')
)
break
case 'failed':
Expand Down
6 changes: 6 additions & 0 deletions packages/framework/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ChatCompletionMessageParam } from 'openai/resources/chat/completions'
import type { CompletionUsage } from 'openai/resources/completions.mjs'

/**
* Utility type to get optional keys from T.
Expand All @@ -22,3 +23,8 @@ export type RequiredOptionals<T> = Required<OptionalProperties<T>>
*/
export type Message = ChatCompletionMessageParam
export type MessageContent = Message['content']

/**
* Usage type for completion
*/
export type Usage = CompletionUsage;
16 changes: 11 additions & 5 deletions packages/framework/src/workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import s from 'dedent'
import { Agent } from './agent.js'
import { openai, Provider } from './models.js'
import { noop, Telemetry } from './telemetry.js'
import { Message } from './types.js'
import { Message, Usage } from './types.js'

type WorkflowOptions = {
description: string
Expand Down Expand Up @@ -35,9 +35,10 @@ export type Workflow = Required<WorkflowOptions>
* Base workflow
*/
type BaseWorkflowState = {
id: string
messages: Message[]
}
id: string;
messages: Message[];
usage: Usage;
};

/**
* Different states workflow is in, in between execution from agents
Expand All @@ -47,7 +48,7 @@ export type IdleWorkflowState = BaseWorkflowState & {
}

/**
* Supervisor selected the task, and is now pending assignement of an agent
* Supervisor selected the task, and is now pending assignment of an agent
*/
export type PendingWorkflowState = BaseWorkflowState & {
status: 'pending'
Expand Down Expand Up @@ -84,6 +85,11 @@ export const workflowState = (workflow: Workflow): IdleWorkflowState => {
`,
},
],
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
}
}

Expand Down
6 changes: 6 additions & 0 deletions packages/tools/src/vision.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import path from 'node:path'
import s from 'dedent'
import { Provider } from 'fabrice-ai/models'
import { tool } from 'fabrice-ai/tool'
import { CompletionUsage } from 'openai/resources/completions.mjs'
import { zodResponseFormat } from 'openai/helpers/zod'
import { z } from 'zod'

Expand All @@ -12,6 +13,11 @@ const encodeImage = async (imagePath: string): Promise<string> => {
return `data:image/${path.extname(imagePath).toLowerCase().replace('.', '')};base64,${imageBuffer.toString('base64')}`
}

export type CallOpenAIResult = {
text: string
usage?: CompletionUsage
}

async function callOpenAI(
provider: Provider,
prompt: string,
Expand Down