Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for Kokoro via Replicate #1153

Merged
merged 9 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping<ReplicateId> = {
"stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
},
"text-to-speech": {
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26",
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts: 3c645149db020c85d080e2f8cfe482a0e68189a922cde964fa9e80fb179191f3",
hanouticelina marked this conversation as resolved.
Show resolved Hide resolved
"hexgrad/Kokoro-82M": "jaaari/kokoro-82m:dfdf537ba482b029e0a761699e6f55e9162cfd159270bfe0e44857caa5f275a6",
hanouticelina marked this conversation as resolved.
Show resolved Hide resolved
},
"text-to-video": {
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",
Expand Down
13 changes: 11 additions & 2 deletions packages/inference/src/tasks/audio/textToSpeech.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type { TextToSpeechInput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { omit } from "../../utils/omit";
import { request } from "../custom/request";

type TextToSpeechArgs = BaseArgs & TextToSpeechInput;

interface OutputUrlTextToSpeechGeneration {
Expand All @@ -13,7 +13,16 @@ interface OutputUrlTextToSpeechGeneration {
* Recommended model: espnet/kan-bayashi_ljspeech_vits
*/
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
const res = await request<Blob | OutputUrlTextToSpeechGeneration>(args, {
// Replicate models expects "text" instead of "inputs"
const payload =
args.provider === "replicate"
? {
...omit(args, ["inputs", "parameters"]),
...args.parameters,
text: args.inputs,
}
: args;
const res = await request<Blob | OutputUrlTextToSpeechGeneration>(payload, {
...options,
taskHint: "text-to-speech",
});
Expand Down
4 changes: 2 additions & 2 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { PipelineType } from "@huggingface/tasks";
import type { ChatCompletionInput } from "@huggingface/tasks";
import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks";

/**
* HF model id, like "meta-llama/Llama-3.3-70B-Instruct"
Expand Down Expand Up @@ -88,6 +87,7 @@ export type RequestArgs = BaseArgs &
| { data: Blob | ArrayBuffer }
| { inputs: unknown }
| { prompt: string }
| { text: string }
| { audio_url: string }
| ChatCompletionInput
) & {
Expand Down
20 changes: 15 additions & 5 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { expect, it, describe, assert } from "vitest";
import { assert, describe, expect, it } from "vitest";

import type { ChatCompletionStreamOutput } from "@huggingface/tasks";

import { chatCompletion, FAL_AI_SUPPORTED_MODEL_IDS, HfInference } from "../src";
import "./vcr";
import { readTestFile } from "./test-files";
import { textToVideo } from "../src/tasks/cv/textToVideo";
import { readTestFile } from "./test-files";
import "./vcr";

const TIMEOUT = 60000 * 3;
const env = import.meta.env;
Expand Down Expand Up @@ -939,11 +939,21 @@ describe.concurrent("HfInference", () => {
expect(res).toBeInstanceOf(Blob);
});

it("textToSpeech OuteTTS", async () => {
it.skip("textToSpeech OuteTTS - usually Cold", async () => {
const res = await client.textToSpeech({
model: "OuteAI/OuteTTS-0.3-500M",
provider: "replicate",
inputs: "OuteTTS is a frontier TTS model for its size of 1 Billion parameters",
text: "OuteTTS is a frontier TTS model for its size of 1 Billion parameters",
});

expect(res).toBeInstanceOf(Blob);
});

it("textToSpeech Kokoro", async () => {
const res = await client.textToSpeech({
model: "hexgrad/Kokoro-82M",
provider: "replicate",
text: "Kokoro is a frontier TTS model for its size of 1 Billion parameters",
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it accept parameters?
Would be nice to add some to make sure they payload has the proper shape

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but it does not accept any of the parameters listed in TextToSpeechParameters (see the schema of the model here). Since each model on Replicate runs in its own docker container, models of the same task can have different parameters (and even more problematic: different input/output format). In huggingface_hub, we added an argument extra_body to allow users to pass any provider- or model-specific parameters.


expect(res).toBeInstanceOf(Blob);
Expand Down
Loading