Skip to content

Commit

Permalink
feat: Support for OpenAI assistants API (promptfoo#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
typpo authored Nov 19, 2023
1 parent d0f3d6c commit 3f3208d
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 2 deletions.
41 changes: 41 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"js-yaml": "^4.1.0",
"node-fetch": "^2.6.7",
"nunjucks": "^3.2.4",
"openai": "^4.19.0",
"opener": "^1.5.2",
"replicate": "^0.12.3",
"rouge": "^1.0.3",
Expand Down
3 changes: 2 additions & 1 deletion src/cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ export async function fetchWithCache(
options: RequestInit = {},
timeout: number,
format: 'json' | 'text' = 'json',
bust: boolean = false,
): Promise<{ data: any; cached: boolean }> {
if (!enabled) {
if (!enabled || bust) {
const resp = await fetchWithRetries(url, options, timeout);
return {
cached: false,
Expand Down
4 changes: 4 additions & 0 deletions src/providers.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import path from 'path';

import {
OpenAiAssistantProvider,
OpenAiCompletionProvider,
OpenAiChatCompletionProvider,
OpenAiEmbeddingProvider,
Expand Down Expand Up @@ -127,6 +128,8 @@ export async function loadApiProvider(
return new OpenAiChatCompletionProvider(modelType, providerOptions);
} else if (OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelType)) {
return new OpenAiCompletionProvider(modelType, providerOptions);
} else if (modelType === 'assistant') {
return new OpenAiAssistantProvider(modelName, providerOptions);
} else {
throw new Error(
`Unknown OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
Expand Down Expand Up @@ -237,6 +240,7 @@ export async function loadApiProvider(
export default {
OpenAiCompletionProvider,
OpenAiChatCompletionProvider,
OpenAiAssistantProvider,
AnthropicCompletionProvider,
ReplicateProvider,
LocalAiCompletionProvider,
Expand Down
182 changes: 182 additions & 0 deletions src/providers/openai.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import OpenAI from 'openai';

import logger from '../logger';
import { fetchWithCache } from '../cache';
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
Expand Down Expand Up @@ -32,6 +34,17 @@ interface OpenAiCompletionOptions {
organization?: string;
}

function failApiCall(err: any) {
if (err instanceof OpenAI.APIError) {
return {
error: `API error: ${err.type} ${err.message}`,
};
}
return {
error: `API error: ${String(err)}`,
};
}

class OpenAiGenericProvider implements ApiProvider {
modelName: string;

Expand Down Expand Up @@ -342,6 +355,175 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
}
}

interface AssistantMessagesResponseDataContent {
type: string;
text?: {
value: string;
};
}

interface AssistantMessagesResponseData {
data: {
role: string;
content?: AssistantMessagesResponseDataContent[];
}[];
}

interface OpenAiAssistantOptions {
modelName?: string;
instructions?: string;
tools?: OpenAI.Beta.Threads.ThreadCreateAndRunParams['tools'];
metadata?: object[];
}

function toTitleCase(str: string) {
return str.replace(/\w\S*/g, (txt) => txt.charAt(0).toUpperCase() + txt.substr(1).toLowerCase());
}

export class OpenAiAssistantProvider extends OpenAiGenericProvider {
assistantId: string;
assistantConfig: OpenAiAssistantOptions;

constructor(
assistantId: string,
options: { config?: OpenAiAssistantOptions; id?: string; env?: EnvOverrides } = {},
) {
super(assistantId, {});
this.assistantConfig = options.config || {};
this.assistantId = assistantId;
}

async callApi(prompt: string): Promise<ProviderResponse> {
if (!this.getApiKey()) {
throw new Error(
'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.',
);
}

const openai = new OpenAI({
maxRetries: 3,
timeout: REQUEST_TIMEOUT_MS,
});

const messages = parseChatPrompt(prompt, [
{ role: 'user', content: prompt },
]) as OpenAI.Beta.Threads.ThreadCreateParams.Message[];
const body: OpenAI.Beta.Threads.ThreadCreateAndRunParams = {
assistant_id: this.assistantId,
model: this.assistantConfig.modelName || undefined,
instructions: this.assistantConfig.instructions || undefined,
tools: this.assistantConfig.tools || undefined,
metadata: this.assistantConfig.metadata || undefined,
thread: {
messages,
},
};

logger.debug(`Calling OpenAI API, creating thread run: ${JSON.stringify(body)}`);
let run;
try {
run = await openai.beta.threads.createAndRun(body);
} catch (err) {
return failApiCall(err);
}

logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`);

while (run.status === 'in_progress' || run.status === 'queued') {
await new Promise((resolve) => setTimeout(resolve, 1000));

logger.debug(`Calling OpenAI API, getting thread run ${run.id} status`);
try {
run = await openai.beta.threads.runs.retrieve(run.thread_id, run.id);
} catch (err) {
return failApiCall(err);
}
logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`);
}

if (run.status !== 'completed') {
if (run.last_error) {
return {
error: `Thread run failed: ${run.last_error.message}`,
};
}
return {
error: `Thread run failed: ${run.status}`,
};
}

// Get run steps
logger.debug(`Calling OpenAI API, getting thread run steps for ${run.thread_id}`);
let steps;
try {
steps = await openai.beta.threads.runs.steps.list(run.thread_id, run.id, {
order: 'asc',
});
} catch (err) {
return failApiCall(err);
}
logger.debug(`\tOpenAI thread run steps API response: ${JSON.stringify(steps)}`);

const outputBlocks = [];
for (const step of steps.data) {
if (step.step_details.type === 'message_creation') {
logger.debug(`Calling OpenAI API, getting message ${step.id}`);
let message;
try {
message = await openai.beta.threads.messages.retrieve(
run.thread_id,
step.step_details.message_creation.message_id,
);
} catch (err) {
return failApiCall(err);
}
logger.debug(`\tOpenAI thread run step message API response: ${JSON.stringify(message)}`);

const content = message.content
.map((content) =>
content.type === 'text' ? content.text.value : `<${content.type} output>`,
)
.join('\n');
outputBlocks.push(`[${toTitleCase(message.role)}] ${content}`);
} else if (step.step_details.type === 'tool_calls') {
for (const toolCall of step.step_details.tool_calls) {
if (toolCall.type === 'function') {
outputBlocks.push(
`[Call function ${toolCall.function.name} with arguments ${toolCall.function.arguments}]`,
);
outputBlocks.push(`[Function output: ${toolCall.function.output}]`);
} else if (toolCall.type === 'retrieval') {
outputBlocks.push(`[Ran retrieval]`);
} else if (toolCall.type === 'code_interpreter') {
const output = toolCall.code_interpreter.outputs
.map((output) => (output.type === 'logs' ? output.logs : `<${output.type} output>`))
.join('\n');
outputBlocks.push(`[Code interpreter input]`);
outputBlocks.push(toolCall.code_interpreter.input);
outputBlocks.push(`[Code interpreter output]`);
outputBlocks.push(output);
} else {
outputBlocks.push(`[Unknown tool call type: ${(toolCall as any).type}]`);
}
}
} else {
outputBlocks.push(`[Unknown step type: ${(step.step_details as any).type}]`);
}
}

return {
output: outputBlocks.join('\n\n').trim(),
/*
tokenUsage: {
total: data.usage.total_tokens,
prompt: data.usage.prompt_tokens,
completion: data.usage.completion_tokens,
},
*/
};
}
}

export const DefaultEmbeddingProvider = new OpenAiEmbeddingProvider('text-embedding-ada-002');
export const DefaultGradingProvider = new OpenAiChatCompletionProvider('gpt-4-0613');
export const DefaultSuggestionsProvider = new OpenAiChatCompletionProvider('gpt-4-0613');
11 changes: 10 additions & 1 deletion test/providers.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import fetch from 'node-fetch';

import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from '../src/providers/openai';
import {
OpenAiAssistantProvider,
OpenAiCompletionProvider,
OpenAiChatCompletionProvider,
} from '../src/providers/openai';
import { AnthropicCompletionProvider } from '../src/providers/anthropic';
import { LlamaProvider } from '../src/providers/llama';

Expand Down Expand Up @@ -374,6 +378,11 @@ describe('providers', () => {
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
});

test('loadApiProvider with openai:assistant', async () => {
const provider = await loadApiProvider('openai:assistant:foobar');
expect(provider).toBeInstanceOf(OpenAiAssistantProvider);
});

test('loadApiProvider with openai:chat:modelName', async () => {
const provider = await loadApiProvider('openai:chat:gpt-3.5-turbo');
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
Expand Down

0 comments on commit 3f3208d

Please sign in to comment.