-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: refactor main flow to state machine (#44)
Replaced recursion with a single flat function (state machine). Easier to debug, and easier to log. In a follow-up PR, I will create a beautiful logger for each state. For now, it's just logging the state. data:image/s3,"s3://crabby-images/5f512/5f512f78e0d239255fa6ea6f213c6f9ddb1f3f45" alt="CleanShot 2024-12-09 at 01 34 31@2x" ``` stateDiagram-v2 [*] --> start start --> pending: Get next task pending --> assigned: Select agent assigned --> tool: Calls tools assigned --> step: Working on a task step --> complete: Finished processing assigned --> failed: No agent found complete --> start: Continue processing tool --> assigned: Run tools step --> step: Continue processing start --> finished: No more tasks failed --> finished: Attempt recovery or finish finished --> [*] ``` We may use this in the docs later too!
- Loading branch information
Showing
16 changed files
with
285 additions
and
219 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import { Workflow, WorkflowState } from '../workflow.js' | ||
import { finalizeWorkflow } from './finalizeWorkflow.js' | ||
import { nextTask } from './nextTask.js' | ||
import { runAgent } from './runAgent.js' | ||
import { runTools } from './runTools.js' | ||
import { selectAgent } from './selectAgent.js' | ||
|
||
/** | ||
* Performs single iteration over Workflow and produces its next state. | ||
*/ | ||
export async function nextTick(workflow: Workflow, state: WorkflowState): Promise<WorkflowState> { | ||
const { status, messages } = state | ||
|
||
/** | ||
* When number of messages exceedes number of maximum iterations, we must force finish the workflow | ||
* and produce best final answer | ||
*/ | ||
if (messages.length > workflow.maxIterations) { | ||
const content = await finalizeWorkflow(workflow.provider, messages) | ||
return { | ||
...state, | ||
status: 'finished', | ||
messages: state.messages.concat({ | ||
role: 'user', | ||
content, | ||
}), | ||
} | ||
} | ||
|
||
/** | ||
* When workflow is idle, we must get next task to work on, or finish the workflow otherwise. | ||
*/ | ||
if (status === 'idle') { | ||
const task = await nextTask(workflow.provider, messages) | ||
if (task) { | ||
return { | ||
...state, | ||
status: 'pending', | ||
agentRequest: [ | ||
{ | ||
role: 'user', | ||
content: task, | ||
}, | ||
], | ||
} | ||
} else { | ||
return { | ||
...state, | ||
status: 'finished', | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* When workflow is pending, we must find best agent to work on it. | ||
*/ | ||
if (status === 'pending') { | ||
const selectedAgent = await selectAgent(workflow.provider, state.agentRequest, workflow.members) | ||
return { | ||
...state, | ||
status: 'assigned', | ||
agentStatus: 'idle', | ||
agent: selectedAgent.role, | ||
} | ||
} | ||
|
||
/** | ||
* When workflow is running, we must call assigned agent to continue working on it. | ||
*/ | ||
if (status === 'assigned') { | ||
const agent = workflow.members.find((member) => member.role === state.agent) | ||
if (!agent) { | ||
return { | ||
id: state.id, | ||
status: 'failed', | ||
messages: state.messages.concat({ | ||
role: 'assistant', | ||
content: 'No agent found.', | ||
}), | ||
} | ||
} | ||
|
||
/** | ||
* When agentStatus is `tool`, an agent is waiting for the tools results. | ||
* We must run all the tools in order to proceed to the next step. | ||
*/ | ||
if (state.agentStatus === 'tool') { | ||
const toolsResponse = await runTools(agent, state.agentRequest!) | ||
return { | ||
...state, | ||
agentStatus: 'step', | ||
agentRequest: state.agentRequest.concat(toolsResponse), | ||
} | ||
} | ||
|
||
/** | ||
* When agent finishes running, it will return status to indicate whether it finished processing. | ||
* | ||
* If it finished processing, we will append its final answer to the context, as well as | ||
* first message from `agentRequest`, which holds the actual task, excluding middle-steps. | ||
* | ||
* If further processing is required, we will carry `agentRequest` over to the next iteration. | ||
*/ | ||
const [agentResponse, status] = await runAgent(agent, state.agentRequest) | ||
if (status === 'complete') { | ||
const agentFinalAnswer = agentResponse.at(-1)! | ||
return { | ||
...state, | ||
status: 'idle', | ||
messages: state.messages.concat(state.agentRequest[0], agentFinalAnswer), | ||
} | ||
} | ||
return { | ||
...state, | ||
agentStatus: status, | ||
agentRequest: state.agentRequest.concat(agentResponse), | ||
} | ||
} | ||
|
||
/** | ||
* When workflow fails due to unexpected error, we must attempt recovering or finish the workflow | ||
* otherwise. | ||
*/ | ||
if (status === 'failed') { | ||
return { | ||
...state, | ||
status: 'finished', | ||
} | ||
} | ||
|
||
return state | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import type { ParsedChatCompletionMessage } from 'openai/resources/beta/chat/completions.mjs' | ||
import { ChatCompletionToolMessageParam } from 'openai/resources/index.mjs' | ||
|
||
import { Agent } from '../agent.js' | ||
import { Message } from '../types.js' | ||
|
||
/** | ||
* Asserts that given message requests tool calls | ||
*/ | ||
function isToolCallRequest(message?: Message): message is ParsedChatCompletionMessage<any> { | ||
return message ? 'tool_calls' in message : false | ||
} | ||
|
||
export async function runTools( | ||
agent: Agent, | ||
agentRequest: Message[] | ||
): Promise<ChatCompletionToolMessageParam[]> { | ||
// tbd: find cleaner way to do this | ||
const messages = Array.from(agentRequest) | ||
const toolCallRequest = messages.pop() | ||
|
||
if (!isToolCallRequest(toolCallRequest)) { | ||
throw new Error('Invalid tool request') | ||
} | ||
|
||
const toolResults = await Promise.all( | ||
toolCallRequest.tool_calls.map(async (toolCall) => { | ||
if (toolCall.type !== 'function') { | ||
throw new Error('Tool call is not a function') | ||
} | ||
|
||
const tool = agent.tools ? agent.tools[toolCall.function.name] : null | ||
if (!tool) { | ||
throw new Error(`Unknown tool: ${toolCall.function.name}`) | ||
} | ||
|
||
const content = await tool.execute(toolCall.function.parsed_arguments, { | ||
provider: agent.provider, | ||
messages, | ||
}) | ||
return { | ||
role: 'tool' as const, | ||
tool_call_id: toolCall.id, | ||
content: JSON.stringify(content), | ||
} | ||
}) | ||
) | ||
|
||
return toolResults | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.