From 5b5fc424500d53ba4e3a1d7853c397d339b672c2 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 10 Jan 2025 00:32:04 +1100 Subject: [PATCH 1/9] fix: Refactor run handling to use run object instead of runId --- control-plane/src/modules/router.ts | 43 ++-- control-plane/src/modules/runs/index.test.ts | 75 +++---- control-plane/src/modules/runs/index.ts | 221 +++++++++---------- control-plane/src/modules/runs/router.ts | 20 +- 4 files changed, 167 insertions(+), 192 deletions(-) diff --git a/control-plane/src/modules/router.ts b/control-plane/src/modules/router.ts index 4601274b..6be66219 100644 --- a/control-plane/src/modules/router.ts +++ b/control-plane/src/modules/router.ts @@ -2,38 +2,34 @@ import { initServer } from "@ts-rest/fastify"; import { generateOpenApi } from "@ts-rest/open-api"; import fs from "fs"; import path from "path"; +import { ulid } from "ulid"; import util from "util"; -import { contract } from "./contract"; -import * as data from "./data"; -import * as management from "./management"; -import * as events from "./observability/events"; -import { - assertMessageOfType, - editHumanMessage, - getRunMessagesForDisplayWithPolling, -} from "./runs/messages"; -import { addMessageAndResume, assertRunReady } from "./runs"; -import { runsRouter } from "./runs/router"; -import { machineRouter } from "./machines/router"; +import { BadRequestError } from "../utilities/errors"; +import { deleteAgent, getAgent, listAgents, upsertAgent, validateSchema } from "./agents"; import { authRouter } from "./auth/router"; -import { ulid } from "ulid"; import { getBlobData } from "./blobs"; -import { posthog } from "./posthog"; -import { BadRequestError } from "../utilities/errors"; -import { upsertAgent, getAgent, deleteAgent, listAgents, validateSchema } from "./agents"; +import { contract } from "./contract"; +import * as data from "./data"; +import { integrationsRouter } from "./integrations/router"; +import { jobsRouter } from "./jobs/router"; import { createClusterKnowledgeArtifact, - getKnowledge, - upsertKnowledgeArtifact, deleteKnowledgeArtifact, - getKnowledgeArtifact, getAllKnowledgeArtifacts, + getKnowledge, + getKnowledgeArtifact, + upsertKnowledgeArtifact, } from "./knowledge/knowledgebase"; -import { jobsRouter } from "./jobs/router"; +import { machineRouter } from "./machines/router"; +import * as management from "./management"; import { buildModel } from "./models"; -import { getServiceDefinitions, getStandardLibraryToolsMeta } from "./service-definitions"; -import { integrationsRouter } from "./integrations/router"; +import * as events from "./observability/events"; +import { posthog } from "./posthog"; +import { addMessageAndResume, assertRunReady, getRun } from "./runs"; +import { editHumanMessage, getRunMessagesForDisplayWithPolling } from "./runs/messages"; +import { runsRouter } from "./runs/router"; import { getServerStats } from "./server-stats"; +import { getServiceDefinitions, getStandardLibraryToolsMeta } from "./service-definitions"; const readFile = util.promisify(fs.readFile); @@ -289,7 +285,8 @@ export const router = initServer().router(contract, { const { clusterId, runId, messageId } = request.params; const { message } = request.body; - await assertRunReady({ clusterId, runId }); + const run = await getRun({ clusterId, runId }); + await assertRunReady({ clusterId, run }); const auth = request.request.getAuth(); await auth.canManage({ run: { clusterId, runId } }); diff --git a/control-plane/src/modules/runs/index.test.ts b/control-plane/src/modules/runs/index.test.ts index 7962ec91..5b121d85 100644 --- a/control-plane/src/modules/runs/index.test.ts +++ b/control-plane/src/modules/runs/index.test.ts @@ -17,9 +17,9 @@ describe("assertRunReady", () => { await expect( assertRunReady({ - runId: run.id, + run, clusterId: owner.clusterId, - }), + }) ).resolves.not.toThrow(); }); @@ -35,9 +35,9 @@ describe("assertRunReady", () => { await expect( assertRunReady({ - runId: run.id, + run, clusterId: owner.clusterId, - }), + }) ).rejects.toThrow(RunBusyError); }); @@ -49,9 +49,9 @@ describe("assertRunReady", () => { await expect( assertRunReady({ - runId: run.id, + run, clusterId: owner.clusterId, - }), + }) ).rejects.toThrow(BadRequestError); }); @@ -87,42 +87,39 @@ describe("assertRunReady", () => { await expect( assertRunReady({ - runId: run.id, + run, clusterId: owner.clusterId, - }), + }) ).resolves.not.toThrow(); }); - it.each(["human", "template"] as const)( - "should throw if last message is %s", - async (type) => { - const run = await createRun({ - clusterId: owner.clusterId, - }); + it.each(["human", "template"] as const)("should throw if last message is %s", async type => { + const run = await createRun({ + clusterId: owner.clusterId, + }); - await insertRunMessage({ - id: ulid(), - data: { - message: "Some request", - }, - type, + await insertRunMessage({ + id: ulid(), + data: { + message: "Some request", + }, + type, + clusterId: owner.clusterId, + runId: run.id, + }); + + await updateRun({ + ...run, + status: "done", + }); + + await expect( + assertRunReady({ + run, clusterId: owner.clusterId, - runId: run.id, - }); - - await updateRun({ - ...run, - status: "done", - }); - - await expect( - assertRunReady({ - runId: run.id, - clusterId: owner.clusterId, - }), - ).rejects.toThrow(RunBusyError); - }, - ); + }) + ).rejects.toThrow(RunBusyError); + }); it.each([ { @@ -153,7 +150,7 @@ describe("assertRunReady", () => { }, type: "template" as const, }, - ])("messages should throw unless AI with no tool calls", async (message) => { + ])("messages should throw unless AI with no tool calls", async message => { const run = await createRun({ clusterId: owner.clusterId, }); @@ -192,9 +189,9 @@ describe("assertRunReady", () => { await expect( assertRunReady({ - runId: run.id, + run, clusterId: owner.clusterId, - }), + }) ).rejects.toThrow(RunBusyError); }); }); diff --git a/control-plane/src/modules/runs/index.ts b/control-plane/src/modules/runs/index.ts index b9cbd0c8..984d7c5c 100644 --- a/control-plane/src/modules/runs/index.ts +++ b/control-plane/src/modules/runs/index.ts @@ -12,26 +12,13 @@ import { import { ulid } from "ulid"; import { env } from "../../utilities/env"; import { BadRequestError, NotFoundError, RunBusyError } from "../../utilities/errors"; -import { - clusters, - db, - jobs, - RunMessageMetadata, - runMessages, - runTags, - runs, -} from "../data"; +import { clusters, db, jobs, RunMessageMetadata, runMessages, runTags, runs } from "../data"; import { ChatIdentifiers } from "../models/routing"; import { logger } from "../observability/logger"; import { injectTraceContext } from "../observability/tracer"; import { sqs } from "../sqs"; import { trackCustomerTelemetry } from "../track-customer-telemetry"; -import { - getRunMessages, - hasInvocations, - insertRunMessage, - lastAgentMessage, -} from "./messages"; +import { getRunMessages, hasInvocations, insertRunMessage, lastAgentMessage } from "./messages"; import { getRunTags } from "./tags"; export { start, stop } from "./queues"; @@ -109,85 +96,69 @@ export const createRun = async ({ context?: unknown; enableResultGrounding?: boolean; }): Promise => { - let run: Run | undefined = undefined; - - await db.transaction(async tx => { - // TODO: Resolve the run debug value via a subquery and remove the transaction: https://github.com/inferablehq/inferable/issues/389 - const [debugQuery] = await tx - .select({ - debug: clusters.debug, - }) - .from(clusters) - .where(eq(clusters.id, clusterId)); - - const result = await tx - .insert(runs) - .values([ - { - id: runId ?? ulid(), - cluster_id: clusterId, - status: "pending", - user_id: userId ?? "SYSTEM", - ...(name ? { name } : {}), - debug: debugQuery.debug, - system_prompt: systemPrompt, - test, - test_mocks: testMocks, - reasoning_traces: reasoningTraces, - interactive: interactive, - enable_summarization: enableSummarization, - on_status_change: onStatusChange, - result_schema: resultSchema, - attached_functions: attachedFunctions, - agent_id: agentId, - agent_version: agentVersion, - model_identifier: modelIdentifier, - auth_context: authContext, - context: context, - enable_result_grounding: enableResultGrounding, - }, - ]) - .returning({ - id: runs.id, - name: runs.name, - clusterId: runs.cluster_id, - systemPrompt: runs.system_prompt, - status: runs.status, - debug: runs.debug, - test: runs.test, - attachedFunctions: runs.attached_functions, - modelIdentifier: runs.model_identifier, - authContext: runs.auth_context, - context: runs.context, - interactive: runs.interactive, - enableResultGrounding: runs.enable_result_grounding, - }); - - run = result[0]; - - if (!!run && tags) { - await tx.insert(runTags).values( - Object.entries(tags).map(([key, value]) => ({ - cluster_id: clusterId, - workflow_id: run!.id, - key, - value, - })) - ); - } - }); + // Insert the run with a subquery for debug value + const [run] = await db + .insert(runs) + .values({ + id: runId ?? ulid(), + cluster_id: clusterId, + status: "pending", + user_id: userId ?? "SYSTEM", + ...(name ? { name } : {}), + debug: sql`(SELECT debug FROM ${clusters} WHERE id = ${clusterId})`, + system_prompt: systemPrompt, + test, + test_mocks: testMocks, + reasoning_traces: reasoningTraces, + interactive: interactive, + enable_summarization: enableSummarization, + on_status_change: onStatusChange, + result_schema: resultSchema, + attached_functions: attachedFunctions, + agent_id: agentId, + agent_version: agentVersion, + model_identifier: modelIdentifier, + auth_context: authContext, + context: context, + enable_result_grounding: enableResultGrounding, + }) + .returning({ + id: runs.id, + name: runs.name, + clusterId: runs.cluster_id, + systemPrompt: runs.system_prompt, + status: runs.status, + debug: runs.debug, + test: runs.test, + attachedFunctions: runs.attached_functions, + modelIdentifier: runs.model_identifier, + authContext: runs.auth_context, + context: runs.context, + interactive: runs.interactive, + enableResultGrounding: runs.enable_result_grounding, + }); if (!run) { throw new Error("Failed to create run"); } + // Insert tags if provided + if (tags) { + await db.insert(runTags).values( + Object.entries(tags).map(([key, value]) => ({ + cluster_id: clusterId, + workflow_id: run.id, + key, + value, + })) + ); + } + return run; }; export const deleteRun = async ({ clusterId, runId }: { clusterId: string; runId: string }) => { - await db - .delete(runs) - .where(and(eq(runs.cluster_id, clusterId), eq(runs.id, runId))); + await db.delete(runs).where(and(eq(runs.cluster_id, clusterId), eq(runs.id, runId))); }; export const updateRun = async (run: Run): Promise => { @@ -204,12 +175,7 @@ export const updateRun = async (run: Run): Promise => { feedback_comment: run.feedbackComment, feedback_score: run.feedbackScore, }) - .where( - and( - eq(runs.cluster_id, run.clusterId), - eq(runs.id, run.id) - ) - ) + .where(and(eq(runs.cluster_id, run.clusterId), eq(runs.id, run.id))) .returning({ id: runs.id, name: runs.name, @@ -264,12 +230,7 @@ export const getRun = async ({ clusterId, runId }: { clusterId: string; runId: s enableResultGrounding: runs.enable_result_grounding, }) .from(runs) - .where( - and( - eq(runs.cluster_id, clusterId), - eq(runs.id, runId) - ) - ); + .where(and(eq(runs.cluster_id, clusterId), eq(runs.id, runId))); return run; }; @@ -300,7 +261,7 @@ export const getClusterRuns = async ({ debug: runs.debug, test: runs.test, agentId: runs.agent_id, - agentVersion : runs.agent_version, + agentVersion: runs.agent_version, feedbackScore: runs.feedback_score, modelIdentifier: runs.model_identifier, authContext: runs.auth_context, @@ -322,13 +283,7 @@ export const getClusterRuns = async ({ return result; }; -export const getRunDetails = async ({ - clusterId, - runId, -}: { - clusterId: string; - runId: string; -}) => { +export const getRunDetails = async ({ clusterId, runId }: { clusterId: string; runId: string }) => { const [[run], agentMessage, tags] = await Promise.all([ db .select({ @@ -364,33 +319,33 @@ export const getRunDetails = async ({ }; }; -export const addMessageAndResume = async ({ +export const addMessageAndResumeWithRun = async ({ userId, id, clusterId, - runId, message, type, metadata, skipAssert, + run, }: { userId?: string; id: string; clusterId: string; - runId: string; message: string; type: "human" | "template" | "supervisor"; metadata?: RunMessageMetadata; skipAssert?: boolean; + run: Run; }) => { if (!skipAssert) { - await assertRunReady({ clusterId, runId }); + await assertRunReady({ clusterId, run }); } await insertRunMessage({ userId, clusterId, - runId, + runId: run.id, data: { message, }, @@ -400,11 +355,44 @@ export const addMessageAndResume = async ({ }); // TODO: Move run name generation to event sourcing (pg-listen) https://github.com/inferablehq/inferable/issues/390 - await generateRunName(await getRun({ clusterId, runId }), message); + await generateRunName(run, message); await resumeRun({ clusterId, - id: runId, + id: run.id, + }); +}; + +export const addMessageAndResume = async ({ + userId, + id, + clusterId, + runId, + message, + type, + metadata, + skipAssert, +}: { + userId?: string; + id: string; + clusterId: string; + runId: string; + message: string; + type: "human" | "template" | "supervisor"; + metadata?: RunMessageMetadata; + skipAssert?: boolean; +}) => { + const run = await getRun({ clusterId, runId }); + + return addMessageAndResumeWithRun({ + userId, + id, + clusterId, + run, + message, + type, + metadata, + skipAssert, }); }; @@ -520,11 +508,11 @@ export const createRunWithMessage = async ({ enableResultGrounding, }); - await addMessageAndResume({ + await addMessageAndResumeWithRun({ id: ulid(), userId, clusterId, - runId: run.id, + run, message, type, metadata: messageMetadata, @@ -543,8 +531,9 @@ export const getClusterBackgroundRun = (clusterId: string) => { return `${clusterId}BACKGROUND`; }; -export const assertRunReady = async (input: { runId: string; clusterId: string }) => { - const run = await getRun(input); +export const assertRunReady = async (input: { clusterId: string; run: Run }) => { + const run = input.run; + if (!run) { throw new NotFoundError("Run not found"); } diff --git a/control-plane/src/modules/runs/router.ts b/control-plane/src/modules/runs/router.ts index c3156af6..3a6f1170 100644 --- a/control-plane/src/modules/runs/router.ts +++ b/control-plane/src/modules/runs/router.ts @@ -8,17 +8,11 @@ import { contract } from "../contract"; import { getJobReferences } from "../jobs/jobs"; import * as events from "../observability/events"; import { posthog } from "../posthog"; -import { - RunOptions, - getAgent, - mergeAgentOptions, - validateSchema, -} from "../agents"; +import { RunOptions, getAgent, mergeAgentOptions, validateSchema } from "../agents"; import { normalizeFunctionReference } from "../service-definitions"; import { timeline } from "../timeline"; import { getRunsByTag } from "./tags"; import { - addMessageAndResume, createRetry, createRun, deleteRun, @@ -26,6 +20,7 @@ import { getAgentMetrics, getRunDetails, updateRun, + addMessageAndResumeWithRun, } from "./"; export const runsRouter = initServer().router( @@ -128,8 +123,6 @@ export const runsRouter = initServer().router( const merged = mergeAgentOptions(runOptions, agent); - - if (merged.error) { return merged.error; } @@ -182,11 +175,11 @@ export const runsRouter = initServer().router( }); if (runOptions.initialPrompt) { - await addMessageAndResume({ + await addMessageAndResumeWithRun({ id: ulid(), userId: auth.entityId, clusterId, - runId: run.id, + run, message: runOptions.initialPrompt, type: body.agentId ? "template" : "human", metadata: runOptions.messageMetadata, @@ -302,7 +295,7 @@ export const runsRouter = initServer().router( // Custom auth can only access their own Runs if (auth.type === "custom") { - userId = auth.entityId + userId = auth.entityId; } if (tags) { @@ -324,7 +317,7 @@ export const runsRouter = initServer().router( value, limit, agentId, - userId + userId, }); return { @@ -333,7 +326,6 @@ export const runsRouter = initServer().router( }; } - const result = await getClusterRuns({ clusterId, userId, From 658fa44f0d903da70e09c8dcb929452a76d2cc1e Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 10 Jan 2025 00:36:43 +1100 Subject: [PATCH 2/9] fix: Minor changes 2025-01-09-23:20:58 --- control-plane/src/modules/runs/index.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/control-plane/src/modules/runs/index.test.ts b/control-plane/src/modules/runs/index.test.ts index 5b121d85..c3e08225 100644 --- a/control-plane/src/modules/runs/index.test.ts +++ b/control-plane/src/modules/runs/index.test.ts @@ -182,14 +182,14 @@ describe("assertRunReady", () => { clusterId: owner.clusterId, }); - await updateRun({ + const updatedRun = await updateRun({ ...run, status: "done", }); await expect( assertRunReady({ - run, + run: updatedRun, clusterId: owner.clusterId, }) ).rejects.toThrow(RunBusyError); From 3e169a741e0588765df09f102437743322919ff3 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 10 Jan 2025 11:37:16 +1100 Subject: [PATCH 3/9] fix: Minor changes 2025-01-09-23:20:58 --- control-plane/src/modules/auth/auth.test.ts | 141 ++++++------------ control-plane/src/modules/data.ts | 6 +- control-plane/src/modules/router.ts | 10 +- .../src/modules/runs/agent/agent.ai.test.ts | 11 ++ control-plane/src/modules/runs/agent/agent.ts | 17 ++- .../modules/runs/agent/nodes/edges.test.ts | 20 +++ .../runs/agent/nodes/model-call.test.ts | 11 +- .../runs/agent/nodes/model-output.test.ts | 62 ++++---- .../runs/agent/nodes/tool-call.test.ts | 10 ++ .../src/modules/runs/agent/nodes/tool-call.ts | 18 ++- .../src/modules/runs/agent/run.test.ts | 11 +- control-plane/src/modules/runs/agent/run.ts | 63 +++++++- control-plane/src/modules/runs/agent/state.ts | 32 +++- .../agent/tools/cluster-internal-tools.ts | 12 +- .../runs/agent/tools/knowledge-artifacts.ts | 19 +-- control-plane/src/modules/runs/index.test.ts | 10 +- control-plane/src/modules/runs/index.ts | 108 +++++++------- control-plane/src/modules/runs/notify.ts | 11 +- control-plane/src/modules/runs/router.ts | 5 +- .../src/modules/runs/summarization.ts | 9 +- 20 files changed, 364 insertions(+), 222 deletions(-) diff --git a/control-plane/src/modules/auth/auth.test.ts b/control-plane/src/modules/auth/auth.test.ts index 6d84ea46..3f59d761 100644 --- a/control-plane/src/modules/auth/auth.test.ts +++ b/control-plane/src/modules/auth/auth.test.ts @@ -1,14 +1,10 @@ import { ulid } from "ulid"; import { createOwner } from "../test/util"; -import { Run, createRun, createRunWithMessage } from "../runs"; +import { createRun, createRunWithMessage } from "../runs"; import * as customAuth from "./custom"; import * as clusterAuth from "./cluster"; import * as clerkAuth from "./clerk"; -import { - Auth, - extractAuthState, - extractCustomAuthState, -} from "./auth"; +import { Auth, extractAuthState, extractCustomAuthState } from "./auth"; import { redisClient } from "../redis"; jest.mock("../../utilities/env"); @@ -21,10 +17,7 @@ const mockClusterAuth = { // Mocking custom auth verification const mockCustomAuth = { - verify: jest.spyOn( - customAuth, - "verify", - ), + verify: jest.spyOn(customAuth, "verify"), }; const mockClerkAuth = { @@ -102,7 +95,7 @@ describe("extractAuthState", () => { const result = await extractAuthState(""); await expect( - result!.canAccess({ cluster: { clusterId: "cluster_1" } }), + result!.canAccess({ cluster: { clusterId: "cluster_1" } }) ).resolves.toBeDefined(); }); @@ -116,7 +109,7 @@ describe("extractAuthState", () => { const result = await extractAuthState(""); await expect( - result!.canAccess({ cluster: { clusterId: "cluster_2" } }), + result!.canAccess({ cluster: { clusterId: "cluster_2" } }) ).rejects.toThrow(); }); }); @@ -139,7 +132,7 @@ describe("extractAuthState", () => { clusterId: "cluster_1", runId: "run_1", }, - }), + }) ).resolves.toBeDefined(); }); }); @@ -165,7 +158,7 @@ describe("extractAuthState", () => { clusterId: "cluster_1", runId: "run_1", }, - }), + }) ).resolves.toBeDefined(); }); }); @@ -183,7 +176,7 @@ describe("extractAuthState", () => { const result = await extractAuthState(""); await expect( - result!.canManage({ cluster: { clusterId: "cluster_1" } }), + result!.canManage({ cluster: { clusterId: "cluster_1" } }) ).rejects.toThrow(); }); }); @@ -205,7 +198,7 @@ describe("extractAuthState", () => { expect(() => result!.canManage({ agent: { clusterId: "cluster_1", agentId: "template_1" }, - }), + }) ).toBeDefined(); }); }); @@ -307,13 +300,11 @@ describe("extractAuthState", () => { await expect( result?.canAccess({ cluster: { clusterId: owner.clusterId }, - }), + }) ).resolves.toBeDefined(); // Incorrect cluster ID - await expect( - result?.canAccess({ cluster: { clusterId: "cluster_2" } }), - ).rejects.toThrow(); + await expect(result?.canAccess({ cluster: { clusterId: "cluster_2" } })).rejects.toThrow(); }); describe("runs", () => { @@ -321,8 +312,8 @@ describe("extractAuthState", () => { let owner2: Awaited>; let owner1AuthState: Auth | undefined; let owner2AuthState: Auth | undefined; - let run1: Run; - let run2: Run; + let run1: any; + let run2: any; const organizationId = Math.random().toString(); beforeAll(async () => { @@ -384,7 +375,7 @@ describe("extractAuthState", () => { runId: run1.id, clusterId: run1.clusterId, }, - }), + }) ).resolves.toBeDefined(); await expect( @@ -393,7 +384,7 @@ describe("extractAuthState", () => { runId: run2.id, clusterId: run2.clusterId, }, - }), + }) ).rejects.toThrow(); await expect( @@ -402,7 +393,7 @@ describe("extractAuthState", () => { runId: run2.id, clusterId: run2.clusterId, }, - }), + }) ).resolves.toBeDefined(); await expect( @@ -411,7 +402,7 @@ describe("extractAuthState", () => { runId: run1.id, clusterId: run1.clusterId, }, - }), + }) ).rejects.toThrow(); }); @@ -440,7 +431,7 @@ describe("extractAuthState", () => { runId: run1.id, clusterId: run1.clusterId, }, - }), + }) ).resolves.toBeDefined(); await expect( @@ -449,7 +440,7 @@ describe("extractAuthState", () => { runId: run2.id, clusterId: run2.clusterId, }, - }), + }) ).resolves.toBeDefined(); }); @@ -478,7 +469,7 @@ describe("extractAuthState", () => { runId: run1.id, clusterId: run1.clusterId, }, - }), + }) ).rejects.toThrow(); await expect( @@ -487,7 +478,7 @@ describe("extractAuthState", () => { runId: run2.id, clusterId: run2.clusterId, }, - }), + }) ).rejects.toThrow(); }); }); @@ -538,10 +529,11 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - await expect(extractCustomAuthState("abc123", owner.clusterId)).rejects.toThrow("Custom auth is not enabled for this cluster"); + await expect(extractCustomAuthState("abc123", owner.clusterId)).rejects.toThrow( + "Custom auth is not enabled for this cluster" + ); }); - describe("isUser", () => { it("should throw", async () => { mockCustomAuth.verify.mockResolvedValue({ @@ -561,12 +553,9 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); await expect( - result!.canAccess({ cluster: { clusterId: owner.clusterId } }), + result!.canAccess({ cluster: { clusterId: owner.clusterId } }) ).resolves.toBeDefined(); }); @@ -575,12 +564,9 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); await expect( - result!.canAccess({ cluster: { clusterId: "some_other_cluster" } }), + result!.canAccess({ cluster: { clusterId: "some_other_cluster" } }) ).rejects.toThrow(); }); }); @@ -597,17 +583,14 @@ describe("extractCustomAuthState", () => { userId: "custom:someValue", }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); await expect( result!.canAccess({ run: { clusterId: owner.clusterId, runId: run.id, }, - }), + }) ).resolves.toBeDefined(); }); @@ -620,17 +603,14 @@ describe("extractCustomAuthState", () => { clusterId: owner.clusterId, }); - const result = await extractCustomAuthState( - "def456", - owner.clusterId, - ); + const result = await extractCustomAuthState("def456", owner.clusterId); await expect( result!.canAccess({ run: { clusterId: owner.clusterId, runId: run.id, }, - }), + }) ).rejects.toThrow(); }); @@ -643,17 +623,14 @@ describe("extractCustomAuthState", () => { clusterId: owner.clusterId, }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); await expect( result!.canAccess({ run: { clusterId: owner.clusterId, runId: run.id, }, - }), + }) ).rejects.toThrow(); }); }); @@ -672,10 +649,7 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); await expect( result!.canManage({ @@ -683,7 +657,7 @@ describe("extractCustomAuthState", () => { clusterId: owner.clusterId, runId: run.id, }, - }), + }) ).resolves.toBeDefined(); }); @@ -696,10 +670,7 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - const result = await extractCustomAuthState( - "def456", - owner.clusterId, - ); + const result = await extractCustomAuthState("def456", owner.clusterId); await expect( result!.canManage({ @@ -707,7 +678,7 @@ describe("extractCustomAuthState", () => { clusterId: owner.clusterId, runId: run.id, }, - }), + }) ).rejects.toThrow(); }); it("should throw if run is not customer authenticated", async () => { @@ -719,10 +690,7 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); await expect( result!.canManage({ @@ -730,7 +698,7 @@ describe("extractCustomAuthState", () => { clusterId: owner.clusterId, runId: run.id, }, - }), + }) ).rejects.toThrow(); }); }); @@ -741,12 +709,9 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); await expect( - result!.canManage({ cluster: { clusterId: owner.clusterId } }), + result!.canManage({ cluster: { clusterId: owner.clusterId } }) ).rejects.toThrow(); }); }); @@ -758,15 +723,12 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); await expect( result!.canManage({ agent: { clusterId: "cluster_1", agentId: "template_1" }, - }), + }) ).rejects.toThrow(); }); }); @@ -780,10 +742,7 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); expect(() => result!.canCreate({ cluster: true })).toThrow(); }); }); @@ -795,10 +754,7 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); expect(() => result!.canCreate({ agent: true })).toThrow(); }); }); @@ -810,10 +766,7 @@ describe("extractCustomAuthState", () => { userId: "someValue", }); - const result = await extractCustomAuthState( - "abc123", - owner.clusterId, - ); + const result = await extractCustomAuthState("abc123", owner.clusterId); expect(() => result!.canCreate({ run: true })).toBeDefined(); }); }); diff --git a/control-plane/src/modules/data.ts b/control-plane/src/modules/data.ts index 0e66e5d5..0cf6fedc 100644 --- a/control-plane/src/modules/data.ts +++ b/control-plane/src/modules/data.ts @@ -304,7 +304,9 @@ export const runs = pgTable( .notNull(), status: text("status", { enum: ["pending", "running", "paused", "done", "failed"], - }).default("pending"), + }) + .default("pending") + .notNull(), failure_reason: text("failure_reason"), debug: boolean("debug").notNull().default(false), attached_functions: json("attached_functions").$type().notNull().default([]), @@ -327,7 +329,7 @@ export const runs = pgTable( interactive: boolean("interactive").default(true).notNull(), enable_summarization: boolean("enable_summarization").default(false).notNull(), enable_result_grounding: boolean("enable_result_grounding").default(false).notNull(), - auth_context: json("auth_context"), + auth_context: json("auth_context").$type(), context: json("context"), }, table => ({ diff --git a/control-plane/src/modules/router.ts b/control-plane/src/modules/router.ts index 6be66219..63838a61 100644 --- a/control-plane/src/modules/router.ts +++ b/control-plane/src/modules/router.ts @@ -286,7 +286,15 @@ export const router = initServer().router(contract, { const { message } = request.body; const run = await getRun({ clusterId, runId }); - await assertRunReady({ clusterId, run }); + await assertRunReady({ + clusterId, + run: { + id: run.id, + status: run.status, + interactive: run.interactive, + clusterId: run.clusterId, + }, + }); const auth = request.request.getAuth(); await auth.canManage({ run: { clusterId, runId } }); diff --git a/control-plane/src/modules/runs/agent/agent.ai.test.ts b/control-plane/src/modules/runs/agent/agent.ai.test.ts index 63c261b4..7f72a13d 100644 --- a/control-plane/src/modules/runs/agent/agent.ai.test.ts +++ b/control-plane/src/modules/runs/agent/agent.ai.test.ts @@ -3,6 +3,7 @@ import { z } from "zod"; import { redisClient } from "../../redis"; import { AgentTool } from "./tool"; import { assertMessageOfType } from "../messages"; +import { ChatIdentifiers } from "../../models/routing"; if (process.env.CI) { jest.retryTimes(3); @@ -17,6 +18,16 @@ describe("Agent", () => { const run = { id: "test", clusterId: "test", + modelIdentifier: "claude-3-5-sonnet" as ChatIdentifiers, + resultSchema: null, + debug: false, + attachedFunctions: null, + status: "pending", + systemPrompt: null, + testMocks: {}, + test: false, + reasoningTraces: false, + enableResultGrounding: false, }; const tools = [ diff --git a/control-plane/src/modules/runs/agent/agent.ts b/control-plane/src/modules/runs/agent/agent.ts index f41aae12..34051c83 100644 --- a/control-plane/src/modules/runs/agent/agent.ts +++ b/control-plane/src/modules/runs/agent/agent.ts @@ -1,5 +1,4 @@ import { START, StateGraph } from "@langchain/langgraph"; -import { type Run } from "../"; import { MODEL_CALL_NODE_NAME, handleModelCall } from "./nodes/model-call"; import { TOOL_CALL_NODE_NAME, handleToolCalls } from "./nodes/tool-call"; import { RunGraphState, createStateGraphChannels } from "./state"; @@ -7,6 +6,7 @@ import { PostStepSave, postModelEdge, postStartEdge, postToolEdge } from "./node import { AgentMessage } from "../messages"; import { buildMockModel, buildModel } from "../../models"; import { AgentTool, AgentToolV2 } from "./tool"; +import { ChatIdentifiers } from "../../models/routing"; export type ReleventToolLookup = (state: RunGraphState) => Promise<(AgentTool | AgentToolV2)[]>; @@ -23,7 +23,20 @@ export const createRunGraph = async ({ getTool, mockModelResponses, }: { - run: Run; + run: { + id: string; + clusterId: string; + modelIdentifier: ChatIdentifiers | null; + resultSchema: unknown | null; + debug: boolean; + attachedFunctions: string[] | null; + status: string; + systemPrompt: string | null; + testMocks: Record }> | null; + test: boolean; + reasoningTraces: boolean; + enableResultGrounding: boolean; + }; additionalContext?: string; allAvailableTools?: string[]; postStepSave: PostStepSave; diff --git a/control-plane/src/modules/runs/agent/nodes/edges.test.ts b/control-plane/src/modules/runs/agent/nodes/edges.test.ts index 25e34074..fb7fb96c 100644 --- a/control-plane/src/modules/runs/agent/nodes/edges.test.ts +++ b/control-plane/src/modules/runs/agent/nodes/edges.test.ts @@ -10,6 +10,16 @@ describe("postStartEdge", () => { run: { id: "test", clusterId: "test", + modelIdentifier: null, + resultSchema: null, + debug: false, + attachedFunctions: null, + status: "running", + systemPrompt: null, + testMocks: {}, + test: false, + reasoningTraces: false, + enableResultGrounding: false, }, status: "running", messages: [], @@ -228,6 +238,16 @@ describe("postModelEdge", () => { run: { id: "test", clusterId: "test", + modelIdentifier: null, + resultSchema: null, + debug: false, + attachedFunctions: null, + status: "running", + systemPrompt: null, + testMocks: {}, + test: false, + reasoningTraces: false, + enableResultGrounding: false, }, status: "running", messages: [], diff --git a/control-plane/src/modules/runs/agent/nodes/model-call.test.ts b/control-plane/src/modules/runs/agent/nodes/model-call.test.ts index 29bf71d8..293af3a6 100644 --- a/control-plane/src/modules/runs/agent/nodes/model-call.test.ts +++ b/control-plane/src/modules/runs/agent/nodes/model-call.test.ts @@ -11,7 +11,16 @@ describe("handleModelCall", () => { const run = { id: "test-run", clusterId: "test-cluster", - reasoningTraces: true, + modelIdentifier: null, + resultSchema: null, + debug: false, + attachedFunctions: null, + status: "running", + systemPrompt: null, + testMocks: {}, + test: false, + reasoningTraces: false, + enableResultGrounding: false, }; const state: RunGraphState = { messages: [ diff --git a/control-plane/src/modules/runs/agent/nodes/model-output.test.ts b/control-plane/src/modules/runs/agent/nodes/model-output.test.ts index 4dd775d5..ee602ebb 100644 --- a/control-plane/src/modules/runs/agent/nodes/model-output.test.ts +++ b/control-plane/src/modules/runs/agent/nodes/model-output.test.ts @@ -10,34 +10,44 @@ describe("buildModelSchema", () => { let resultSchema: JsonSchema7ObjectType | undefined; beforeEach(() => { - state = { - messages: [ - { - id: ulid(), - clusterId: "test-cluster", - runId: "test-run", - data: { - message: "What are your capabilities?", + state = { + messages: [ + { + id: ulid(), + clusterId: "test-cluster", + runId: "test-run", + data: { + message: "What are your capabilities?", + }, + type: "human", }, - type: "human", + ], + waitingJobs: [], + allAvailableTools: [], + run: { + id: "test-run", + clusterId: "test-cluster", + modelIdentifier: null, + resultSchema: null, + debug: false, + attachedFunctions: null, + status: "running", + systemPrompt: null, + testMocks: {}, + test: false, + reasoningTraces: false, + enableResultGrounding: false, }, - ], - waitingJobs: [], - allAvailableTools: [], - run: { - id: "test-run", - clusterId: "test-cluster", - }, - additionalContext: "", - status: "running", - }; - relevantSchemas = [ - { name: "localTool1"}, - { name: "localTool2"}, - { name: "globalTool1"}, - { name: "globalTool2"}, - ] as AgentTool[], - resultSchema = undefined; + additionalContext: "", + status: "running", + }; + (relevantSchemas = [ + { name: "localTool1" }, + { name: "localTool2" }, + { name: "globalTool1" }, + { name: "globalTool2" }, + ] as AgentTool[]), + (resultSchema = undefined); }); it("returns a schema with 'message' when resultSchema is not provided", () => { diff --git a/control-plane/src/modules/runs/agent/nodes/tool-call.test.ts b/control-plane/src/modules/runs/agent/nodes/tool-call.test.ts index 29bf8159..c3e1a49f 100644 --- a/control-plane/src/modules/runs/agent/nodes/tool-call.test.ts +++ b/control-plane/src/modules/runs/agent/nodes/tool-call.test.ts @@ -12,6 +12,16 @@ describe("handleToolCalls", () => { const run = { id: "test", clusterId: "test", + modelIdentifier: null, + resultSchema: null, + debug: false, + attachedFunctions: null, + status: "running", + systemPrompt: null, + testMocks: {}, + test: false, + reasoningTraces: false, + enableResultGrounding: false, }; const toolHandler = jest.fn(); diff --git a/control-plane/src/modules/runs/agent/nodes/tool-call.ts b/control-plane/src/modules/runs/agent/nodes/tool-call.ts index df78c669..4e242c99 100644 --- a/control-plane/src/modules/runs/agent/nodes/tool-call.ts +++ b/control-plane/src/modules/runs/agent/nodes/tool-call.ts @@ -5,11 +5,11 @@ import { logger } from "../../../observability/logger"; import { addAttributes, withSpan } from "../../../observability/tracer"; import { trackCustomerTelemetry } from "../../../track-customer-telemetry"; import { AgentMessage, assertMessageOfType } from "../../messages"; -import { Run } from "../../"; import { ToolFetcher } from "../agent"; import { RunGraphState } from "../state"; import { SpecialResultTypes, parseFunctionResponse } from "../tools/functions"; import { AgentTool, AgentToolInputError, AgentToolV2 } from "../tool"; +import { ChatIdentifiers } from "../../../models/routing"; export const TOOL_CALL_NODE_NAME = "action"; @@ -88,7 +88,13 @@ const _handleToolCalls = async ( const handleToolCall = ( toolCall: Required["invocations"][number], - run: Run, + run: { + id: string; + clusterId: string; + modelIdentifier: ChatIdentifiers | null; + resultSchema: unknown | null; + debug: boolean; + }, getTool: ToolFetcher ) => withSpan("run.toolCall", () => _handleToolCall(toolCall, run, getTool), { @@ -100,7 +106,13 @@ const handleToolCall = ( const _handleToolCall = async ( toolCall: Required["invocations"][number], - run: Run, + run: { + id: string; + clusterId: string; + modelIdentifier: ChatIdentifiers | null; + resultSchema: unknown | null; + debug: boolean; + }, getTool: ToolFetcher ): Promise> => { logger.info("Executing tool call"); diff --git a/control-plane/src/modules/runs/agent/run.test.ts b/control-plane/src/modules/runs/agent/run.test.ts index b54b25e3..36759b7e 100644 --- a/control-plane/src/modules/runs/agent/run.test.ts +++ b/control-plane/src/modules/runs/agent/run.test.ts @@ -1,4 +1,3 @@ -import { Run } from "../"; import { buildMockTools, findRelevantTools, formatJobsContext } from "./run"; import { upsertServiceDefinition } from "../../service-definitions"; import { createOwner } from "../../test/util"; @@ -30,6 +29,14 @@ describe("findRelevantTools", () => { clusterId: owner.clusterId, status: "running" as const, attachedFunctions: ["testService_someFunction"], + modelIdentifier: null, + resultSchema: null, + debug: false, + systemPrompt: null, + testMocks: {}, + test: false, + reasoningTraces: false, + enableResultGrounding: false, }; const tools = await findRelevantTools({ @@ -65,7 +72,7 @@ const mockTargetSchema = JSON.stringify({ }); describe("buildMockTools", () => { - let run: Run; + let run: any; const service = "testService"; beforeAll(async () => { diff --git a/control-plane/src/modules/runs/agent/run.ts b/control-plane/src/modules/runs/agent/run.ts index 115afea8..a963403e 100644 --- a/control-plane/src/modules/runs/agent/run.ts +++ b/control-plane/src/modules/runs/agent/run.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import { Run, getWaitingJobIds, updateRun } from "../"; +import { getWaitingJobIds, updateRun } from "../"; import { env } from "../../../utilities/env"; import { NotFoundError } from "../../../utilities/errors"; import { getClusterContextText } from "../../cluster"; @@ -13,6 +13,8 @@ import { embeddableServiceFunction, getServiceDefinitions, serviceFunctionEmbeddingId, + ServiceDefinition, + ServiceDefinitionFunction, } from "../../service-definitions"; import { getRunMessages, insertRunMessage } from "../messages"; import { notifyNewMessage, notifyStatusChange } from "../notify"; @@ -24,11 +26,31 @@ import { getClusterInternalTools } from "./tools/cluster-internal-tools"; import { buildAbstractServiceFunctionTool, buildServiceFunctionTool } from "./tools/functions"; import { buildMockFunctionTool } from "./tools/mock-function"; import { stdlib } from "./tools/stdlib"; +import { ChatIdentifiers } from "../../models/routing"; /** * Run a Run from the most recent saved state **/ -export const processRun = async (run: Run, tags?: Record) => { +export const processRun = async ( + run: { + id: string; + clusterId: string; + modelIdentifier: ChatIdentifiers | null; + resultSchema: unknown | null; + debug: boolean; + attachedFunctions: string[] | null; + status: string; + systemPrompt: string | null; + testMocks: Record }> | null; + test: boolean; + reasoningTraces: boolean; + enableResultGrounding: boolean; + onStatusChange: string | null; + authContext: unknown | null; + context: unknown | null; + }, + tags?: Record +) => { logger.info("Processing Run"); // Parallelize fetching additional context and service definitions @@ -55,8 +77,8 @@ export const processRun = async (run: Run, tags?: Record) => { allAvailableTools.push(...attachedFunctions); - serviceDefinitions.flatMap(service => - (service.definition?.functions ?? []).forEach(f => { + serviceDefinitions.flatMap((service: { service: string; definition: ServiceDefinition }) => + (service.definition?.functions ?? []).forEach((f: ServiceDefinitionFunction) => { // Do not attach additional tools if `attachedFunctions` is provided if (attachedFunctions.length > 0 || f.config?.private) { return; @@ -206,7 +228,14 @@ export const processRun = async (run: Run, tags?: Record) => { const waitingJobs = parsedOutput.data.waitingJobs; await notifyStatusChange({ - run, + run: { + id: run.id, + clusterId: run.clusterId, + onStatusChange: run.onStatusChange, + status: parsedOutput.data.status, + authContext: run.authContext, + context: run.context, + }, status: parsedOutput.data.status, result: parsedOutput.data.result, }); @@ -288,7 +317,16 @@ export const formatJobsContext = ( return `\n${jobEntries}\n`; }; -async function findRelatedFunctionTools(run: Run, search: string) { +async function findRelatedFunctionTools( + run: { + id: string; + clusterId: string; + modelIdentifier: ChatIdentifiers | null; + resultSchema: unknown | null; + debug: boolean; + }, + search: string +) { const flags = await flagsmith?.getIdentityFlags(run.clusterId, { clusterId: run.clusterId, }); @@ -380,7 +418,11 @@ async function findRelatedFunctionTools(run: Run, search: string) { return selectedTools; } -const buildAdditionalContext = async (run: Run) => { +const buildAdditionalContext = async (run: { + id: string; + clusterId: string; + systemPrompt: string | null; +}) => { let context = ""; context += await getClusterContextText(run.clusterId); @@ -466,7 +508,12 @@ export const findRelevantTools = async (state: RunGraphState) => { return tools; }; -export const buildMockTools = async (run: Run) => { +export const buildMockTools = async (run: { + id: string; + clusterId: string; + testMocks: Record }> | null; + test: boolean; +}) => { const mocks: Record = {}; if (!run.testMocks || Object.keys(run.testMocks).length === 0) { return mocks; diff --git a/control-plane/src/modules/runs/agent/state.ts b/control-plane/src/modules/runs/agent/state.ts index 5820e093..011facb0 100644 --- a/control-plane/src/modules/runs/agent/state.ts +++ b/control-plane/src/modules/runs/agent/state.ts @@ -2,7 +2,7 @@ import { StateGraphArgs } from "@langchain/langgraph"; import { InferSelectModel } from "drizzle-orm"; import { UnifiedMessage } from "../../contract"; import { RunMessageMetadata, runs } from "../../data"; -import { Run } from "../"; +import { ChatIdentifiers } from "../../models/routing"; export type RunGraphStateMessage = UnifiedMessage & { persisted?: true; @@ -15,7 +15,20 @@ export type RunGraphState = { status: InferSelectModel["status"]; messages: RunGraphStateMessage[]; additionalContext?: string; - run: Run; + run: { + id: string; + clusterId: string; + modelIdentifier: ChatIdentifiers | null; + resultSchema: unknown | null; + debug: boolean; + attachedFunctions: string[] | null; + status: string; + systemPrompt: string | null; + testMocks: Record }> | null; + test: boolean; + reasoningTraces: boolean; + enableResultGrounding: boolean; + }; waitingJobs: string[]; allAvailableTools: string[]; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -27,7 +40,20 @@ export const createStateGraphChannels = ({ additionalContext, allAvailableTools, }: { - run: Run; + run: { + id: string; + clusterId: string; + modelIdentifier: ChatIdentifiers | null; + resultSchema: unknown | null; + debug: boolean; + attachedFunctions: string[] | null; + status: string; + systemPrompt: string | null; + testMocks: Record }> | null; + test: boolean; + reasoningTraces: boolean; + enableResultGrounding: boolean; + }; allAvailableTools: string[]; additionalContext?: string; }): StateGraphArgs["channels"] => { diff --git a/control-plane/src/modules/runs/agent/tools/cluster-internal-tools.ts b/control-plane/src/modules/runs/agent/tools/cluster-internal-tools.ts index 0f6937db..7fecf3e4 100644 --- a/control-plane/src/modules/runs/agent/tools/cluster-internal-tools.ts +++ b/control-plane/src/modules/runs/agent/tools/cluster-internal-tools.ts @@ -1,12 +1,10 @@ -import { Run } from "../../"; -import { - buildAccessKnowledgeArtifacts, - ACCESS_KNOWLEDGE_ARTIFACTS_TOOL_NAME, -} from "./knowledge-artifacts"; import { createCache } from "../../../../utilities/cache"; import { getClusterDetails } from "../../../management"; -import { buildCurrentDateTimeTool, CURRENT_DATE_TIME_TOOL_NAME } from "./date-time"; import { AgentTool, AgentToolV2 } from "../tool"; +import { + ACCESS_KNOWLEDGE_ARTIFACTS_TOOL_NAME, + buildAccessKnowledgeArtifacts, +} from "./knowledge-artifacts"; import { stdlib } from "./stdlib"; const clusterSettingsCache = createCache<{ @@ -14,7 +12,7 @@ const clusterSettingsCache = createCache<{ }>(Symbol("clusterSettings")); export type InternalToolBuilder = ( - run: Run, + run: any, toolCallId: string ) => AgentTool | Promise | AgentToolV2; // TODO: Standardize on AgentToolV2 diff --git a/control-plane/src/modules/runs/agent/tools/knowledge-artifacts.ts b/control-plane/src/modules/runs/agent/tools/knowledge-artifacts.ts index 48d34087..a05e236d 100644 --- a/control-plane/src/modules/runs/agent/tools/knowledge-artifacts.ts +++ b/control-plane/src/modules/runs/agent/tools/knowledge-artifacts.ts @@ -1,25 +1,18 @@ import { z } from "zod"; import { logger } from "../../../observability/logger"; import { getKnowledge } from "../../../knowledge/knowledgebase"; -import { Run } from "../../"; import * as events from "../../../observability/events"; import { getAllUniqueTags } from "../../../embeddings/embeddings"; import { AgentTool } from "../tool"; export const ACCESS_KNOWLEDGE_ARTIFACTS_TOOL_NAME = "accessKnowledgeArtifacts"; -export const buildAccessKnowledgeArtifacts = async ( - run: Run, -): Promise => { - const tags = await getAllUniqueTags( - run.clusterId, - "knowledgebase-artifact", - ); +export const buildAccessKnowledgeArtifacts = async (run: any): Promise => { + const tags = await getAllUniqueTags(run.clusterId, "knowledgebase-artifact"); return new AgentTool({ name: ACCESS_KNOWLEDGE_ARTIFACTS_TOOL_NAME, - description: - "Retrieves relevant knowledge artifacts based on a given query.", + description: "Retrieves relevant knowledge artifacts based on a given query.", schema: z.object({ query: z.string().describe("The query to search for knowledge artifacts"), tag: @@ -27,7 +20,7 @@ export const buildAccessKnowledgeArtifacts = async ( ? z .enum(tags as [string, ...string[]]) .describe( - "The tag to filter the knowledge artifacts by. If not provided, all artifacts are returned.", + "The tag to filter the knowledge artifacts by. If not provided, all artifacts are returned." ) .optional() : z.undefined().optional(), @@ -48,8 +41,8 @@ export const buildAccessKnowledgeArtifacts = async ( workflowId: run.id, meta: { artifacts: artifacts.map( - (artifact) => - `# ${artifact.title} (${Math.round(artifact.similarity * 100)}%) \n ${artifact.data} \n\n`, + artifact => + `# ${artifact.title} (${Math.round(artifact.similarity * 100)}%) \n ${artifact.data} \n\n` ), }, }); diff --git a/control-plane/src/modules/runs/index.test.ts b/control-plane/src/modules/runs/index.test.ts index c3e08225..ff27c2c7 100644 --- a/control-plane/src/modules/runs/index.test.ts +++ b/control-plane/src/modules/runs/index.test.ts @@ -183,13 +183,19 @@ describe("assertRunReady", () => { }); const updatedRun = await updateRun({ - ...run, + id: run.id, + clusterId: owner.clusterId, status: "done", }); await expect( assertRunReady({ - run: updatedRun, + run: { + id: updatedRun.id, + status: updatedRun.status, + interactive: updatedRun.interactive, + clusterId: owner.clusterId, + }, clusterId: owner.clusterId, }) ).rejects.toThrow(RunBusyError); diff --git a/control-plane/src/modules/runs/index.ts b/control-plane/src/modules/runs/index.ts index 984d7c5c..b571003b 100644 --- a/control-plane/src/modules/runs/index.ts +++ b/control-plane/src/modules/runs/index.ts @@ -1,14 +1,4 @@ -import { - and, - countDistinct, - desc, - eq, - inArray, - InferSelectModel, - isNull, - or, - sql, -} from "drizzle-orm"; +import { and, countDistinct, desc, eq, inArray, isNull, or, sql } from "drizzle-orm"; import { ulid } from "ulid"; import { env } from "../../utilities/env"; import { BadRequestError, NotFoundError, RunBusyError } from "../../utilities/errors"; @@ -23,31 +13,6 @@ import { getRunTags } from "./tags"; export { start, stop } from "./queues"; -export type Run = { - id: string; - clusterId: string; - status?: "pending" | "running" | "paused" | "done" | "failed" | null; - name?: string | null; - agentId?: string | null; - systemPrompt?: string | null; - failureReason?: string | null; - debug?: boolean; - test?: boolean; - testMocks?: InferSelectModel["test_mocks"]; - feedbackComment?: string | null; - feedbackScore?: number | null; - resultSchema?: unknown | null; - attachedFunctions?: string[] | null; - onStatusChange?: string | null; - interactive?: boolean; - reasoningTraces?: boolean; - enableSummarization?: boolean; - modelIdentifier?: ChatIdentifiers | null; - authContext?: unknown | null; - context?: unknown | null; - enableResultGrounding?: boolean; -}; - export const createRun = async ({ runId, userId, @@ -95,7 +60,7 @@ export const createRun = async ({ authContext?: unknown; context?: unknown; enableResultGrounding?: boolean; -}): Promise => { +}) => { // Insert the run with a subquery for debug value const [run] = await db .insert(runs) @@ -136,6 +101,7 @@ export const createRun = async ({ context: runs.context, interactive: runs.interactive, enableResultGrounding: runs.enable_result_grounding, + agentId: runs.agent_id, }); if (!run) { @@ -161,20 +127,30 @@ export const deleteRun = async ({ clusterId, runId }: { clusterId: string; runId await db.delete(runs).where(and(eq(runs.cluster_id, clusterId), eq(runs.id, runId))); }; -export const updateRun = async (run: Run): Promise => { +export const updateRun = async (run: { + id: string; + clusterId: string; + name?: string; + status?: "pending" | "running" | "paused" | "done" | "failed" | null; + failureReason?: string; + feedbackComment?: string; + feedbackScore?: number; +}) => { + const updateSet = { + name: run.name ?? undefined, + status: run.status ?? undefined, + failure_reason: run.failureReason ?? undefined, + feedback_comment: run.feedbackComment ?? undefined, + feedback_score: run.feedbackScore ?? undefined, + }; + if (run.status && run.status !== "failed") { - run.failureReason = null; + delete updateSet.failure_reason; } const [updated] = await db .update(runs) - .set({ - name: !run.name ? undefined : run.name, - status: run.status, - failure_reason: run.failureReason, - feedback_comment: run.feedbackComment, - feedback_score: run.feedbackScore, - }) + .set(updateSet) .where(and(eq(runs.cluster_id, run.clusterId), eq(runs.id, run.id))) .returning({ id: runs.id, @@ -186,6 +162,8 @@ export const updateRun = async (run: Run): Promise => { attachedFunctions: runs.attached_functions, authContext: runs.auth_context, context: runs.context, + interactive: runs.interactive, + enableResultGrounding: runs.enable_result_grounding, }); // Send telemetry event if feedback was updated @@ -336,7 +314,13 @@ export const addMessageAndResumeWithRun = async ({ type: "human" | "template" | "supervisor"; metadata?: RunMessageMetadata; skipAssert?: boolean; - run: Run; + run: { + id: string; + name: string | null; + status: string; + interactive: boolean; + clusterId: string; + }; }) => { if (!skipAssert) { await assertRunReady({ clusterId, run }); @@ -355,7 +339,14 @@ export const addMessageAndResumeWithRun = async ({ }); // TODO: Move run name generation to event sourcing (pg-listen) https://github.com/inferablehq/inferable/issues/390 - await generateRunName(run, message); + await generateRunName( + { + id: run.id, + clusterId: run.clusterId, + name: run.name, + }, + message + ); await resumeRun({ clusterId, @@ -396,7 +387,7 @@ export const addMessageAndResume = async ({ }); }; -export const resumeRun = async (input: Pick) => { +export const resumeRun = async (input: { id: string; clusterId: string }) => { if (env.NODE_ENV === "test") { logger.warn("Skipping run resume. NODE_ENV is set to 'test'."); return; @@ -424,7 +415,14 @@ export const resumeRun = async (input: Pick) => { }); }; -export const generateRunName = async (run: Run, content: string) => { +export const generateRunName = async ( + run: { + id: string; + clusterId: string; + name: string | null; + }, + content: string +) => { if (env.NODE_ENV === "test") { logger.warn("Skipping run resume. NODE_ENV is set to 'test'."); return; @@ -531,7 +529,15 @@ export const getClusterBackgroundRun = (clusterId: string) => { return `${clusterId}BACKGROUND`; }; -export const assertRunReady = async (input: { clusterId: string; run: Run }) => { +export const assertRunReady = async (input: { + clusterId: string; + run: { + id: string; + status: string; + interactive: boolean; + clusterId: string; + }; +}) => { const run = input.run; if (!run) { diff --git a/control-plane/src/modules/runs/notify.ts b/control-plane/src/modules/runs/notify.ts index b256aad3..fe33cbf6 100644 --- a/control-plane/src/modules/runs/notify.ts +++ b/control-plane/src/modules/runs/notify.ts @@ -3,7 +3,7 @@ import * as jobs from "../jobs/jobs"; import { logger } from "../observability/logger"; import { packer } from "../packer"; import { getRunTags } from "./tags"; -import { getClusterBackgroundRun, Run } from "./"; +import { getClusterBackgroundRun } from "./"; import { runMessages } from "../data"; import * as slack from "../integrations/slack"; import * as email from "../email"; @@ -47,7 +47,14 @@ export const notifyStatusChange = async ({ status, result, }: { - run: Run; + run: { + id: string; + clusterId: string; + onStatusChange: string | null; + status: string; + authContext: unknown; + context: unknown; + }; status: string; result?: unknown; }) => { diff --git a/control-plane/src/modules/runs/router.ts b/control-plane/src/modules/runs/router.ts index 3a6f1170..c20c6554 100644 --- a/control-plane/src/modules/runs/router.ts +++ b/control-plane/src/modules/runs/router.ts @@ -249,8 +249,9 @@ export const runsRouter = initServer().router( await updateRun({ id: runId, clusterId, - feedbackComment: comment, - feedbackScore: score, + feedbackComment: comment ?? undefined, + feedbackScore: score ?? undefined, + status: null, }); events.write({ diff --git a/control-plane/src/modules/runs/summarization.ts b/control-plane/src/modules/runs/summarization.ts index b92f288e..9d503f8f 100644 --- a/control-plane/src/modules/runs/summarization.ts +++ b/control-plane/src/modules/runs/summarization.ts @@ -1,5 +1,4 @@ import { z } from "zod"; -import { Run } from "./"; import { logger } from "../observability/logger"; import { RetryableError } from "../../utilities/errors"; import { buildModel } from "../models"; @@ -7,8 +6,12 @@ import { zodToJsonSchema } from "zod-to-json-schema"; export const generateTitle = async ( message: string, - run: Run, - words: number = 10, + run: { + id: string; + clusterId: string; + name: string | null; + }, + words: number = 10 ): Promise<{ summary: string }> => { const system = [ `You are a title generation assistant that is capable of succintly summarizing a set of messages in a single sentence. The title should be no more than ${words} words. Generate title for the following messages. Use identifying information such as names, dates, and locations if necessary. Good examples: From 843552c49d5d7e5272e50c374347cf696b0fe1e9 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 10 Jan 2025 00:10:03 +1100 Subject: [PATCH 4/9] fix: Refactor CSS class handling for message component (#528) --- app/components/chat/human-message.tsx | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/app/components/chat/human-message.tsx b/app/components/chat/human-message.tsx index 0f948f14..6546661a 100644 --- a/app/components/chat/human-message.tsx +++ b/app/components/chat/human-message.tsx @@ -1,5 +1,5 @@ import { client } from "@/client/client"; -import { createErrorToast } from "@/lib/utils"; +import { cn, createErrorToast } from "@/lib/utils"; import { useAuth } from "@clerk/nextjs"; import { formatRelative } from "date-fns"; import { Mail, PencilLineIcon, RefreshCcw, Slack, User } from "lucide-react"; @@ -139,9 +139,10 @@ export function HumanMessage({ return (
From a4a3a94bc24c988ace070048278c36c8b45f1531 Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Fri, 10 Jan 2025 08:26:52 +1100 Subject: [PATCH 5/9] fix: Remove knowledge artifact related code and components (#530) * fix: Remove knowledge artifact related code and components * fix: Minor changes 2025-01-09-22:23:43 * fix: Minor changes 2025-01-09-23:20:58 * fix: Minor changes 2025-01-09-23:20:58 * fix: Minor changes 2025-01-09-23:20:58 * fix: Minor changes 2025-01-09-23:20:58 --- .../knowledge/[artifactId]/edit/page.tsx | 80 ------ .../clusters/[clusterId]/knowledge/layout.tsx | 270 ------------------ .../[clusterId]/knowledge/new/page.tsx | 54 ---- .../clusters/[clusterId]/knowledge/page.tsx | 13 - .../[clusterId]/settings/details/page.tsx | 104 ++----- app/client/contract.ts | 117 -------- app/components/KnowledgeArtifactForm.tsx | 149 ---------- cli/src/client/contract.ts | 96 ------- control-plane/src/modules/contract.ts | 117 -------- .../modules/knowledge/knowledgebase.test.ts | 69 ----- .../src/modules/knowledge/knowledgebase.ts | 134 --------- .../src/modules/observability/events.ts | 1 - control-plane/src/modules/router.ts | 120 -------- control-plane/src/modules/runs/agent/agent.ts | 6 +- .../modules/runs/agent/nodes/model-output.ts | 4 +- .../modules/runs/agent/nodes/system-prompt.ts | 7 +- .../src/modules/runs/agent/nodes/tool-call.ts | 4 +- control-plane/src/modules/runs/agent/run.ts | 37 +-- control-plane/src/modules/runs/agent/tool.ts | 64 +---- .../modules/runs/agent/tools/calculator.ts | 55 ++-- .../agent/tools/cluster-internal-tools.ts | 49 ---- .../src/modules/runs/agent/tools/date-time.ts | 31 +- .../src/modules/runs/agent/tools/get-url.ts | 107 +++---- .../runs/agent/tools/knowledge-artifacts.ts | 68 ----- .../src/modules/runs/agent/tools/stdlib.ts | 41 +-- 25 files changed, 163 insertions(+), 1634 deletions(-) delete mode 100644 app/app/clusters/[clusterId]/knowledge/[artifactId]/edit/page.tsx delete mode 100644 app/app/clusters/[clusterId]/knowledge/layout.tsx delete mode 100644 app/app/clusters/[clusterId]/knowledge/new/page.tsx delete mode 100644 app/app/clusters/[clusterId]/knowledge/page.tsx delete mode 100644 app/components/KnowledgeArtifactForm.tsx delete mode 100644 control-plane/src/modules/knowledge/knowledgebase.test.ts delete mode 100644 control-plane/src/modules/knowledge/knowledgebase.ts delete mode 100644 control-plane/src/modules/runs/agent/tools/cluster-internal-tools.ts delete mode 100644 control-plane/src/modules/runs/agent/tools/knowledge-artifacts.ts diff --git a/app/app/clusters/[clusterId]/knowledge/[artifactId]/edit/page.tsx b/app/app/clusters/[clusterId]/knowledge/[artifactId]/edit/page.tsx deleted file mode 100644 index 5258d486..00000000 --- a/app/app/clusters/[clusterId]/knowledge/[artifactId]/edit/page.tsx +++ /dev/null @@ -1,80 +0,0 @@ -"use client"; - -import { useState, useEffect } from "react"; -import { useAuth } from "@clerk/nextjs"; -import { client } from "@/client/client"; -import toast from "react-hot-toast"; -import { useParams, useRouter } from "next/navigation"; -import { - KnowledgeArtifactForm, - KnowledgeArtifact, -} from "@/components/KnowledgeArtifactForm"; -import { createErrorToast } from "@/lib/utils"; - -export default function EditKnowledgeArtifact() { - const { getToken } = useAuth(); - const params = useParams(); - const router = useRouter(); - const clusterId = params?.clusterId as string; - const artifactId = params?.artifactId as string; - const [artifact, setArtifact] = useState(null); - - useEffect(() => { - const fetchArtifact = async () => { - try { - const token = await getToken(); - const response = await client.getKnowledgeArtifact({ - params: { clusterId, artifactId }, - headers: { authorization: token as string }, - }); - - if (response.status === 200) { - setArtifact(response.body); - } else { - createErrorToast(response, "Artifact not found"); - router.push(`/clusters/${clusterId}/knowledge`); - } - } catch (error) { - console.error("Error fetching artifact:", error); - toast.error("Failed to fetch knowledge artifact"); - } - }; - - fetchArtifact(); - }, [getToken, clusterId, artifactId, router]); - - const handleUpdate = async (updatedArtifact: KnowledgeArtifact) => { - try { - const token = await getToken(); - - const response = await client.upsertKnowledgeArtifact({ - params: { clusterId, artifactId }, - headers: { authorization: token as string }, - body: updatedArtifact, - }); - - if (response.status === 200) { - toast.success("Knowledge artifact updated successfully"); - } else { - createErrorToast(response, "Failed to update knowledge artifact"); - } - } catch (error) { - console.error("Error updating artifact:", error); - toast.error("Failed to update knowledge artifact"); - } - }; - - if (!artifact) { - return null; - } - - return ( - router.push(`/clusters/${clusterId}/knowledge`)} - submitButtonText="Update Artifact" - editing={true} - /> - ); -} diff --git a/app/app/clusters/[clusterId]/knowledge/layout.tsx b/app/app/clusters/[clusterId]/knowledge/layout.tsx deleted file mode 100644 index c6700e61..00000000 --- a/app/app/clusters/[clusterId]/knowledge/layout.tsx +++ /dev/null @@ -1,270 +0,0 @@ -"use client"; - -import { client } from "@/client/client"; -import { Button } from "@/components/ui/button"; -import { - Dialog, - DialogContent, - DialogHeader, - DialogTitle, - DialogTrigger, -} from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { ScrollArea } from "@/components/ui/scroll-area"; -import { Textarea } from "@/components/ui/textarea"; -import { useAuth } from "@clerk/nextjs"; -import { truncate } from "lodash"; -import { GlobeIcon, PlusIcon, TrashIcon, UploadIcon } from "lucide-react"; -import Link from "next/link"; -import { useParams, useRouter } from "next/navigation"; -import { useCallback, useEffect, useState } from "react"; -import toast from "react-hot-toast"; - -type KnowledgeArtifact = { - id: string; - data: string; - tags: string[]; - title: string; - similarity?: number; -}; - -export default function KnowledgeLayout({ - children, -}: { - children: React.ReactNode; -}) { - const { getToken } = useAuth(); - const params = useParams(); - const router = useRouter(); - const clusterId = params?.clusterId as string; - const [artifacts, setArtifacts] = useState([]); - const [searchQuery, setSearchQuery] = useState(""); - const [isUploadDialogOpen, setIsUploadDialogOpen] = useState(false); - const [bulkUploadJson, setBulkUploadJson] = useState(""); - - const fetchArtifacts = useCallback(async () => { - try { - const token = await getToken(); - const response = await client.listKnowledgeArtifacts({ - params: { clusterId }, - headers: { authorization: token as string }, - query: { query: searchQuery, limit: 20 }, - }); - - if (response.status === 200) { - setArtifacts(response.body); - } - } catch (error) { - console.error("Error fetching artifacts:", error); - toast.error("Failed to fetch knowledge artifacts"); - } - }, [getToken, clusterId, searchQuery]); - - useEffect(() => { - fetchArtifacts(); - }, [fetchArtifacts]); - - const handleSearch = () => { - fetchArtifacts(); - }; - - const handleDelete = async (id: string) => { - try { - const token = await getToken(); - const response = await client.deleteKnowledgeArtifact({ - params: { clusterId, artifactId: id }, - headers: { authorization: token as string }, - }); - - if (response.status === 204) { - toast.success("Knowledge artifact deleted successfully"); - fetchArtifacts(); - } - } catch (error) { - console.error("Error deleting artifact:", error); - toast.error("Failed to delete knowledge artifact"); - } - }; - - const handleBulkUpload = async () => { - try { - const artifacts = JSON.parse(bulkUploadJson); - if (!Array.isArray(artifacts)) { - throw new Error("Invalid JSON format. Expected an array of artifacts."); - } - - const loading = toast.loading("Uploading knowledge artifacts..."); - - const token = await getToken(); - let successCount = 0; - let failCount = 0; - - for (const artifact of artifacts) { - try { - const response = await client.upsertKnowledgeArtifact({ - params: { clusterId, artifactId: artifact.id }, - headers: { authorization: token as string }, - body: artifact, - }); - - if (response.status === 201) { - successCount++; - } else { - failCount++; - } - } catch (error) { - console.error("Error creating artifact:", error); - failCount++; - } - } - - toast.remove(loading); - toast.success( - `Bulk upload completed. Success: ${successCount}, Failed: ${failCount}`, - ); - setIsUploadDialogOpen(false); - fetchArtifacts(); - } catch (error) { - console.error("Error parsing JSON:", error); - toast.error( - "Failed to parse JSON. Please check the format and try again.", - ); - } - }; - - return ( -
-
-
-
- - - -
- - - - - - - - Bulk Upload Knowledge Artifacts - -
-

- Paste your JSON array of knowledge artifacts below. Note - that items with the same ID will be overwritten. -

-