Skip to content

Commit

Permalink
Extend endpointOai.ts to allow usage of extra sampling parameters (hu…
Browse files Browse the repository at this point in the history
…ggingface#1032)

* Extend endpointOai.ts to allow usage of extra sampling parameters when calling vllm as an OpenAI compatible

* refactor : prettier endpointOai.ts

* Fix: Corrected type imports in endpointOai.ts

* Simplifies code a bit and adds `extraBody` to open ai endpooint

* Update zod schema to allow any type in extraBody

---------

Co-authored-by: Nathan Sarrazin <[email protected]>
  • Loading branch information
taeminlee and nsarrazin authored May 8, 2024
1 parent 8bc9f67 commit 0a87cce
Showing 1 changed file with 34 additions and 25 deletions.
59 changes: 34 additions & 25 deletions src/lib/server/endpoints/openai/endpointOai.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { z } from "zod";
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions";
import { buildPrompt } from "$lib/buildPrompt";
import { env } from "$env/dynamic/private";
import type { Endpoint } from "../endpoints";
Expand All @@ -16,12 +18,13 @@ export const endpointOAIParametersSchema = z.object({
.default("chat_completions"),
defaultHeaders: z.record(z.string()).optional(),
defaultQuery: z.record(z.string()).optional(),
extraBody: z.record(z.any()).optional(),
});

export async function endpointOai(
input: z.input<typeof endpointOAIParametersSchema>
): Promise<Endpoint> {
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery } =
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery, extraBody } =
endpointOAIParametersSchema.parse(input);
let OpenAI;
try {
Expand All @@ -47,19 +50,22 @@ export async function endpointOai(
});

const parameters = { ...model.parameters, ...generateSettings };
const body: CompletionCreateParamsStreaming = {
model: model.id ?? model.name,
prompt,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
};

return openAICompletionToTextGenerationStream(
await openai.completions.create({
model: model.id ?? model.name,
prompt,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
})
);
const openAICompletion = await openai.completions.create(body, {
body: { ...body, ...extraBody },
});

return openAICompletionToTextGenerationStream(openAICompletion);
};
} else if (completion === "chat_completions") {
return async ({ messages, preprompt, generateSettings }) => {
Expand All @@ -77,19 +83,22 @@ export async function endpointOai(
}

const parameters = { ...model.parameters, ...generateSettings };
const body: ChatCompletionCreateParamsStreaming = {
model: model.id ?? model.name,
messages: messagesOpenAI,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
};

const openChatAICompletion = await openai.chat.completions.create(body, {
body: { ...body, ...extraBody },
});

return openAIChatToTextGenerationStream(
await openai.chat.completions.create({
model: model.id ?? model.name,
messages: messagesOpenAI,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
})
);
return openAIChatToTextGenerationStream(openChatAICompletion);
};
} else {
throw new Error("Invalid completion type");
Expand Down

0 comments on commit 0a87cce

Please sign in to comment.