Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Refactor run handling to use run object instead of runId #531

Merged
merged 10 commits into from
Jan 10, 2025
141 changes: 47 additions & 94 deletions control-plane/src/modules/auth/auth.test.ts

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions control-plane/src/modules/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ export const runs = pgTable(
.notNull(),
status: text("status", {
enum: ["pending", "running", "paused", "done", "failed"],
}).default("pending"),
})
.default("pending")
.notNull(),
failure_reason: text("failure_reason"),
debug: boolean("debug").notNull().default(false),
attached_functions: json("attached_functions").$type<string[]>().notNull().default([]),
Expand All @@ -327,7 +329,7 @@ export const runs = pgTable(
interactive: boolean("interactive").default(true).notNull(),
enable_summarization: boolean("enable_summarization").default(false).notNull(),
enable_result_grounding: boolean("enable_result_grounding").default(false).notNull(),
auth_context: json("auth_context"),
auth_context: json("auth_context").$type<unknown>(),
context: json("context"),
},
table => ({
Expand Down
41 changes: 23 additions & 18 deletions control-plane/src/modules/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,26 @@ import { initServer } from "@ts-rest/fastify";
import { generateOpenApi } from "@ts-rest/open-api";
import fs from "fs";
import path from "path";
import { ulid } from "ulid";
import util from "util";
import { BadRequestError } from "../utilities/errors";
import { deleteAgent, getAgent, listAgents, upsertAgent, validateSchema } from "./agents";
import { authRouter } from "./auth/router";
import { getBlobData } from "./blobs";
import { contract } from "./contract";
import * as data from "./data";
import { integrationsRouter } from "./integrations/router";
import { jobsRouter } from "./jobs/router";
import { machineRouter } from "./machines/router";
import * as management from "./management";
import { buildModel } from "./models";
import * as events from "./observability/events";
import {
assertMessageOfType,
editHumanMessage,
getRunMessagesForDisplayWithPolling,
} from "./runs/messages";
import { addMessageAndResume, assertRunReady } from "./runs";
import { runsRouter } from "./runs/router";
import { machineRouter } from "./machines/router";
import { authRouter } from "./auth/router";
import { ulid } from "ulid";
import { getBlobData } from "./blobs";
import { posthog } from "./posthog";
import { BadRequestError } from "../utilities/errors";
import { upsertAgent, getAgent, deleteAgent, listAgents, validateSchema } from "./agents";
import { jobsRouter } from "./jobs/router";
import { buildModel } from "./models";
import { getServiceDefinitions, getStandardLibraryToolsMeta } from "./service-definitions";
import { integrationsRouter } from "./integrations/router";
import { addMessageAndResume, assertRunReady, getRun } from "./runs";
import { editHumanMessage, getRunMessagesForDisplayWithPolling } from "./runs/messages";
import { runsRouter } from "./runs/router";
import { getServerStats } from "./server-stats";
import { getServiceDefinitions, getStandardLibraryToolsMeta } from "./service-definitions";

const readFile = util.promisify(fs.readFile);

Expand Down Expand Up @@ -281,7 +277,16 @@ export const router = initServer().router(contract, {
const { clusterId, runId, messageId } = request.params;
const { message } = request.body;

await assertRunReady({ clusterId, runId });
const run = await getRun({ clusterId, runId });
await assertRunReady({
clusterId,
run: {
id: run.id,
status: run.status,
interactive: run.interactive,
clusterId: run.clusterId,
},
});

const auth = request.request.getAuth();
await auth.canManage({ run: { clusterId, runId } });
Expand Down
11 changes: 11 additions & 0 deletions control-plane/src/modules/runs/agent/agent.ai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { z } from "zod";
import { redisClient } from "../../redis";
import { AgentTool } from "./tool";
import { assertMessageOfType } from "../messages";
import { ChatIdentifiers } from "../../models/routing";

if (process.env.CI) {
jest.retryTimes(3);
Expand All @@ -17,6 +18,16 @@ describe("Agent", () => {
const run = {
id: "test",
clusterId: "test",
modelIdentifier: "claude-3-5-sonnet" as ChatIdentifiers,
resultSchema: null,
debug: false,
attachedFunctions: null,
status: "pending",
systemPrompt: null,
testMocks: {},
test: false,
reasoningTraces: false,
enableResultGrounding: false,
};

const tools = [
Expand Down
17 changes: 15 additions & 2 deletions control-plane/src/modules/runs/agent/agent.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { START, StateGraph } from "@langchain/langgraph";
import { type Run } from "../";
import { MODEL_CALL_NODE_NAME, handleModelCall } from "./nodes/model-call";
import { TOOL_CALL_NODE_NAME, handleToolCalls } from "./nodes/tool-call";
import { RunGraphState, createStateGraphChannels } from "./state";
import { PostStepSave, postModelEdge, postStartEdge, postToolEdge } from "./nodes/edges";
import { AgentMessage } from "../messages";
import { buildMockModel, buildModel } from "../../models";
import { AgentTool } from "./tool";
import { ChatIdentifiers } from "../../models/routing";

export type ReleventToolLookup = (state: RunGraphState) => Promise<AgentTool[]>;

Expand All @@ -23,7 +23,20 @@ export const createRunGraph = async ({
getTool,
mockModelResponses,
}: {
run: Run;
run: {
id: string;
clusterId: string;
modelIdentifier: ChatIdentifiers | null;
resultSchema: unknown | null;
debug: boolean;
attachedFunctions: string[] | null;
status: string;
systemPrompt: string | null;
testMocks: Record<string, { output: Record<string, unknown> }> | null;
test: boolean;
reasoningTraces: boolean;
enableResultGrounding: boolean;
};
additionalContext?: string;
allAvailableTools?: string[];
postStepSave: PostStepSave;
Expand Down
20 changes: 20 additions & 0 deletions control-plane/src/modules/runs/agent/nodes/edges.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ describe("postStartEdge", () => {
run: {
id: "test",
clusterId: "test",
modelIdentifier: null,
resultSchema: null,
debug: false,
attachedFunctions: null,
status: "running",
systemPrompt: null,
testMocks: {},
test: false,
reasoningTraces: false,
enableResultGrounding: false,
},
status: "running",
messages: [],
Expand Down Expand Up @@ -228,6 +238,16 @@ describe("postModelEdge", () => {
run: {
id: "test",
clusterId: "test",
modelIdentifier: null,
resultSchema: null,
debug: false,
attachedFunctions: null,
status: "running",
systemPrompt: null,
testMocks: {},
test: false,
reasoningTraces: false,
enableResultGrounding: false,
},
status: "running",
messages: [],
Expand Down
9 changes: 9 additions & 0 deletions control-plane/src/modules/runs/agent/nodes/model-call.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@ describe("handleModelCall", () => {
const run = {
id: "test-run",
clusterId: "test-cluster",
modelIdentifier: null,
resultSchema: null,
debug: false,
attachedFunctions: null,
status: "running",
systemPrompt: null,
testMocks: {},
test: false,
reasoningTraces: true,
enableResultGrounding: false,
};
const state: RunGraphState = {
messages: [
Expand Down
62 changes: 36 additions & 26 deletions control-plane/src/modules/runs/agent/nodes/model-output.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,44 @@ describe("buildModelSchema", () => {
let resultSchema: JsonSchema7ObjectType | undefined;

beforeEach(() => {
state = {
messages: [
{
id: ulid(),
clusterId: "test-cluster",
runId: "test-run",
data: {
message: "What are your capabilities?",
state = {
messages: [
{
id: ulid(),
clusterId: "test-cluster",
runId: "test-run",
data: {
message: "What are your capabilities?",
},
type: "human",
},
type: "human",
],
waitingJobs: [],
allAvailableTools: [],
run: {
id: "test-run",
clusterId: "test-cluster",
modelIdentifier: null,
resultSchema: null,
debug: false,
attachedFunctions: null,
status: "running",
systemPrompt: null,
testMocks: {},
test: false,
reasoningTraces: false,
enableResultGrounding: false,
},
],
waitingJobs: [],
allAvailableTools: [],
run: {
id: "test-run",
clusterId: "test-cluster",
},
additionalContext: "",
status: "running",
};
relevantSchemas = [
{ name: "localTool1"},
{ name: "localTool2"},
{ name: "globalTool1"},
{ name: "globalTool2"},
] as AgentTool[],
resultSchema = undefined;
additionalContext: "",
status: "running",
};
(relevantSchemas = [
{ name: "localTool1" },
{ name: "localTool2" },
{ name: "globalTool1" },
{ name: "globalTool2" },
] as AgentTool[]),
(resultSchema = undefined);
});

it("returns a schema with 'message' when resultSchema is not provided", () => {
Expand Down
10 changes: 10 additions & 0 deletions control-plane/src/modules/runs/agent/nodes/tool-call.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ describe("handleToolCalls", () => {
const run = {
id: "test",
clusterId: "test",
modelIdentifier: null,
resultSchema: null,
debug: false,
attachedFunctions: null,
status: "running",
systemPrompt: null,
testMocks: {},
test: false,
reasoningTraces: false,
enableResultGrounding: false,
};

const toolHandler = jest.fn();
Expand Down
18 changes: 15 additions & 3 deletions control-plane/src/modules/runs/agent/nodes/tool-call.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import { logger } from "../../../observability/logger";
import { addAttributes, withSpan } from "../../../observability/tracer";
import { trackCustomerTelemetry } from "../../../track-customer-telemetry";
import { AgentMessage, assertMessageOfType } from "../../messages";
import { Run } from "../../";
import { ToolFetcher } from "../agent";
import { RunGraphState } from "../state";
import { SpecialResultTypes, parseFunctionResponse } from "../tools/functions";
import { AgentTool, AgentToolInputError } from "../tool";
import { ChatIdentifiers } from "../../../models/routing";

export const TOOL_CALL_NODE_NAME = "action";

Expand Down Expand Up @@ -88,7 +88,13 @@ const _handleToolCalls = async (

const handleToolCall = (
toolCall: Required<AgentMessage["data"]>["invocations"][number],
run: Run,
run: {
id: string;
clusterId: string;
modelIdentifier: ChatIdentifiers | null;
resultSchema: unknown | null;
debug: boolean;
},
getTool: ToolFetcher
) =>
withSpan("run.toolCall", () => _handleToolCall(toolCall, run, getTool), {
Expand All @@ -100,7 +106,13 @@ const handleToolCall = (

const _handleToolCall = async (
toolCall: Required<AgentMessage["data"]>["invocations"][number],
run: Run,
run: {
id: string;
clusterId: string;
modelIdentifier: ChatIdentifiers | null;
resultSchema: unknown | null;
debug: boolean;
},
getTool: ToolFetcher
): Promise<Partial<RunGraphState>> => {
logger.info("Executing tool call");
Expand Down
11 changes: 9 additions & 2 deletions control-plane/src/modules/runs/agent/run.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { Run } from "../";
import { buildMockTools, findRelevantTools, formatJobsContext } from "./run";
import { upsertServiceDefinition } from "../../service-definitions";
import { createOwner } from "../../test/util";
Expand Down Expand Up @@ -30,6 +29,14 @@ describe("findRelevantTools", () => {
clusterId: owner.clusterId,
status: "running" as const,
attachedFunctions: ["testService_someFunction"],
modelIdentifier: null,
resultSchema: null,
debug: false,
systemPrompt: null,
testMocks: {},
test: false,
reasoningTraces: false,
enableResultGrounding: false,
};

const tools = await findRelevantTools({
Expand Down Expand Up @@ -65,7 +72,7 @@ const mockTargetSchema = JSON.stringify({
});

describe("buildMockTools", () => {
let run: Run;
let run: any;

const service = "testService";
beforeAll(async () => {
Expand Down
Loading
Loading