Skip to content

Commit

Permalink
Add support for llama.cpp server (promptfoo#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
typpo authored Jul 31, 2023
1 parent 8c044b5 commit 690729c
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ scratch
.DS_Store
node_modules
dist
.aider*
6 changes: 5 additions & 1 deletion src/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from './provid
import { AnthropicCompletionProvider } from './providers/anthropic';
import { ReplicateProvider } from './providers/replicate';
import { LocalAiCompletionProvider, LocalAiChatProvider } from './providers/localai';
import { LlamaProvider } from './providers/llama';
import { ScriptCompletionProvider } from './providers/scriptCompletion';
import {
AzureOpenAiChatCompletionProvider,
Expand Down Expand Up @@ -133,7 +134,10 @@ export async function loadApiProvider(
return new ReplicateProvider(modelName, undefined, context?.config);
}

if (providerPath?.startsWith('localai:')) {
if (providerPath === 'llama' || providerPath.startsWith('llama:')) {
const modelName = providerPath.split(':')[1];
return new LlamaProvider(modelName, context?.config);
} else if (providerPath?.startsWith('localai:')) {
const options = providerPath.split(':');
const modelType = options[1];
const modelName = options[2];
Expand Down
97 changes: 97 additions & 0 deletions src/providers/llama.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import { fetchJsonWithCache } from '../cache';
import { REQUEST_TIMEOUT_MS } from './shared';

import type { ApiProvider, ProviderResponse } from '../types.js';

interface LlamaCompletionOptions {
n_predict?: number;
temperature?: number;
top_k?: number;
top_p?: number;
n_keep?: number;
stop?: string[];
repeat_penalty?: number;
repeat_last_n?: number;
penalize_nl?: boolean;
presence_penalty?: number;
frequency_penalty?: number;
mirostat?: boolean;
mirostat_tau?: number;
mirostat_eta?: number;
seed?: number;
ignore_eos?: boolean;
logit_bias?: Record<string, number>;
}

export class LlamaProvider implements ApiProvider {
modelName: string;
apiBaseUrl: string;
options?: LlamaCompletionOptions;

constructor(modelName: string, options?: LlamaCompletionOptions) {
this.modelName = modelName;
this.apiBaseUrl = 'http://localhost:8080';
this.options = options;
}

id(): string {
return `llama:${this.modelName}`;
}

toString(): string {
return `[Llama Provider ${this.modelName}]`;
}

async callApi(prompt: string, options?: LlamaCompletionOptions): Promise<ProviderResponse> {
options = Object.assign({}, this.options, options);
const body = {
prompt,
n_predict: options?.n_predict || 512,
temperature: options?.temperature,
top_k: options?.top_k,
top_p: options?.top_p,
n_keep: options?.n_keep,
stop: options?.stop,
repeat_penalty: options?.repeat_penalty,
repeat_last_n: options?.repeat_last_n,
penalize_nl: options?.penalize_nl,
presence_penalty: options?.presence_penalty,
frequency_penalty: options?.frequency_penalty,
mirostat: options?.mirostat,
mirostat_tau: options?.mirostat_tau,
mirostat_eta: options?.mirostat_eta,
seed: options?.seed,
ignore_eos: options?.ignore_eos,
logit_bias: options?.logit_bias,
};

let response;
try {
response = await fetchJsonWithCache(
`${this.apiBaseUrl}/completion`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(body),
},
REQUEST_TIMEOUT_MS,
);
} catch (err) {
return {
error: `API call error: ${String(err)}`,
};
}

try {
return {
output: response.data.content,
};
} catch (err) {
return {
error: `API response error: ${String(err)}: ${JSON.stringify(response.data)}`,
};
}
}
}
22 changes: 22 additions & 0 deletions test/providers.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import fetch from 'node-fetch';
import { fetchJsonWithCache } from '../src/cache';

import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from '../src/providers/openai';
import { AnthropicCompletionProvider } from '../src/providers/anthropic';
import { LlamaProvider } from '../src/providers/llama';

import { disableCache, enableCache } from '../src/cache.js';
import { loadApiProvider, loadApiProviders } from '../src/providers.js';
Expand Down Expand Up @@ -150,6 +152,21 @@ describe('providers', () => {
expect(result.tokenUsage).toEqual({});
});

test('LlamaProvider callApi', async () => {
const mockResponse = {
json: jest.fn().mockResolvedValue({
content: 'Test output',
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);

const provider = new LlamaProvider('llama.cpp');
const result = await provider.callApi('Test prompt');

expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
});

test('loadApiProvider with openai:chat', async () => {
const provider = await loadApiProvider('openai:chat');
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
Expand Down Expand Up @@ -190,6 +207,11 @@ describe('providers', () => {
expect(provider).toBeInstanceOf(AnthropicCompletionProvider);
});

test('loadApiProvider with llama:modelName', async () => {
const provider = await loadApiProvider('llama');
expect(provider).toBeInstanceOf(LlamaProvider);
});

test('loadApiProvider with RawProviderConfig', async () => {
const rawProviderConfig = {
'openai:chat': {
Expand Down

0 comments on commit 690729c

Please sign in to comment.