diff --git a/.gitignore b/.gitignore index 7e2875831..ab29d7636 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ build dist node_modules +examples/*/package-lock.json \ No newline at end of file diff --git a/examples/basic/src/recipes/information-extraction-openai-chat-functions.ts b/examples/basic/src/recipes/information-extraction-openai-chat-functions.ts index a3f07b4d1..2c0a3f7a1 100644 --- a/examples/basic/src/recipes/information-extraction-openai-chat-functions.ts +++ b/examples/basic/src/recipes/information-extraction-openai-chat-functions.ts @@ -49,13 +49,13 @@ dotenv.config(); fs.readFileSync("data/san-francisco-wikipedia.json", "utf8") ).content; - const { value: extractedInformation1 } = await extractNameAndPopulation( + const extractedInformation1 = await extractNameAndPopulation( sanFranciscoWikipedia.slice(0, 2000) ); console.log(extractedInformation1); - const { value: extractedInformation2 } = await extractNameAndPopulation( + const extractedInformation2 = await extractNameAndPopulation( "Carl was a friendly robot." ); diff --git a/examples/basic/src/tool/execute-tool-example.ts b/examples/basic/src/tool/execute-tool-example.ts new file mode 100644 index 000000000..6d4859369 --- /dev/null +++ b/examples/basic/src/tool/execute-tool-example.ts @@ -0,0 +1,41 @@ +import dotenv from "dotenv"; +import { Tool, executeTool } from "modelfusion"; +import { z } from "zod"; + +dotenv.config(); + +(async () => { + const calculator = new Tool({ + name: "calculator" as const, // mark 'as const' for type inference + description: "Execute a calculation", + + inputSchema: z.object({ + a: z.number().describe("The first number."), + b: z.number().describe("The second number."), + operator: z.enum(["+", "-", "*", "/"]).describe("The operator."), + }), + + execute: async ({ a, b, operator }) => { + switch (operator) { + case "+": + return a + b; + case "-": + return a - b; + case "*": + return a * b; + case "/": + return a / b; + default: + throw new Error(`Unknown operator: ${operator}`); + } + }, + }); + + const result = await executeTool(calculator, { + a: 14, + b: 12, + operator: "*" as const, + }); + + console.log(`Result: ${result}`); +})(); diff --git a/examples/basic/src/tool/execute-tool-full-example.ts b/examples/basic/src/tool/execute-tool-full-example.ts new file mode 100644 index 000000000..44832e4f2 --- /dev/null +++ b/examples/basic/src/tool/execute-tool-full-example.ts @@ -0,0 +1,46 @@ +import dotenv from "dotenv"; +import { Tool, executeTool } from "modelfusion"; +import { z } from "zod"; + +dotenv.config(); + +(async () => { + const calculator = new Tool({ + name: "calculator" as const, // mark 'as const' for type inference + description: "Execute a calculation", + + inputSchema: z.object({ + a: z.number().describe("The first number."), + b: z.number().describe("The second number."), + operator: z.enum(["+", "-", "*", "/"]).describe("The operator."), + }), + + execute: async ({ a, b, operator }) => { + switch (operator) { + case "+": + return a + b; + case "-": + return a - b; + case "*": + return a * b; + case "/": + return a / b; + default: + throw new Error(`Unknown operator: ${operator}`); + } + }, + }); + + const { metadata, output } = await executeTool( + calculator, + { + a: 14, + b: 12, + operator: "*" as const, + }, + { fullResponse: true } + ); + + console.log(`Result: ${output}`); + console.log(`Duration: ${metadata.durationInMs}ms`); +})(); diff --git a/examples/basic/src/composed-function/use-tool-example.ts b/examples/basic/src/tool/use-tool-example.ts similarity index 100% rename from examples/basic/src/composed-function/use-tool-example.ts rename to examples/basic/src/tool/use-tool-example.ts diff --git a/examples/basic/src/composed-function/use-tool-or-generate-text-example.ts b/examples/basic/src/tool/use-tool-or-generate-text-example.ts similarity index 100% rename from examples/basic/src/composed-function/use-tool-or-generate-text-example.ts rename to examples/basic/src/tool/use-tool-or-generate-text-example.ts diff --git a/examples/basic/src/composed-function/use-tool-or-generate-text-with-custom-prompt-example.ts b/examples/basic/src/tool/use-tool-or-generate-text-with-custom-prompt-example.ts similarity index 100% rename from examples/basic/src/composed-function/use-tool-or-generate-text-with-custom-prompt-example.ts rename to examples/basic/src/tool/use-tool-or-generate-text-with-custom-prompt-example.ts diff --git a/examples/pdf-to-tweet/src/main.ts b/examples/pdf-to-tweet/src/main.ts index 71a2fbd61..bd5ef49d3 100644 --- a/examples/pdf-to-tweet/src/main.ts +++ b/examples/pdf-to-tweet/src/main.ts @@ -20,7 +20,7 @@ const run = new DefaultRun({ costCalculators: [new OpenAICostCalculator()], observers: [ { - onModelCallStarted(event) { + onRunFunctionStarted(event) { if (event.type === "text-generation-started") { console.log( `Generate text ${event.metadata.functionId ?? "unknown"} started.` @@ -32,7 +32,7 @@ const run = new DefaultRun({ } }, - onModelCallFinished(event) { + onRunFunctionFinished(event) { if (event.type === "text-generation-finished") { console.log( `Generate text ${event.metadata.functionId ?? "unknown"} finished.` diff --git a/src/composed-function/index.ts b/src/composed-function/index.ts index 2c018a1ad..11366b3a7 100644 --- a/src/composed-function/index.ts +++ b/src/composed-function/index.ts @@ -1,6 +1,3 @@ -export * from "./use-tool/NoSuchToolError.js"; -export * from "./use-tool/Tool.js"; -export * from "./use-tool/useTool.js"; export * from "./summarize/SummarizationFunction.js"; export * from "./summarize/summarizeRecursively.js"; export * from "./summarize/summarizeRecursivelyWithTextGenerationAndTokenSplitting.js"; diff --git a/src/index.ts b/src/index.ts index af957c9e4..a1b91b105 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,5 +5,6 @@ export * from "./model-provider/index.js"; export * from "./prompt/index.js"; export * from "./run/index.js"; export * from "./text-chunk/index.js"; +export * from "./tool/index.js"; export * from "./util/index.js"; export * from "./vector-index/index.js"; diff --git a/src/model-function/Model.ts b/src/model-function/Model.ts index bae6205fc..4570ea3b9 100644 --- a/src/model-function/Model.ts +++ b/src/model-function/Model.ts @@ -1,8 +1,8 @@ import { ModelInformation } from "./ModelInformation.js"; -import { ModelCallObserver } from "./ModelCallObserver.js"; +import { RunFunctionObserver } from "../run/RunFunctionObserver.js"; export interface ModelSettings { - observers?: Array; + observers?: Array; } export interface Model { diff --git a/src/model-function/ModelCallEvent.ts b/src/model-function/ModelCallEvent.ts index 20e2f5b1a..d163d091f 100644 --- a/src/model-function/ModelCallEvent.ts +++ b/src/model-function/ModelCallEvent.ts @@ -1,4 +1,7 @@ -import { IdMetadata } from "../run/IdMetadata.js"; +import { + RunFunctionFinishedEventMetadata, + RunFunctionStartedEventMetadata, +} from "../run/RunFunctionEvent.js"; import { ModelInformation } from "./ModelInformation.js"; import { TextEmbeddingFinishedEvent, @@ -25,11 +28,8 @@ import { TranscriptionStartedEvent, } from "./transcribe-audio/TranscriptionEvent.js"; -export type ModelCallEvent = ModelCallStartedEvent | ModelCallFinishedEvent; - -export type ModelCallStartedEventMetadata = IdMetadata & { +export type ModelCallStartedEventMetadata = RunFunctionStartedEventMetadata & { model: ModelInformation; - startEpochSeconds: number; }; export type ModelCallStartedEvent = @@ -40,9 +40,10 @@ export type ModelCallStartedEvent = | TextStreamingStartedEvent | TranscriptionStartedEvent; -export type ModelCallFinishedEventMetadata = ModelCallStartedEventMetadata & { - durationInMs: number; -}; +export type ModelCallFinishedEventMetadata = + RunFunctionFinishedEventMetadata & { + model: ModelInformation; + }; export type ModelCallFinishedEvent = | ImageGenerationFinishedEvent diff --git a/src/model-function/ModelCallObserver.ts b/src/model-function/ModelCallObserver.ts deleted file mode 100644 index e1be9626f..000000000 --- a/src/model-function/ModelCallObserver.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { - ModelCallFinishedEvent, - ModelCallStartedEvent, -} from "./ModelCallEvent.js"; - -export type ModelCallObserver = { - onModelCallStarted?: (event: ModelCallStartedEvent) => void; - onModelCallFinished?: (event: ModelCallFinishedEvent) => void; -}; diff --git a/src/model-function/SuccessfulModelCall.ts b/src/model-function/SuccessfulModelCall.ts index 2aab387f3..82e50b6d3 100644 --- a/src/model-function/SuccessfulModelCall.ts +++ b/src/model-function/SuccessfulModelCall.ts @@ -1,13 +1,12 @@ -import { - ModelCallFinishedEvent, - ModelCallStartedEvent, -} from "./ModelCallEvent.js"; +import { RunFunctionEvent } from "../run/RunFunctionEvent.js"; +import { ModelCallFinishedEvent } from "./ModelCallEvent.js"; import { ModelInformation } from "./ModelInformation.js"; export type SuccessfulModelCall = { type: | "image-generation" | "json-generation" + | "json-or-text-generation" | "text-embedding" | "text-generation" | "text-streaming" @@ -18,12 +17,14 @@ export type SuccessfulModelCall = { }; export function extractSuccessfulModelCalls( - modelCallEvents: (ModelCallFinishedEvent | ModelCallStartedEvent)[] + runFunctionEvents: RunFunctionEvent[] ) { - return modelCallEvents + return runFunctionEvents .filter( (event): event is ModelCallFinishedEvent & { status: "success" } => - "status" in event && event.status === "success" + Object.keys(eventTypeToCostType).includes(event.type) && + "status" in event && + event.status === "success" ) .map( (event): SuccessfulModelCall => ({ @@ -38,6 +39,7 @@ export function extractSuccessfulModelCalls( const eventTypeToCostType = { "image-generation-finished": "image-generation" as const, "json-generation-finished": "json-generation" as const, + "json-or-text-generation-finished": "json-or-text-generation" as const, "text-embedding-finished": "text-embedding" as const, "text-generation-finished": "text-generation" as const, "text-streaming-finished": "text-streaming" as const, diff --git a/src/model-function/executeCall.ts b/src/model-function/executeCall.ts index 1665105d5..7c3593187 100644 --- a/src/model-function/executeCall.ts +++ b/src/model-function/executeCall.ts @@ -1,4 +1,5 @@ import { nanoid as createId } from "nanoid"; +import { RunFunctionEventSource } from "../run/RunFunctionEventSource.js"; import { startDurationMeasurement } from "../util/DurationMeasurement.js"; import { AbortError } from "../util/api/AbortError.js"; import { runSafe } from "../util/runSafe.js"; @@ -10,7 +11,6 @@ import { ModelCallStartedEvent, ModelCallStartedEventMetadata, } from "./ModelCallEvent.js"; -import { ModelCallEventSource } from "./ModelCallEventSource.js"; export type CallMetadata> = { callId: string; @@ -79,7 +79,7 @@ export async function executeCall< const run = options?.run; const settings = model.settings; - const eventSource = new ModelCallEventSource({ + const eventSource = new RunFunctionEventSource({ observers: [...(settings.observers ?? []), ...(run?.observers ?? [])], errorHandler: run?.errorHandler, }); @@ -96,7 +96,7 @@ export async function executeCall< startEpochSeconds: durationMeasurement.startEpochSeconds, }; - eventSource.notifyModelCallStarted(getStartEvent(startMetadata, settings)); + eventSource.notifyRunFunctionStarted(getStartEvent(startMetadata, settings)); const result = await runSafe(() => generateResponse({ @@ -113,13 +113,13 @@ export async function executeCall< if (!result.ok) { if (result.isAborted) { - eventSource.notifyModelCallFinished( + eventSource.notifyRunFunctionFinished( getAbortEvent(finishMetadata, settings) ); throw new AbortError(); } - eventSource.notifyModelCallFinished( + eventSource.notifyRunFunctionFinished( getFailureEvent(finishMetadata, settings, result.error) ); throw result.error; @@ -128,7 +128,7 @@ export async function executeCall< const response = result.output; const output = extractOutputValue(response); - eventSource.notifyModelCallFinished( + eventSource.notifyRunFunctionFinished( getSuccessEvent(finishMetadata, settings, response, output) ); diff --git a/src/model-function/generate-json/JsonGenerationEvent.ts b/src/model-function/generate-json/JsonGenerationEvent.ts index 4f4bafdaf..24fad3fe7 100644 --- a/src/model-function/generate-json/JsonGenerationEvent.ts +++ b/src/model-function/generate-json/JsonGenerationEvent.ts @@ -4,14 +4,14 @@ import { } from "../ModelCallEvent.js"; export type JsonGenerationStartedEvent = { - type: "json-generation-started"; + type: "json-generation-started" | "json-or-text-generation-started"; metadata: ModelCallStartedEventMetadata; settings: unknown; prompt: unknown; }; export type JsonGenerationFinishedEvent = { - type: "json-generation-finished"; + type: "json-generation-finished" | "json-or-text-generation-finished"; metadata: ModelCallFinishedEventMetadata; settings: unknown; prompt: unknown; diff --git a/src/model-function/generate-json/generateJsonOrText.ts b/src/model-function/generate-json/generateJsonOrText.ts index b42108c5e..296479beb 100644 --- a/src/model-function/generate-json/generateJsonOrText.ts +++ b/src/model-function/generate-json/generateJsonOrText.ts @@ -135,20 +135,20 @@ export async function generateJsonOrText< }; }, getStartEvent: (metadata, settings) => ({ - type: "json-generation-started", + type: "json-or-text-generation-started", metadata, settings, prompt, }), getAbortEvent: (metadata, settings) => ({ - type: "json-generation-finished", + type: "json-or-text-generation-finished", status: "abort", metadata, settings, prompt, }), getFailureEvent: (metadata, settings, error) => ({ - type: "json-generation-finished", + type: "json-or-text-generation-finished", status: "failure", metadata, settings, @@ -156,7 +156,7 @@ export async function generateJsonOrText< error, }), getSuccessEvent: (metadata, settings, response, output) => ({ - type: "json-generation-finished", + type: "json-or-text-generation-finished", status: "success", metadata, settings, diff --git a/src/model-function/generate-text/streamText.ts b/src/model-function/generate-text/streamText.ts index 941d5a596..608d99bfb 100644 --- a/src/model-function/generate-text/streamText.ts +++ b/src/model-function/generate-text/streamText.ts @@ -1,15 +1,16 @@ import { nanoid as createId } from "nanoid"; +import { RunFunctionEventSource } from "../../run/RunFunctionEventSource.js"; import { startDurationMeasurement } from "../../util/DurationMeasurement.js"; import { AbortError } from "../../util/api/AbortError.js"; import { runSafe } from "../../util/runSafe.js"; import { FunctionOptions } from "../FunctionOptions.js"; -import { ModelCallEventSource } from "../ModelCallEventSource.js"; import { CallMetadata } from "../executeCall.js"; import { DeltaEvent } from "./DeltaEvent.js"; import { TextGenerationModel, TextGenerationModelSettings, } from "./TextGenerationModel.js"; +import { TextStreamingFinishedEvent } from "./TextStreamingEvent.js"; import { extractTextDeltas } from "./extractTextDeltas.js"; export async function streamText< @@ -91,7 +92,7 @@ export async function streamText< const run = options?.run; const settings = model.settings; - const eventSource = new ModelCallEventSource({ + const eventSource = new RunFunctionEventSource({ observers: [...(settings.observers ?? []), ...(run?.observers ?? [])], errorHandler: run?.errorHandler, }); @@ -108,7 +109,7 @@ export async function streamText< startEpochSeconds: durationMeasurement.startEpochSeconds, }; - eventSource.notifyModelCallStarted({ + eventSource.notifyRunFunctionStarted({ type: "text-streaming-started", metadata: startMetadata, settings, @@ -129,7 +130,7 @@ export async function streamText< durationInMs: durationMeasurement.durationInMs, }; - eventSource.notifyModelCallFinished({ + eventSource.notifyRunFunctionFinished({ type: "text-streaming-finished", status: "success", metadata: finishMetadata, @@ -137,7 +138,7 @@ export async function streamText< prompt, response: lastFullDelta, generatedText: fullText, - }); + } as TextStreamingFinishedEvent); }, onError: (error) => { const finishMetadata = { @@ -145,7 +146,7 @@ export async function streamText< durationInMs: durationMeasurement.durationInMs, }; - eventSource.notifyModelCallFinished( + eventSource.notifyRunFunctionFinished( error instanceof AbortError ? { type: "text-streaming-finished", @@ -174,7 +175,7 @@ export async function streamText< }; if (result.isAborted) { - eventSource.notifyModelCallFinished({ + eventSource.notifyRunFunctionFinished({ type: "text-streaming-finished", status: "abort", metadata: finishMetadata, @@ -184,7 +185,7 @@ export async function streamText< throw new AbortError(); } - eventSource.notifyModelCallFinished({ + eventSource.notifyRunFunctionFinished({ type: "text-streaming-finished", status: "failure", metadata: finishMetadata, diff --git a/src/model-function/index.ts b/src/model-function/index.ts index 64b44864d..705836d6f 100644 --- a/src/model-function/index.ts +++ b/src/model-function/index.ts @@ -1,7 +1,7 @@ export * from "./FunctionOptions.js"; export * from "./Model.js"; export * from "./ModelCallEvent.js"; -export * from "./ModelCallObserver.js"; +export * from "../run/RunFunctionObserver.js"; export * from "./ModelInformation.js"; export * from "./SuccessfulModelCall.js"; export * from "./embed-text/TextEmbeddingEvent.js"; diff --git a/src/model-provider/openai/chat/OpenAIChatPrompt.ts b/src/model-provider/openai/chat/OpenAIChatPrompt.ts index 5258b65a0..33e780239 100644 --- a/src/model-provider/openai/chat/OpenAIChatPrompt.ts +++ b/src/model-provider/openai/chat/OpenAIChatPrompt.ts @@ -1,10 +1,10 @@ import SecureJSON from "secure-json-parse"; import z from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; -import { Tool } from "../../../composed-function/use-tool/Tool.js"; import { GenerateJsonPrompt } from "../../../model-function/generate-json/GenerateJsonModel.js"; import { GenerateJsonOrTextPrompt } from "../../../model-function/generate-json/GenerateJsonOrTextModel.js"; import { SchemaDefinition } from "../../../model-function/generate-json/SchemaDefinition.js"; +import { Tool } from "../../../tool/Tool.js"; import { OpenAIChatMessage } from "./OpenAIChatMessage.js"; import { OpenAIChatResponse } from "./OpenAIChatModel.js"; diff --git a/src/run/ConsoleLogger.ts b/src/run/ConsoleLogger.ts index 4c9d7b5fd..1ea0565b3 100644 --- a/src/run/ConsoleLogger.ts +++ b/src/run/ConsoleLogger.ts @@ -1,15 +1,15 @@ import { - ModelCallFinishedEvent, - ModelCallStartedEvent, -} from "../model-function/ModelCallEvent.js"; -import { ModelCallObserver } from "../model-function/ModelCallObserver.js"; + RunFunctionFinishedEvent, + RunFunctionStartedEvent, +} from "./RunFunctionEvent.js"; +import { RunFunctionObserver } from "./RunFunctionObserver.js"; -export class ConsoleLogger implements ModelCallObserver { - onModelCallStarted(event: ModelCallStartedEvent) { +export class ConsoleLogger implements RunFunctionObserver { + onRunFunctionStarted(event: RunFunctionStartedEvent) { console.log(JSON.stringify(event, null, 2)); } - onModelCallFinished(event: ModelCallFinishedEvent) { + onRunFunctionFinished(event: RunFunctionFinishedEvent) { console.log(JSON.stringify(event, null, 2)); } } diff --git a/src/run/DefaultRun.ts b/src/run/DefaultRun.ts index ffd298551..4ea875840 100644 --- a/src/run/DefaultRun.ts +++ b/src/run/DefaultRun.ts @@ -1,16 +1,14 @@ import { nanoid as createId } from "nanoid"; +import { CostCalculator } from "../cost/CostCalculator.js"; +import { calculateCost } from "../cost/calculateCost.js"; import { SuccessfulModelCall, extractSuccessfulModelCalls, } from "../model-function/SuccessfulModelCall.js"; -import { - ModelCallFinishedEvent, - ModelCallStartedEvent, -} from "../model-function/ModelCallEvent.js"; -import { ModelCallObserver } from "../model-function/ModelCallObserver.js"; import { Run } from "./Run.js"; -import { CostCalculator } from "../cost/CostCalculator.js"; -import { calculateCost } from "../cost/calculateCost.js"; +import { RunFunctionEvent } from "./RunFunctionEvent.js"; +import { RunFunctionObserver } from "./RunFunctionObserver.js"; + export class DefaultRun implements Run { readonly runId: string; readonly sessionId?: string; @@ -19,10 +17,9 @@ export class DefaultRun implements Run { readonly abortSignal?: AbortSignal; readonly costCalculators: CostCalculator[]; - readonly modelCallEvents: (ModelCallFinishedEvent | ModelCallStartedEvent)[] = - []; + readonly events: RunFunctionEvent[] = []; - readonly observers?: ModelCallObserver[]; + readonly observers?: RunFunctionObserver[]; constructor({ runId = createId(), @@ -36,7 +33,7 @@ export class DefaultRun implements Run { sessionId?: string; userId?: string; abortSignal?: AbortSignal; - observers?: ModelCallObserver[]; + observers?: RunFunctionObserver[]; costCalculators?: CostCalculator[]; } = {}) { this.runId = runId; @@ -47,11 +44,11 @@ export class DefaultRun implements Run { this.observers = [ { - onModelCallStarted: (event) => { - this.modelCallEvents.push(event); + onRunFunctionStarted: (event) => { + this.events.push(event); }, - onModelCallFinished: (event) => { - this.modelCallEvents.push(event); + onRunFunctionFinished: (event) => { + this.events.push(event); }, }, ...(observers ?? []), @@ -59,7 +56,7 @@ export class DefaultRun implements Run { } get successfulModelCalls(): Array { - return extractSuccessfulModelCalls(this.modelCallEvents); + return extractSuccessfulModelCalls(this.events); } calculateCost() { diff --git a/src/run/Run.ts b/src/run/Run.ts index b57013b73..40b2e60da 100644 --- a/src/run/Run.ts +++ b/src/run/Run.ts @@ -1,5 +1,5 @@ import { ErrorHandler } from "../util/ErrorHandler.js"; -import { ModelCallObserver } from "../model-function/ModelCallObserver.js"; +import { RunFunctionObserver } from "./RunFunctionObserver.js"; export interface Run { /** @@ -27,7 +27,7 @@ export interface Run { */ abortSignal?: AbortSignal; - observers?: ModelCallObserver[]; + observers?: RunFunctionObserver[]; errorHandler?: ErrorHandler; } diff --git a/src/run/RunFunction.ts b/src/run/RunFunction.ts index 778f81964..3abe90072 100644 --- a/src/run/RunFunction.ts +++ b/src/run/RunFunction.ts @@ -1,4 +1,3 @@ -import { SafeResult } from "../util/SafeResult.js"; import { Run } from "./Run.js"; /** @@ -12,10 +11,3 @@ export type RunFunction = ( run?: Run; } ) => PromiseLike; - -export type SafeRunFunction = ( - input: INPUT, - options?: { - run?: Run; - } -) => PromiseLike>; diff --git a/src/run/RunFunctionEvent.ts b/src/run/RunFunctionEvent.ts new file mode 100644 index 000000000..962b5672c --- /dev/null +++ b/src/run/RunFunctionEvent.ts @@ -0,0 +1,30 @@ +import { + ExecuteToolFinishedEvent, + ExecuteToolStartedEvent, +} from "tool/ExecuteToolEvent.js"; +import { + ModelCallFinishedEvent, + ModelCallStartedEvent, +} from "../model-function/ModelCallEvent.js"; +import { IdMetadata } from "./IdMetadata.js"; + +export type RunFunctionEvent = + | RunFunctionStartedEvent + | RunFunctionFinishedEvent; + +export type RunFunctionStartedEventMetadata = IdMetadata & { + startEpochSeconds: number; +}; + +export type RunFunctionStartedEvent = + | ModelCallStartedEvent + | ExecuteToolStartedEvent; + +export type RunFunctionFinishedEventMetadata = + RunFunctionStartedEventMetadata & { + durationInMs: number; + }; + +export type RunFunctionFinishedEvent = + | ModelCallFinishedEvent + | ExecuteToolFinishedEvent; diff --git a/src/model-function/ModelCallEventSource.ts b/src/run/RunFunctionEventSource.ts similarity index 54% rename from src/model-function/ModelCallEventSource.ts rename to src/run/RunFunctionEventSource.ts index 4e8284a63..3e2eeb7ae 100644 --- a/src/model-function/ModelCallEventSource.ts +++ b/src/run/RunFunctionEventSource.ts @@ -1,39 +1,39 @@ import { ErrorHandler } from "../util/ErrorHandler.js"; import { - ModelCallFinishedEvent, - ModelCallStartedEvent, -} from "./ModelCallEvent.js"; -import { ModelCallObserver } from "./ModelCallObserver.js"; + RunFunctionFinishedEvent, + RunFunctionStartedEvent, +} from "./RunFunctionEvent.js"; +import { RunFunctionObserver } from "./RunFunctionObserver.js"; -export class ModelCallEventSource { - readonly observers: ModelCallObserver[]; +export class RunFunctionEventSource { + readonly observers: RunFunctionObserver[]; readonly errorHandler: ErrorHandler; constructor({ observers, errorHandler, }: { - observers: ModelCallObserver[]; + observers: RunFunctionObserver[]; errorHandler?: ErrorHandler; }) { this.observers = observers; this.errorHandler = errorHandler ?? ((error) => console.error(error)); } - notifyModelCallStarted(event: ModelCallStartedEvent) { + notifyRunFunctionStarted(event: RunFunctionStartedEvent) { for (const observer of this.observers) { try { - observer.onModelCallStarted?.(event); + observer.onRunFunctionStarted?.(event); } catch (error) { this.errorHandler(error); } } } - notifyModelCallFinished(event: ModelCallFinishedEvent) { + notifyRunFunctionFinished(event: RunFunctionFinishedEvent) { for (const observer of this.observers) { try { - observer.onModelCallFinished?.(event); + observer.onRunFunctionFinished?.(event); } catch (error) { this.errorHandler(error); } diff --git a/src/run/RunFunctionObserver.ts b/src/run/RunFunctionObserver.ts new file mode 100644 index 000000000..d7ea366dc --- /dev/null +++ b/src/run/RunFunctionObserver.ts @@ -0,0 +1,9 @@ +import { + RunFunctionFinishedEvent, + RunFunctionStartedEvent, +} from "./RunFunctionEvent.js"; + +export type RunFunctionObserver = { + onRunFunctionStarted?: (event: RunFunctionStartedEvent) => void; + onRunFunctionFinished?: (event: RunFunctionFinishedEvent) => void; +}; diff --git a/src/run/index.ts b/src/run/index.ts index b9f0045fa..500569290 100644 --- a/src/run/index.ts +++ b/src/run/index.ts @@ -3,4 +3,7 @@ export * from "./DefaultRun.js"; export * from "./IdMetadata.js"; export * from "./Run.js"; export * from "./RunFunction.js"; +export * from "./RunFunctionEvent.js"; +export * from "./RunFunctionObserver.js"; +export * from "./RunFunctionEventSource.js"; export * from "./Vector.js"; diff --git a/src/tool/ExecuteToolEvent.ts b/src/tool/ExecuteToolEvent.ts new file mode 100644 index 000000000..10ba06f29 --- /dev/null +++ b/src/tool/ExecuteToolEvent.ts @@ -0,0 +1,26 @@ +import { + RunFunctionFinishedEventMetadata, + RunFunctionStartedEventMetadata, +} from "../run/RunFunctionEvent.js"; +import { Tool } from "./Tool.js"; + +export type ExecuteToolStartedEvent = { + type: "execute-tool-started"; + metadata: RunFunctionStartedEventMetadata; + tool: Tool; + input: unknown; +}; + +export type ExecuteToolFinishedEvent = { + type: "execute-tool-finished"; + metadata: RunFunctionFinishedEventMetadata; + tool: Tool; + input: unknown; +} & ( + | { + status: "success"; + output: unknown; + } + | { status: "failure"; error: unknown } + | { status: "abort" } +); diff --git a/src/composed-function/use-tool/NoSuchToolError.ts b/src/tool/NoSuchToolError.ts similarity index 100% rename from src/composed-function/use-tool/NoSuchToolError.ts rename to src/tool/NoSuchToolError.ts diff --git a/src/composed-function/use-tool/Tool.ts b/src/tool/Tool.ts similarity index 67% rename from src/composed-function/use-tool/Tool.ts rename to src/tool/Tool.ts index 71cca7ee6..4d8ac1657 100644 --- a/src/composed-function/use-tool/Tool.ts +++ b/src/tool/Tool.ts @@ -1,21 +1,25 @@ import { z } from "zod"; -import { SchemaDefinition } from "../../model-function/generate-json/SchemaDefinition.js"; +import { SchemaDefinition } from "../model-function/generate-json/SchemaDefinition.js"; +import { RunFunction } from "../run/RunFunction.js"; export class Tool { readonly name: NAME; readonly description: string; readonly inputSchema: z.ZodSchema; - readonly execute: (input: INPUT) => PromiseLike; + readonly outputSchema?: z.ZodSchema; + readonly execute: RunFunction; constructor(options: { name: NAME; description: string; inputSchema: z.ZodSchema; + outputSchema?: z.ZodSchema; execute(input: INPUT): Promise; }) { this.name = options.name; this.description = options.description; this.inputSchema = options.inputSchema; + this.outputSchema = options.outputSchema; this.execute = options.execute; } diff --git a/src/tool/ToolExecutionError.ts b/src/tool/ToolExecutionError.ts new file mode 100644 index 000000000..b979a00e6 --- /dev/null +++ b/src/tool/ToolExecutionError.ts @@ -0,0 +1,24 @@ +export class ToolExecutionError extends Error { + readonly toolName: string; + readonly input: unknown; + readonly cause: unknown; + + constructor({ + message = "unknown error", + toolName, + input, + cause, + }: { + toolName: string; + input: unknown; + message: string | undefined; + cause: unknown | undefined; + }) { + super(`Error executing tool ${toolName}: ${message}`); + + this.name = "ToolExecutionError"; + this.toolName = toolName; + this.input = input; + this.cause = cause; + } +} diff --git a/src/tool/executeTool.ts b/src/tool/executeTool.ts new file mode 100644 index 000000000..694a6da7b --- /dev/null +++ b/src/tool/executeTool.ts @@ -0,0 +1,130 @@ +import { nanoid as createId } from "nanoid"; +import { FunctionOptions } from "../model-function/FunctionOptions.js"; +import { RunFunctionEventSource } from "../run/RunFunctionEventSource.js"; +import { startDurationMeasurement } from "../util/DurationMeasurement.js"; +import { AbortError } from "../util/api/AbortError.js"; +import { runSafe } from "../util/runSafe.js"; +import { Tool } from "./Tool.js"; +import { ToolExecutionError } from "./ToolExecutionError.js"; + +export type ExecuteToolMetadata = { + callId: string; + runId?: string; + sessionId?: string; + userId?: string; + functionId?: string; + startEpochSeconds: number; + durationInMs: number; +}; + +export async function executeTool( + tool: Tool, + input: INPUT, + options: FunctionOptions & { + fullResponse: true; + } +): Promise<{ + output: OUTPUT; + metadata: ExecuteToolMetadata; +}>; +export async function executeTool( + tool: Tool, + input: INPUT, + options?: FunctionOptions & { + fullResponse?: false; + } +): Promise; +export async function executeTool( + tool: Tool, + input: INPUT, + options?: FunctionOptions & { + fullResponse?: boolean; + } +): Promise< + | OUTPUT + | { + output: OUTPUT; + metadata: ExecuteToolMetadata; + } +> { + const run = options?.run; + + const eventSource = new RunFunctionEventSource({ + observers: run?.observers ?? [], + errorHandler: run?.errorHandler, + }); + + const durationMeasurement = startDurationMeasurement(); + + const startMetadata = { + callId: `call-${createId()}`, + runId: run?.runId, + sessionId: run?.sessionId, + userId: run?.userId, + functionId: options?.functionId, + startEpochSeconds: durationMeasurement.startEpochSeconds, + }; + + eventSource.notifyRunFunctionStarted({ + type: "execute-tool-started", + metadata: startMetadata, + tool: tool as Tool, + input, + }); + + const result = await runSafe(() => tool.execute(input, options)); + + const finishMetadata = { + ...startMetadata, + durationInMs: durationMeasurement.durationInMs, + }; + + if (!result.ok) { + if (result.isAborted) { + eventSource.notifyRunFunctionFinished({ + type: "execute-tool-finished", + status: "abort", + metadata: finishMetadata, + tool: tool as Tool, + input, + }); + + throw new AbortError(); + } + + eventSource.notifyRunFunctionFinished({ + type: "execute-tool-finished", + status: "failure", + metadata: finishMetadata, + tool: tool as Tool, + input, + error: result.error, + }); + + throw new ToolExecutionError({ + toolName: tool.name, + input, + cause: result.error, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + message: (result.error as any)?.message, + }); + } + + const output = result.output; + + eventSource.notifyRunFunctionFinished({ + type: "execute-tool-finished", + status: "success", + metadata: finishMetadata, + tool: tool as Tool, + input, + output, + }); + + return options?.fullResponse === true + ? { + output, + metadata: finishMetadata, + } + : output; +} diff --git a/src/tool/index.ts b/src/tool/index.ts new file mode 100644 index 000000000..cf5a654db --- /dev/null +++ b/src/tool/index.ts @@ -0,0 +1,6 @@ +export * from "./NoSuchToolError.js"; +export * from "./Tool.js"; +export * from "./ToolExecutionError.js"; +export * from "./executeTool.js"; +export * from "./useTool.js"; +export * from "./useToolOrGenerateText.js"; diff --git a/src/tool/useTool.ts b/src/tool/useTool.ts new file mode 100644 index 000000000..e84461fc2 --- /dev/null +++ b/src/tool/useTool.ts @@ -0,0 +1,64 @@ +import { FunctionOptions } from "../model-function/FunctionOptions.js"; +import { + GenerateJsonModel, + GenerateJsonModelSettings, + GenerateJsonPrompt, +} from "../model-function/generate-json/GenerateJsonModel.js"; +import { generateJson } from "../model-function/generate-json/generateJson.js"; +import { Tool } from "./Tool.js"; +import { executeTool } from "./executeTool.js"; + +// In this file, using 'any' is required to allow for flexibility in the inputs. The actual types are +// retrieved through lookups such as TOOL["name"], such that any does not affect any client. +/* eslint-disable @typescript-eslint/no-explicit-any */ + +/** + * `useTool` uses `generateJson` to generate parameters for a tool and then executes the tool with the parameters. + * + * @returns The result contains the name of the tool (`tool` property), + * the parameters (`parameters` property, typed), + * and the result of the tool execution (`result` property, typed). + */ +export async function useTool< + PROMPT, + RESPONSE, + SETTINGS extends GenerateJsonModelSettings, + TOOL extends Tool, +>( + model: GenerateJsonModel, + tool: TOOL, + prompt: (tool: TOOL) => PROMPT & GenerateJsonPrompt, + options?: FunctionOptions +): Promise<{ + tool: TOOL["name"]; + parameters: TOOL["inputSchema"]; + result: Awaited>; +}> { + const { value } = await generateJson< + TOOL["inputSchema"], + PROMPT, + RESPONSE, + TOOL["name"], + SETTINGS + >( + model, + { + name: tool.name, + description: tool.description, + schema: tool.inputSchema, + }, + () => prompt(tool), + { + ...(options ?? {}), + fullResponse: true, + } + ); + + return { + tool: tool.name, + parameters: value, + result: await executeTool(tool, value, { + run: options?.run, + }), + }; +} diff --git a/src/composed-function/use-tool/useTool.ts b/src/tool/useToolOrGenerateText.ts similarity index 57% rename from src/composed-function/use-tool/useTool.ts rename to src/tool/useToolOrGenerateText.ts index 0d4a1fb64..423221703 100644 --- a/src/composed-function/use-tool/useTool.ts +++ b/src/tool/useToolOrGenerateText.ts @@ -1,72 +1,18 @@ -import { FunctionOptions } from "../../model-function/FunctionOptions.js"; -import { - GenerateJsonModel, - GenerateJsonModelSettings, - GenerateJsonPrompt, -} from "../../model-function/generate-json/GenerateJsonModel.js"; +import { FunctionOptions } from "../model-function/FunctionOptions.js"; import { GenerateJsonOrTextModel, GenerateJsonOrTextModelSettings, GenerateJsonOrTextPrompt, -} from "../../model-function/generate-json/GenerateJsonOrTextModel.js"; -import { generateJson } from "../../model-function/generate-json/generateJson.js"; -import { generateJsonOrText } from "../../model-function/generate-json/generateJsonOrText.js"; +} from "../model-function/generate-json/GenerateJsonOrTextModel.js"; +import { generateJsonOrText } from "../model-function/generate-json/generateJsonOrText.js"; import { NoSuchToolError } from "./NoSuchToolError.js"; import { Tool } from "./Tool.js"; +import { executeTool } from "./executeTool.js"; // In this file, using 'any' is required to allow for flexibility in the inputs. The actual types are // retrieved through lookups such as TOOL["name"], such that any does not affect any client. /* eslint-disable @typescript-eslint/no-explicit-any */ -/** - * `useTool` uses `generateJson` to generate parameters for a tool and then executes the tool with the parameters. - * - * @returns The result contains the name of the tool (`tool` property), - * the parameters (`parameters` property, typed), - * and the result of the tool execution (`result` property, typed). - */ -export async function useTool< - PROMPT, - RESPONSE, - SETTINGS extends GenerateJsonModelSettings, - TOOL extends Tool, ->( - model: GenerateJsonModel, - tool: TOOL, - prompt: (tool: TOOL) => PROMPT & GenerateJsonPrompt, - options?: FunctionOptions -): Promise<{ - tool: TOOL["name"]; - parameters: TOOL["inputSchema"]; - result: Awaited>; -}> { - const { value } = await generateJson< - TOOL["inputSchema"], - PROMPT, - RESPONSE, - TOOL["name"], - SETTINGS - >( - model, - { - name: tool.name, - description: tool.description, - schema: tool.inputSchema, - }, - () => prompt(tool), - { - ...(options ?? {}), - fullResponse: true, - } - ); - - return { - tool: tool.name, - parameters: value, - result: await tool.execute(value), - }; -} - // [ { name: "n", ... } | { ... } ] type ToolArray[]> = T; @@ -131,7 +77,9 @@ export async function useToolOrGenerateText< const toolParameters = modelResponse.value; - const result = await tool.execute(toolParameters); + const result = await executeTool(tool, toolParameters, { + run: options?.run, + }); return { tool: schema as keyof ToToolMap,