Skip to content

Commit

Permalink
feat: refactor main flow to state machine (#44)
Browse files Browse the repository at this point in the history
Replaced recursion with a single flat function (state machine). Easier
to debug, and easier to log.
In a follow-up PR, I will create a beautiful logger for each state. For
now, it's just logging the state.

![CleanShot 2024-12-09 at 01 34
31@2x](https://github.com/user-attachments/assets/02bdbdca-5a78-48da-82fa-82b667425dc9)

```
stateDiagram-v2
    [*] --> start
    start --> pending: Get next task
    pending --> assigned: Select agent
    assigned --> tool: Calls tools
    assigned --> step: Working on a task
    step --> complete: Finished processing
    assigned --> failed: No agent found
    complete --> start: Continue processing
    tool --> assigned: Run tools
    step --> step: Continue processing
    start --> finished: No more tasks
    failed --> finished: Attempt recovery or finish
    finished --> [*]
```

We may use this in the docs later too!
  • Loading branch information
grabbou authored Dec 9, 2024
1 parent 1f147fa commit 32b5806
Show file tree
Hide file tree
Showing 16 changed files with 285 additions and 219 deletions.
Binary file modified bun.lockb
Binary file not shown.
5 changes: 3 additions & 2 deletions example/src/surprise_trip.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import { agent } from '@dead-simple-ai-agent/framework/agent'
import { teamwork } from '@dead-simple-ai-agent/framework/teamwork'
import { logger } from '@dead-simple-ai-agent/framework/telemetry/console'
import { logger } from '@dead-simple-ai-agent/framework/telemetry'
import { workflow } from '@dead-simple-ai-agent/framework/workflow'

import { lookupWikipedia } from '../tools.js'
Expand Down Expand Up @@ -72,7 +72,8 @@ const researchTripWorkflow = workflow({
Comprehensive day-by-day itinerary for the trip to Wrocław, Poland.
Ensure the itinerary integrates flights, hotel information, and all planned activities and dining experiences.
`,
telemetry: logger,
// Uncomment to see the workflow state in the console
// snapshot: logger,
})

const result = await teamwork(researchTripWorkflow)
Expand Down
3 changes: 0 additions & 3 deletions packages/framework/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
},
"./models/*": {
"bun": "./src/models/*.ts"
},
"./telemetry/*": {
"bun": "./src/telemetry/*.ts"
}
},
"type": "module",
Expand Down
3 changes: 2 additions & 1 deletion packages/framework/src/supervisor/nextTask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import { z } from 'zod'
import { Provider } from '../models/openai.js'
import { Message } from '../types.js'

export async function getNextTask(provider: Provider, history: Message[]): Promise<string | null> {
export async function nextTask(provider: Provider, history: Message[]): Promise<string | null> {
const response = await provider.completions({
messages: [
{
role: 'system',
// tbd: handle subsequent failures
// tbd: include max iterations in system prompt
content: s`
You are a planner that breaks down complex workflows into smaller, actionable steps.
Your job is to determine the next task that needs to be done based on the original workflow and what has been completed so far.
Expand Down
132 changes: 132 additions & 0 deletions packages/framework/src/supervisor/nextTick.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import { Workflow, WorkflowState } from '../workflow.js'
import { finalizeWorkflow } from './finalizeWorkflow.js'
import { nextTask } from './nextTask.js'
import { runAgent } from './runAgent.js'
import { runTools } from './runTools.js'
import { selectAgent } from './selectAgent.js'

/**
* Performs single iteration over Workflow and produces its next state.
*/
export async function nextTick(workflow: Workflow, state: WorkflowState): Promise<WorkflowState> {
const { status, messages } = state

/**
* When number of messages exceedes number of maximum iterations, we must force finish the workflow
* and produce best final answer
*/
if (messages.length > workflow.maxIterations) {
const content = await finalizeWorkflow(workflow.provider, messages)
return {
...state,
status: 'finished',
messages: state.messages.concat({
role: 'user',
content,
}),
}
}

/**
* When workflow is idle, we must get next task to work on, or finish the workflow otherwise.
*/
if (status === 'idle') {
const task = await nextTask(workflow.provider, messages)
if (task) {
return {
...state,
status: 'pending',
agentRequest: [
{
role: 'user',
content: task,
},
],
}
} else {
return {
...state,
status: 'finished',
}
}
}

/**
* When workflow is pending, we must find best agent to work on it.
*/
if (status === 'pending') {
const selectedAgent = await selectAgent(workflow.provider, state.agentRequest, workflow.members)
return {
...state,
status: 'assigned',
agentStatus: 'idle',
agent: selectedAgent.role,
}
}

/**
* When workflow is running, we must call assigned agent to continue working on it.
*/
if (status === 'assigned') {
const agent = workflow.members.find((member) => member.role === state.agent)
if (!agent) {
return {
id: state.id,
status: 'failed',
messages: state.messages.concat({
role: 'assistant',
content: 'No agent found.',
}),
}
}

/**
* When agentStatus is `tool`, an agent is waiting for the tools results.
* We must run all the tools in order to proceed to the next step.
*/
if (state.agentStatus === 'tool') {
const toolsResponse = await runTools(agent, state.agentRequest!)
return {
...state,
agentStatus: 'step',
agentRequest: state.agentRequest.concat(toolsResponse),
}
}

/**
* When agent finishes running, it will return status to indicate whether it finished processing.
*
* If it finished processing, we will append its final answer to the context, as well as
* first message from `agentRequest`, which holds the actual task, excluding middle-steps.
*
* If further processing is required, we will carry `agentRequest` over to the next iteration.
*/
const [agentResponse, status] = await runAgent(agent, state.agentRequest)
if (status === 'complete') {
const agentFinalAnswer = agentResponse.at(-1)!
return {
...state,
status: 'idle',
messages: state.messages.concat(state.agentRequest[0], agentFinalAnswer),
}
}
return {
...state,
agentStatus: status,
agentRequest: state.agentRequest.concat(agentResponse),
}
}

/**
* When workflow fails due to unexpected error, we must attempt recovering or finish the workflow
* otherwise.
*/
if (status === 'failed') {
return {
...state,
status: 'finished',
}
}

return state
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ import s from 'dedent'
import { zodFunction, zodResponseFormat } from 'openai/helpers/zod'
import { z } from 'zod'

import { Agent } from './agent.js'
import { Message } from './types.js'
import { Agent } from '../agent.js'
import { Message } from '../types.js'

export async function executeTaskWithAgent(
export async function runAgent(
agent: Agent,
messages: Message[],
team: Agent[]
): Promise<string> {
agentRequest: Message[]
): Promise<[Message[], 'step' | 'complete' | 'tool']> {
const tools = agent.tools
? Object.entries(agent.tools).map(([name, tool]) =>
zodFunction({
Expand All @@ -21,7 +20,6 @@ export async function executeTaskWithAgent(
: []

const response = await agent.provider.completions({
// tbd: verify the prompt
messages: [
{
role: 'system',
Expand All @@ -36,7 +34,7 @@ export async function executeTaskWithAgent(
Only ask question to the user if you cannot complete the task without their input.
`,
},
...messages,
...agentRequest,
],
tools: tools.length > 0 ? tools : undefined,
response_format: zodResponseFormat(
Expand All @@ -58,62 +56,23 @@ export async function executeTaskWithAgent(
'task_result'
),
})
if (response.choices[0].message.tool_calls.length > 0) {
const toolResults = await Promise.all(
response.choices[0].message.tool_calls.map(async (toolCall) => {
if (toolCall.type !== 'function') {
throw new Error('Tool call is not a function')
}

const tool = agent.tools ? agent.tools[toolCall.function.name] : null
if (!tool) {
throw new Error(`Unknown tool: ${toolCall.function.name}`)
}

const content = await tool.execute(toolCall.function.parsed_arguments, {
provider: agent.provider,
messages,
})
return {
role: 'tool' as const,
tool_call_id: toolCall.id,
content: JSON.stringify(content),
}
})
)

return executeTaskWithAgent(
agent,
[...messages, response.choices[0].message, ...toolResults],
team
)
if (response.choices[0].message.tool_calls.length > 0) {
return [[response.choices[0].message], 'tool']
}

// tbd: verify shape of response
const result = response.choices[0].message.parsed
if (!result) {
throw new Error('No parsed response received')
}

if (result.response.kind === 'step') {
console.log('🚀 Step:', result.response.name)
return executeTaskWithAgent(
agent,
[
...messages,
{
role: 'assistant',
content: result.response.result,
},
],
team
)
}

if (result.response.kind === 'complete') {
return result.response.result
}

// tbd: check if this is reachable
throw new Error('Illegal state')
return [
[
{
role: 'assistant',
content: result.response.result,
},
],
result.response.kind,
]
}
50 changes: 50 additions & 0 deletions packages/framework/src/supervisor/runTools.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import type { ParsedChatCompletionMessage } from 'openai/resources/beta/chat/completions.mjs'
import { ChatCompletionToolMessageParam } from 'openai/resources/index.mjs'

import { Agent } from '../agent.js'
import { Message } from '../types.js'

/**
* Asserts that given message requests tool calls
*/
function isToolCallRequest(message?: Message): message is ParsedChatCompletionMessage<any> {
return message ? 'tool_calls' in message : false
}

export async function runTools(
agent: Agent,
agentRequest: Message[]
): Promise<ChatCompletionToolMessageParam[]> {
// tbd: find cleaner way to do this
const messages = Array.from(agentRequest)
const toolCallRequest = messages.pop()

if (!isToolCallRequest(toolCallRequest)) {
throw new Error('Invalid tool request')
}

const toolResults = await Promise.all(
toolCallRequest.tool_calls.map(async (toolCall) => {
if (toolCall.type !== 'function') {
throw new Error('Tool call is not a function')
}

const tool = agent.tools ? agent.tools[toolCall.function.name] : null
if (!tool) {
throw new Error(`Unknown tool: ${toolCall.function.name}`)
}

const content = await tool.execute(toolCall.function.parsed_arguments, {
provider: agent.provider,
messages,
})
return {
role: 'tool' as const,
tool_call_id: toolCall.id,
content: JSON.stringify(content),
}
})
)

return toolResults
}
5 changes: 3 additions & 2 deletions packages/framework/src/supervisor/selectAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import { z } from 'zod'

import { Agent } from '../agent.js'
import { Provider } from '../models/openai.js'
import { Message } from '../types.js'

export async function selectAgent(
provider: Provider,
task: string,
agentRequest: Message[],
agents: Agent[]
): Promise<Agent> {
const response = await provider.completions({
Expand All @@ -29,7 +30,7 @@ export async function selectAgent(
role: 'user',
content: s`
Here is the task:
<task>${task}</task>
<task>${agentRequest.map((request) => request.content).join(',')}</task>
Here are the available agents:
<agents>
Expand Down
Loading

0 comments on commit 32b5806

Please sign in to comment.