Skip to content

Commit

Permalink
Merge pull request #45 from lgrammel/tools
Browse files Browse the repository at this point in the history
Tool improvements
  • Loading branch information
lgrammel authored Aug 12, 2023
2 parents dbb6469 + dbea4f8 commit 399e51a
Show file tree
Hide file tree
Showing 36 changed files with 474 additions and 160 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
build
dist
node_modules
examples/*/package-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ dotenv.config();
fs.readFileSync("data/san-francisco-wikipedia.json", "utf8")
).content;

const { value: extractedInformation1 } = await extractNameAndPopulation(
const extractedInformation1 = await extractNameAndPopulation(
sanFranciscoWikipedia.slice(0, 2000)
);

console.log(extractedInformation1);

const { value: extractedInformation2 } = await extractNameAndPopulation(
const extractedInformation2 = await extractNameAndPopulation(
"Carl was a friendly robot."
);

Expand Down
41 changes: 41 additions & 0 deletions examples/basic/src/tool/execute-tool-example.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import dotenv from "dotenv";
import { Tool, executeTool } from "modelfusion";
import { z } from "zod";

dotenv.config();

(async () => {
const calculator = new Tool({
name: "calculator" as const, // mark 'as const' for type inference
description: "Execute a calculation",

inputSchema: z.object({
a: z.number().describe("The first number."),
b: z.number().describe("The second number."),
operator: z.enum(["+", "-", "*", "/"]).describe("The operator."),
}),

execute: async ({ a, b, operator }) => {
switch (operator) {
case "+":
return a + b;
case "-":
return a - b;
case "*":
return a * b;
case "/":
return a / b;
default:
throw new Error(`Unknown operator: ${operator}`);
}
},
});

const result = await executeTool(calculator, {
a: 14,
b: 12,
operator: "*" as const,
});

console.log(`Result: ${result}`);
})();
46 changes: 46 additions & 0 deletions examples/basic/src/tool/execute-tool-full-example.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import dotenv from "dotenv";
import { Tool, executeTool } from "modelfusion";
import { z } from "zod";

dotenv.config();

(async () => {
const calculator = new Tool({
name: "calculator" as const, // mark 'as const' for type inference
description: "Execute a calculation",

inputSchema: z.object({
a: z.number().describe("The first number."),
b: z.number().describe("The second number."),
operator: z.enum(["+", "-", "*", "/"]).describe("The operator."),
}),

execute: async ({ a, b, operator }) => {
switch (operator) {
case "+":
return a + b;
case "-":
return a - b;
case "*":
return a * b;
case "/":
return a / b;
default:
throw new Error(`Unknown operator: ${operator}`);
}
},
});

const { metadata, output } = await executeTool(
calculator,
{
a: 14,
b: 12,
operator: "*" as const,
},
{ fullResponse: true }
);

console.log(`Result: ${output}`);
console.log(`Duration: ${metadata.durationInMs}ms`);
})();
File renamed without changes.
4 changes: 2 additions & 2 deletions examples/pdf-to-tweet/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const run = new DefaultRun({
costCalculators: [new OpenAICostCalculator()],
observers: [
{
onModelCallStarted(event) {
onRunFunctionStarted(event) {
if (event.type === "text-generation-started") {
console.log(
`Generate text ${event.metadata.functionId ?? "unknown"} started.`
Expand All @@ -32,7 +32,7 @@ const run = new DefaultRun({
}
},

onModelCallFinished(event) {
onRunFunctionFinished(event) {
if (event.type === "text-generation-finished") {
console.log(
`Generate text ${event.metadata.functionId ?? "unknown"} finished.`
Expand Down
3 changes: 0 additions & 3 deletions src/composed-function/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
export * from "./use-tool/NoSuchToolError.js";
export * from "./use-tool/Tool.js";
export * from "./use-tool/useTool.js";
export * from "./summarize/SummarizationFunction.js";
export * from "./summarize/summarizeRecursively.js";
export * from "./summarize/summarizeRecursivelyWithTextGenerationAndTokenSplitting.js";
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ export * from "./model-provider/index.js";
export * from "./prompt/index.js";
export * from "./run/index.js";
export * from "./text-chunk/index.js";
export * from "./tool/index.js";
export * from "./util/index.js";
export * from "./vector-index/index.js";
4 changes: 2 additions & 2 deletions src/model-function/Model.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { ModelInformation } from "./ModelInformation.js";
import { ModelCallObserver } from "./ModelCallObserver.js";
import { RunFunctionObserver } from "../run/RunFunctionObserver.js";

export interface ModelSettings {
observers?: Array<ModelCallObserver>;
observers?: Array<RunFunctionObserver>;
}

export interface Model<SETTINGS> {
Expand Down
17 changes: 9 additions & 8 deletions src/model-function/ModelCallEvent.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { IdMetadata } from "../run/IdMetadata.js";
import {
RunFunctionFinishedEventMetadata,
RunFunctionStartedEventMetadata,
} from "../run/RunFunctionEvent.js";
import { ModelInformation } from "./ModelInformation.js";
import {
TextEmbeddingFinishedEvent,
Expand All @@ -25,11 +28,8 @@ import {
TranscriptionStartedEvent,
} from "./transcribe-audio/TranscriptionEvent.js";

export type ModelCallEvent = ModelCallStartedEvent | ModelCallFinishedEvent;

export type ModelCallStartedEventMetadata = IdMetadata & {
export type ModelCallStartedEventMetadata = RunFunctionStartedEventMetadata & {
model: ModelInformation;
startEpochSeconds: number;
};

export type ModelCallStartedEvent =
Expand All @@ -40,9 +40,10 @@ export type ModelCallStartedEvent =
| TextStreamingStartedEvent
| TranscriptionStartedEvent;

export type ModelCallFinishedEventMetadata = ModelCallStartedEventMetadata & {
durationInMs: number;
};
export type ModelCallFinishedEventMetadata =
RunFunctionFinishedEventMetadata & {
model: ModelInformation;
};

export type ModelCallFinishedEvent =
| ImageGenerationFinishedEvent
Expand Down
9 changes: 0 additions & 9 deletions src/model-function/ModelCallObserver.ts

This file was deleted.

16 changes: 9 additions & 7 deletions src/model-function/SuccessfulModelCall.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import {
ModelCallFinishedEvent,
ModelCallStartedEvent,
} from "./ModelCallEvent.js";
import { RunFunctionEvent } from "../run/RunFunctionEvent.js";
import { ModelCallFinishedEvent } from "./ModelCallEvent.js";
import { ModelInformation } from "./ModelInformation.js";

export type SuccessfulModelCall = {
type:
| "image-generation"
| "json-generation"
| "json-or-text-generation"
| "text-embedding"
| "text-generation"
| "text-streaming"
Expand All @@ -18,12 +17,14 @@ export type SuccessfulModelCall = {
};

export function extractSuccessfulModelCalls(
modelCallEvents: (ModelCallFinishedEvent | ModelCallStartedEvent)[]
runFunctionEvents: RunFunctionEvent[]
) {
return modelCallEvents
return runFunctionEvents
.filter(
(event): event is ModelCallFinishedEvent & { status: "success" } =>
"status" in event && event.status === "success"
Object.keys(eventTypeToCostType).includes(event.type) &&
"status" in event &&
event.status === "success"
)
.map(
(event): SuccessfulModelCall => ({
Expand All @@ -38,6 +39,7 @@ export function extractSuccessfulModelCalls(
const eventTypeToCostType = {
"image-generation-finished": "image-generation" as const,
"json-generation-finished": "json-generation" as const,
"json-or-text-generation-finished": "json-or-text-generation" as const,
"text-embedding-finished": "text-embedding" as const,
"text-generation-finished": "text-generation" as const,
"text-streaming-finished": "text-streaming" as const,
Expand Down
12 changes: 6 additions & 6 deletions src/model-function/executeCall.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { nanoid as createId } from "nanoid";
import { RunFunctionEventSource } from "../run/RunFunctionEventSource.js";
import { startDurationMeasurement } from "../util/DurationMeasurement.js";
import { AbortError } from "../util/api/AbortError.js";
import { runSafe } from "../util/runSafe.js";
Expand All @@ -10,7 +11,6 @@ import {
ModelCallStartedEvent,
ModelCallStartedEventMetadata,
} from "./ModelCallEvent.js";
import { ModelCallEventSource } from "./ModelCallEventSource.js";

export type CallMetadata<MODEL extends Model<unknown>> = {
callId: string;
Expand Down Expand Up @@ -79,7 +79,7 @@ export async function executeCall<
const run = options?.run;
const settings = model.settings;

const eventSource = new ModelCallEventSource({
const eventSource = new RunFunctionEventSource({
observers: [...(settings.observers ?? []), ...(run?.observers ?? [])],
errorHandler: run?.errorHandler,
});
Expand All @@ -96,7 +96,7 @@ export async function executeCall<
startEpochSeconds: durationMeasurement.startEpochSeconds,
};

eventSource.notifyModelCallStarted(getStartEvent(startMetadata, settings));
eventSource.notifyRunFunctionStarted(getStartEvent(startMetadata, settings));

const result = await runSafe(() =>
generateResponse({
Expand All @@ -113,13 +113,13 @@ export async function executeCall<

if (!result.ok) {
if (result.isAborted) {
eventSource.notifyModelCallFinished(
eventSource.notifyRunFunctionFinished(
getAbortEvent(finishMetadata, settings)
);
throw new AbortError();
}

eventSource.notifyModelCallFinished(
eventSource.notifyRunFunctionFinished(
getFailureEvent(finishMetadata, settings, result.error)
);
throw result.error;
Expand All @@ -128,7 +128,7 @@ export async function executeCall<
const response = result.output;
const output = extractOutputValue(response);

eventSource.notifyModelCallFinished(
eventSource.notifyRunFunctionFinished(
getSuccessEvent(finishMetadata, settings, response, output)
);

Expand Down
4 changes: 2 additions & 2 deletions src/model-function/generate-json/JsonGenerationEvent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import {
} from "../ModelCallEvent.js";

export type JsonGenerationStartedEvent = {
type: "json-generation-started";
type: "json-generation-started" | "json-or-text-generation-started";
metadata: ModelCallStartedEventMetadata;
settings: unknown;
prompt: unknown;
};

export type JsonGenerationFinishedEvent = {
type: "json-generation-finished";
type: "json-generation-finished" | "json-or-text-generation-finished";
metadata: ModelCallFinishedEventMetadata;
settings: unknown;
prompt: unknown;
Expand Down
8 changes: 4 additions & 4 deletions src/model-function/generate-json/generateJsonOrText.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,28 +135,28 @@ export async function generateJsonOrText<
};
},
getStartEvent: (metadata, settings) => ({
type: "json-generation-started",
type: "json-or-text-generation-started",
metadata,
settings,
prompt,
}),
getAbortEvent: (metadata, settings) => ({
type: "json-generation-finished",
type: "json-or-text-generation-finished",
status: "abort",
metadata,
settings,
prompt,
}),
getFailureEvent: (metadata, settings, error) => ({
type: "json-generation-finished",
type: "json-or-text-generation-finished",
status: "failure",
metadata,
settings,
prompt,
error,
}),
getSuccessEvent: (metadata, settings, response, output) => ({
type: "json-generation-finished",
type: "json-or-text-generation-finished",
status: "success",
metadata,
settings,
Expand Down
Loading

1 comment on commit 399e51a

@vercel
Copy link

@vercel vercel bot commented on 399e51a Aug 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.