Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable messages api #581

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ import { textGeneration } from "@huggingface/inference";

await textGeneration({
accessToken: "hf_...",
model: "model_or_endpoint",
model: "model",
inputs: ...,
parameters: ...
parameters: ...,
endpointUrl: "custom endpoint url",
})
```

Expand Down
8 changes: 4 additions & 4 deletions packages/inference/src/HfInference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ type TaskWithNoAccessToken = {
) => ReturnType<Task[key]>;
};

type TaskWithNoAccessTokenNoModel = {
type TaskWithNoAccessTokenNoEndpointUrl = {
[key in keyof Task]: (
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "model">,
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "endpointUrl">,
options?: Parameters<Task[key]>[1]
) => ReturnType<Task[key]>;
};
Expand Down Expand Up @@ -57,12 +57,12 @@ export class HfInferenceEndpoint {
enumerable: false,
value: (params: RequestArgs, options: Options) =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
fn({ ...params, accessToken, model: endpointUrl } as any, { ...defaultOptions, ...options }),
fn({ ...params, accessToken, endpointUrl } as any, { ...defaultOptions, ...options }),
});
}
}
}

export interface HfInference extends TaskWithNoAccessToken {}

export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoModel {}
export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoEndpointUrl {}
20 changes: 12 additions & 8 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { InferenceTask, Options, RequestArgs } from "../types";
import { omit } from "../utils/omit";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

Expand All @@ -24,8 +25,7 @@ export async function makeRequestOptions(
taskHint?: InferenceTask;
}
): Promise<{ url: string; info: RequestInit }> {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { accessToken, model: _model, ...otherArgs } = args;
const { accessToken, endpointUrl, ...otherArgs } = args;
let { model } = args;
const {
forceTask: task,
Expand All @@ -34,7 +34,7 @@ export async function makeRequestOptions(
wait_for_model,
use_cache,
dont_load_model,
...otherOptions
// ...otherOptions
} = options ?? {};

const headers: Record<string, string> = {};
Expand Down Expand Up @@ -78,10 +78,16 @@ export async function makeRequestOptions(
}

const url = (() => {
if (endpointUrl && isUrl(model)) {
throw new TypeError("Both model and endpointUrl cannot be URLs");
}
if (isUrl(model)) {
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
return model;
}

if (endpointUrl) {
return endpointUrl;
}
if (task) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
}
Expand All @@ -103,19 +109,17 @@ export async function makeRequestOptions(
} else if (includeCredentials === undefined) {
credentials = "same-origin";
}

const info: RequestInit = {
headers,
method: "POST",
body: binary
? args.data
: JSON.stringify({
...otherArgs,
options: options && otherOptions,
coyotte508 marked this conversation as resolved.
Show resolved Hide resolved
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
...(options && Object.keys(options).length > 0 ? { options: options } : {}),
}),
credentials,
signal: options?.signal,
};

return { url, info };
}
3 changes: 3 additions & 0 deletions packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ export async function* streamingRequest<T>(
onChunk(value);
for (const event of events) {
if (event.data.length > 0) {
if (event.data === "[DONE]") {
return;
}
const data = JSON.parse(event.data);
if (typeof data === "object" && data !== null && "error" in data) {
throw new Error(data.error);
Expand Down
23 changes: 23 additions & 0 deletions packages/inference/src/tasks/nlp/textGenerationStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,29 @@ export interface TextGenerationStreamOutput {
* Only available when the generation is finished
*/
details: TextGenerationStreamDetails | null;
/**
* If Message API compatible
*/
choices?: Choice[];
}

export interface Choice {
index: number;
delta: {
role: string;
content?: string;
tool_calls?: {
index: number;
id: string;
type: string;
function: {
name?: string;
arguments: string;
};
};
};
logprobs?: Record<string, unknown>;
finish_reason?: string;
}

/**
Expand Down
18 changes: 16 additions & 2 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,29 @@ export interface BaseArgs {
*/
accessToken?: string;
/**
* The model to use. Can be a full URL for a dedicated inference endpoint.
* The model to use.
*
* If not specified, will call huggingface.co/api/tasks to get the default model for the task.
*
* /!\ Legacy behavior allows this to be an URL, but this is deprecated and will be removed in the future.
* Use the `endpointUrl` parameter instead.
*/
model?: string;

/**
* The URL of the endpoint to use. If not specified, will call huggingface.co/api/tasks to get the default endpoint for the task.
*
* If specified, will use this URL instead of the default one.
*/
endpointUrl?: string;
}

export type RequestArgs = BaseArgs &
({ data: Blob | ArrayBuffer } | { inputs: unknown }) & {
(
| { data: Blob | ArrayBuffer }
| { inputs: unknown }
| { messages?: Array<{ role: "user" | "assistant"; content: string }> }
) & {
parameters?: Record<string, unknown>;
accessToken?: string;
};
82 changes: 71 additions & 11 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,34 +203,31 @@ describe.concurrent(
});

it("textGenerationStream - google/flan-t5-xxl", async () => {
const phrase = "one two three four";
const response = hf.textGenerationStream({
model: "google/flan-t5-xxl",
inputs: `repeat "${phrase}"`,
inputs: "Please answer the following question: complete one two and ____.",
});

const makeExpectedReturn = (tokenText: string, fullPhrase: string): TextGenerationStreamOutput => {
const eot = tokenText === "</s>";
const eot = tokenText === "</s>" || tokenText === null;
return {
details: null,
token: {
id: expect.any(Number),
logprob: expect.any(Number),
text: expect.stringContaining(tokenText),
special: eot,
text: expect.any(String) || null,
special: expect.any(Boolean),
},
generated_text: eot ? fullPhrase : null,
};
};

const expectedTokens = phrase.split(" ");
// eot token
expectedTokens.push("</s>");
const word = "three";
const expectedTokens = [word, "</s>"];

for await (const ret of response) {
const expectedToken = expectedTokens.shift();
assert(expectedToken);
expect(ret).toMatchObject(makeExpectedReturn(expectedToken, phrase));
expect(ret).toMatchObject(makeExpectedReturn(expectedToken, word));
}
});

Expand All @@ -244,7 +241,7 @@ describe.concurrent(
});

await expect(response.next()).rejects.toThrow(
"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 4096. Given: 17 `inputs` tokens and 10000 `max_new_tokens`"
"Input validation error: `inputs` tokens + `max_new_tokens` must be <= 2048. Given: 17 `inputs` tokens and 10000 `max_new_tokens`"
);
});

Expand Down Expand Up @@ -651,6 +648,69 @@ describe.concurrent(
});
expect(generated_text).toEqual("three");
});

it("textGenerationStream - OpenAI Specs", async () => {
const ep = hf.endpoint(
"https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2/v1/chat/completions"
);
const stream = ep.textGenerationStream({
model: "tgi",
messages: [{ role: "user", content: "Complete the equation 1+1= ,just the answer" }],
parameters: {
max_tokens: 500,
return_full_text: false,
temperature: 0.0,
seed: 0,
},
});
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
expect(out).toContain("The answer to the equation 1 + 1 is 2.</s>");
});
it("mistral - OpenAI Specs", async () => {
const MISTRAL_KEY = env.MISTRAL_KEY;
if (!MISTRAL_KEY) {
console.warn("Skipping test because MISTRAL_KEY is not set");
return;
}
const hf = new HfInference(MISTRAL_KEY);
const ep = hf.endpoint("https://api.mistral.ai/v1/chat/completions");
const stream = ep.streamingRequest({
model: "mistral-tiny",
messages: [{ role: "user", content: "Complete the equation one + one = , just the answer" }],
}) as AsyncGenerator<TextGenerationStreamOutput>;
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
expect(out).toContain("The answer to the equation one + one is two.");
});
it("openai - OpenAI Specs", async () => {
const OPENAI_KEY = env.OPENAI_KEY;
if (!OPENAI_KEY) {
console.warn("Skipping test because OPENAI_KEY is not set");
return;
}
const hf = new HfInference(OPENAI_KEY);
const ep = hf.endpoint("https://api.openai.com/v1/chat/completions");
const stream = ep.streamingRequest({
model: "gpt-3.5-turbo",
messages: [{ role: "user", content: "Complete the equation one + one =" }],
}) as AsyncGenerator<TextGenerationStreamOutput>;
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
expect(out).toContain("two");
});
},
TIMEOUT
);
Loading