From 3f3208dd2c92f36ad0597fb6f7baf39e8852a30f Mon Sep 17 00:00:00 2001 From: Ian Webster Date: Sun, 19 Nov 2023 13:20:21 -0800 Subject: [PATCH] feat: Support for OpenAI assistants API (#283) --- package-lock.json | 41 +++++++++ package.json | 1 + src/cache.ts | 3 +- src/providers.ts | 4 + src/providers/openai.ts | 182 ++++++++++++++++++++++++++++++++++++++++ test/providers.test.ts | 11 ++- 6 files changed, 240 insertions(+), 2 deletions(-) diff --git a/package-lock.json b/package-lock.json index d4ea59401273..acf117e720f1 100644 --- a/package-lock.json +++ b/package-lock.json @@ -30,6 +30,7 @@ "js-yaml": "^4.1.0", "node-fetch": "^2.6.7", "nunjucks": "^3.2.4", + "openai": "^4.19.0", "opener": "^1.5.2", "replicate": "^0.12.3", "rouge": "^1.0.3", @@ -4992,6 +4993,41 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/openai": { + "version": "4.19.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.19.0.tgz", + "integrity": "sha512-cJbl0noZyAaXVKBTMMq6X5BAvP1pm2rWYDBnZes99NL+Zh5/4NmlAwyuhTZEru5SqGGZIoiYKeMPXy4bm9DI0w==", + "dependencies": { + "@types/node": "^18.11.18", + "@types/node-fetch": "^2.6.4", + "abort-controller": "^3.0.0", + "agentkeepalive": "^4.2.1", + "digest-fetch": "^1.3.0", + "form-data-encoder": "1.7.2", + "formdata-node": "^4.3.2", + "node-fetch": "^2.6.7", + "web-streams-polyfill": "^3.2.1" + }, + "bin": { + "openai": "bin/cli" + } + }, + "node_modules/openai/node_modules/@types/node": { + "version": "18.18.10", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.18.10.tgz", + "integrity": "sha512-luANqZxPmjTll8bduz4ACs/lNTCLuWssCyjqTY9yLdsv1xnViQp3ISKwsEWOIecO13JWUqjVdig/Vjjc09o8uA==", + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "node_modules/openai/node_modules/web-streams-polyfill": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/web-streams-polyfill/-/web-streams-polyfill-3.2.1.tgz", + "integrity": "sha512-e0MO3wdXWKrLbL0DgGnUV7WHVuw9OUvL4hjgnPkIeEvESk74gAITi5G606JtZPp39cd8HA9VQzCIvA49LpPN5Q==", + "engines": { + "node": ">= 8" + } + }, "node_modules/opener": { "version": "1.5.2", "resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz", @@ -6252,6 +6288,11 @@ "node": ">=14.17" } }, + "node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==" + }, "node_modules/unpipe": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", diff --git a/package.json b/package.json index b282fb4b5c5a..85566d6af503 100644 --- a/package.json +++ b/package.json @@ -86,6 +86,7 @@ "js-yaml": "^4.1.0", "node-fetch": "^2.6.7", "nunjucks": "^3.2.4", + "openai": "^4.19.0", "opener": "^1.5.2", "replicate": "^0.12.3", "rouge": "^1.0.3", diff --git a/src/cache.ts b/src/cache.ts index 46fad614c54b..be0af0c6f095 100644 --- a/src/cache.ts +++ b/src/cache.ts @@ -50,8 +50,9 @@ export async function fetchWithCache( options: RequestInit = {}, timeout: number, format: 'json' | 'text' = 'json', + bust: boolean = false, ): Promise<{ data: any; cached: boolean }> { - if (!enabled) { + if (!enabled || bust) { const resp = await fetchWithRetries(url, options, timeout); return { cached: false, diff --git a/src/providers.ts b/src/providers.ts index 480c9a248a35..5bdd539af69a 100644 --- a/src/providers.ts +++ b/src/providers.ts @@ -1,6 +1,7 @@ import path from 'path'; import { + OpenAiAssistantProvider, OpenAiCompletionProvider, OpenAiChatCompletionProvider, OpenAiEmbeddingProvider, @@ -127,6 +128,8 @@ export async function loadApiProvider( return new OpenAiChatCompletionProvider(modelType, providerOptions); } else if (OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelType)) { return new OpenAiCompletionProvider(modelType, providerOptions); + } else if (modelType === 'assistant') { + return new OpenAiAssistantProvider(modelName, providerOptions); } else { throw new Error( `Unknown OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:, openai:completion:`, @@ -237,6 +240,7 @@ export async function loadApiProvider( export default { OpenAiCompletionProvider, OpenAiChatCompletionProvider, + OpenAiAssistantProvider, AnthropicCompletionProvider, ReplicateProvider, LocalAiCompletionProvider, diff --git a/src/providers/openai.ts b/src/providers/openai.ts index 3801d4724a8d..0accd9e709fa 100644 --- a/src/providers/openai.ts +++ b/src/providers/openai.ts @@ -1,3 +1,5 @@ +import OpenAI from 'openai'; + import logger from '../logger'; import { fetchWithCache } from '../cache'; import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared'; @@ -32,6 +34,17 @@ interface OpenAiCompletionOptions { organization?: string; } +function failApiCall(err: any) { + if (err instanceof OpenAI.APIError) { + return { + error: `API error: ${err.type} ${err.message}`, + }; + } + return { + error: `API error: ${String(err)}`, + }; +} + class OpenAiGenericProvider implements ApiProvider { modelName: string; @@ -342,6 +355,175 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider { } } +interface AssistantMessagesResponseDataContent { + type: string; + text?: { + value: string; + }; +} + +interface AssistantMessagesResponseData { + data: { + role: string; + content?: AssistantMessagesResponseDataContent[]; + }[]; +} + +interface OpenAiAssistantOptions { + modelName?: string; + instructions?: string; + tools?: OpenAI.Beta.Threads.ThreadCreateAndRunParams['tools']; + metadata?: object[]; +} + +function toTitleCase(str: string) { + return str.replace(/\w\S*/g, (txt) => txt.charAt(0).toUpperCase() + txt.substr(1).toLowerCase()); +} + +export class OpenAiAssistantProvider extends OpenAiGenericProvider { + assistantId: string; + assistantConfig: OpenAiAssistantOptions; + + constructor( + assistantId: string, + options: { config?: OpenAiAssistantOptions; id?: string; env?: EnvOverrides } = {}, + ) { + super(assistantId, {}); + this.assistantConfig = options.config || {}; + this.assistantId = assistantId; + } + + async callApi(prompt: string): Promise { + if (!this.getApiKey()) { + throw new Error( + 'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.', + ); + } + + const openai = new OpenAI({ + maxRetries: 3, + timeout: REQUEST_TIMEOUT_MS, + }); + + const messages = parseChatPrompt(prompt, [ + { role: 'user', content: prompt }, + ]) as OpenAI.Beta.Threads.ThreadCreateParams.Message[]; + const body: OpenAI.Beta.Threads.ThreadCreateAndRunParams = { + assistant_id: this.assistantId, + model: this.assistantConfig.modelName || undefined, + instructions: this.assistantConfig.instructions || undefined, + tools: this.assistantConfig.tools || undefined, + metadata: this.assistantConfig.metadata || undefined, + thread: { + messages, + }, + }; + + logger.debug(`Calling OpenAI API, creating thread run: ${JSON.stringify(body)}`); + let run; + try { + run = await openai.beta.threads.createAndRun(body); + } catch (err) { + return failApiCall(err); + } + + logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`); + + while (run.status === 'in_progress' || run.status === 'queued') { + await new Promise((resolve) => setTimeout(resolve, 1000)); + + logger.debug(`Calling OpenAI API, getting thread run ${run.id} status`); + try { + run = await openai.beta.threads.runs.retrieve(run.thread_id, run.id); + } catch (err) { + return failApiCall(err); + } + logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`); + } + + if (run.status !== 'completed') { + if (run.last_error) { + return { + error: `Thread run failed: ${run.last_error.message}`, + }; + } + return { + error: `Thread run failed: ${run.status}`, + }; + } + + // Get run steps + logger.debug(`Calling OpenAI API, getting thread run steps for ${run.thread_id}`); + let steps; + try { + steps = await openai.beta.threads.runs.steps.list(run.thread_id, run.id, { + order: 'asc', + }); + } catch (err) { + return failApiCall(err); + } + logger.debug(`\tOpenAI thread run steps API response: ${JSON.stringify(steps)}`); + + const outputBlocks = []; + for (const step of steps.data) { + if (step.step_details.type === 'message_creation') { + logger.debug(`Calling OpenAI API, getting message ${step.id}`); + let message; + try { + message = await openai.beta.threads.messages.retrieve( + run.thread_id, + step.step_details.message_creation.message_id, + ); + } catch (err) { + return failApiCall(err); + } + logger.debug(`\tOpenAI thread run step message API response: ${JSON.stringify(message)}`); + + const content = message.content + .map((content) => + content.type === 'text' ? content.text.value : `<${content.type} output>`, + ) + .join('\n'); + outputBlocks.push(`[${toTitleCase(message.role)}] ${content}`); + } else if (step.step_details.type === 'tool_calls') { + for (const toolCall of step.step_details.tool_calls) { + if (toolCall.type === 'function') { + outputBlocks.push( + `[Call function ${toolCall.function.name} with arguments ${toolCall.function.arguments}]`, + ); + outputBlocks.push(`[Function output: ${toolCall.function.output}]`); + } else if (toolCall.type === 'retrieval') { + outputBlocks.push(`[Ran retrieval]`); + } else if (toolCall.type === 'code_interpreter') { + const output = toolCall.code_interpreter.outputs + .map((output) => (output.type === 'logs' ? output.logs : `<${output.type} output>`)) + .join('\n'); + outputBlocks.push(`[Code interpreter input]`); + outputBlocks.push(toolCall.code_interpreter.input); + outputBlocks.push(`[Code interpreter output]`); + outputBlocks.push(output); + } else { + outputBlocks.push(`[Unknown tool call type: ${(toolCall as any).type}]`); + } + } + } else { + outputBlocks.push(`[Unknown step type: ${(step.step_details as any).type}]`); + } + } + + return { + output: outputBlocks.join('\n\n').trim(), + /* + tokenUsage: { + total: data.usage.total_tokens, + prompt: data.usage.prompt_tokens, + completion: data.usage.completion_tokens, + }, + */ + }; + } +} + export const DefaultEmbeddingProvider = new OpenAiEmbeddingProvider('text-embedding-ada-002'); export const DefaultGradingProvider = new OpenAiChatCompletionProvider('gpt-4-0613'); export const DefaultSuggestionsProvider = new OpenAiChatCompletionProvider('gpt-4-0613'); diff --git a/test/providers.test.ts b/test/providers.test.ts index 3322ecf8388a..64774b6e0245 100644 --- a/test/providers.test.ts +++ b/test/providers.test.ts @@ -1,6 +1,10 @@ import fetch from 'node-fetch'; -import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from '../src/providers/openai'; +import { + OpenAiAssistantProvider, + OpenAiCompletionProvider, + OpenAiChatCompletionProvider, +} from '../src/providers/openai'; import { AnthropicCompletionProvider } from '../src/providers/anthropic'; import { LlamaProvider } from '../src/providers/llama'; @@ -374,6 +378,11 @@ describe('providers', () => { expect(provider).toBeInstanceOf(OpenAiCompletionProvider); }); + test('loadApiProvider with openai:assistant', async () => { + const provider = await loadApiProvider('openai:assistant:foobar'); + expect(provider).toBeInstanceOf(OpenAiAssistantProvider); + }); + test('loadApiProvider with openai:chat:modelName', async () => { const provider = await loadApiProvider('openai:chat:gpt-3.5-turbo'); expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);