Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Dynamic inference provider mapping #1173

Merged
merged 25 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/hub/src/lib/list-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export const MODEL_EXPANDABLE_KEYS = [
"downloadsAllTime",
"gated",
"gitalyUid",
"inferenceProviderMapping",
"lastModified",
"library_name",
"likes",
Expand Down
5 changes: 4 additions & 1 deletion packages/hub/src/types/api/api-model.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { ModelLibraryKey, TransformersInfo } from "@huggingface/tasks";
import type { ModelLibraryKey, TransformersInfo, WidgetType } from "@huggingface/tasks";
import type { License, PipelineType } from "../public";

export interface ApiModelInfo {
Expand All @@ -18,6 +18,9 @@ export interface ApiModelInfo {
downloadsAllTime: number;
files: string[];
gitalyUid: string;
inferenceProviderMapping: Partial<
Record<string, { providerId: string; status: "prod" | "staging"; task: WidgetType }>
>;
julien-c marked this conversation as resolved.
Show resolved Hide resolved
lastAuthor: { email: string; user?: string };
lastModified: string; // convert to date
library_name?: ModelLibraryKey;
Expand Down
1 change: 1 addition & 0 deletions packages/inference/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"check": "tsc"
},
"dependencies": {
"@huggingface/hub": "workspace:^",
julien-c marked this conversation as resolved.
Show resolved Hide resolved
"@huggingface/tasks": "workspace:^"
},
"devDependencies": {
Expand Down
3 changes: 3 additions & 0 deletions packages/inference/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 31 additions & 9 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { modelInfo } from "@huggingface/hub";
import type { WidgetType } from "@huggingface/tasks";
import { HF_HUB_URL } from "../config";
import { FAL_AI_API_BASE_URL, FAL_AI_SUPPORTED_MODEL_IDS } from "../providers/fal-ai";
Expand Down Expand Up @@ -53,13 +54,13 @@ export async function makeRequestOptions(
let model: string;
if (!maybeModel) {
if (taskHint) {
model = mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion });
model = await mapModel({ model: await loadDefaultModel(taskHint), provider, taskHint, chatCompletion });
} else {
throw new Error("No model provided, and no default model found for this task");
/// TODO : change error message ^
}
} else {
model = mapModel({ model: maybeModel, provider, taskHint, chatCompletion });
model = await mapModel({ model: maybeModel, provider, taskHint, chatCompletion });
}

/// If accessToken is passed, it should take precedence over includeCredentials
Expand Down Expand Up @@ -153,12 +154,12 @@ export async function makeRequestOptions(
return { url, info };
}

function mapModel(params: {
async function mapModel(params: {
model: string;
provider: InferenceProvider;
taskHint: InferenceTask | undefined;
chatCompletion: boolean | undefined;
}): string {
}): Promise<string> {
if (params.provider === "hf-inference") {
return params.model;
}
Expand All @@ -167,7 +168,29 @@ function mapModel(params: {
}
const task: WidgetType =
params.taskHint === "text-generation" && params.chatCompletion ? "conversational" : params.taskHint;
const model = (() => {

// TODO: cache this call
const info = await modelInfo({ name: params.model, additionalFields: ["inferenceProviderMapping"] });

const inferenceProviderMapping = info.inferenceProviderMapping[params.provider];
// If provider listed => takes precedence over hard-coded mapping
if (inferenceProviderMapping) {
if (inferenceProviderMapping.task !== task) {
throw new Error(
`Model ${params.model} is not supported for task ${task} and provider ${params.provider}. Supported task: ${inferenceProviderMapping.task}.`
);
}
if (inferenceProviderMapping.status === "staging") {
console.warn(
`Model ${params.model} is in staging for provider ${params.provider}. Use it only for test purposes.`
);
}
// TODO: how is it handled server-side if model has multiple tasks (e.g. `text-generation` + `conversational`)?
return inferenceProviderMapping.providerId;
}

// Otherwise, default to hard-coded mapping
const modelFromMapping = (() => {
switch (params.provider) {
case "fal-ai":
return FAL_AI_SUPPORTED_MODEL_IDS[task]?.[params.model];
Expand All @@ -179,11 +202,10 @@ function mapModel(params: {
return TOGETHER_SUPPORTED_MODEL_IDS[task]?.[params.model];
}
})();

if (!model) {
throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}`);
if (!modelFromMapping) {
throw new Error(`Model ${params.model} is not supported for task ${task} and provider ${params.provider}.`);
}
return model;
return modelFromMapping;
}

function makeUrl(params: {
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ type SambanovaId = string;
export const SAMBANOVA_SUPPORTED_MODEL_IDS: ProviderMapping<SambanovaId> = {
/** Chat completion / conversational */
conversational: {
"allenai/Llama-3.1-Tulu-3-405B":"Llama-3.1-Tulu-3-405B",
"allenai/Llama-3.1-Tulu-3-405B": "Llama-3.1-Tulu-3-405B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B": "DeepSeek-R1-Distill-Llama-70B",
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
Expand Down
Loading