diff --git a/README.md b/README.md index e4a5a8af1b2..636e40a805f 100644 --- a/README.md +++ b/README.md @@ -618,6 +618,24 @@ MODELS=`[ ``` +##### LangServe + +LangChain applications that are deployed using LangServe can be called with the following config: + +``` +MODELS=`[ +//... + { + "name": "summarization-chain", //model-name + "endpoints" : [{ + "type": "langserve", + "url" : "http://127.0.0.1:8100", + }] + }, +]` + +``` + ### Custom endpoint authorization #### Basic and Bearer diff --git a/src/lib/server/endpoints/endpoints.ts b/src/lib/server/endpoints/endpoints.ts index fd6dfbdeacd..4e28a4b5a91 100644 --- a/src/lib/server/endpoints/endpoints.ts +++ b/src/lib/server/endpoints/endpoints.ts @@ -17,6 +17,9 @@ import endpointCloudflare, { endpointCloudflareParametersSchema, } from "./cloudflare/endpointCloudflare"; import { endpointCohere, endpointCohereParametersSchema } from "./cohere/endpointCohere"; +import endpointLangserve, { + endpointLangserveParametersSchema, +} from "./langserve/endpointLangserve"; // parameters passed when generating text export interface EndpointParameters { @@ -48,6 +51,7 @@ export const endpoints = { vertex: endpointVertex, cloudflare: endpointCloudflare, cohere: endpointCohere, + langserve: endpointLangserve, }; export const endpointSchema = z.discriminatedUnion("type", [ @@ -60,5 +64,6 @@ export const endpointSchema = z.discriminatedUnion("type", [ endpointVertexParametersSchema, endpointCloudflareParametersSchema, endpointCohereParametersSchema, + endpointLangserveParametersSchema, ]); export default endpoints; diff --git a/src/lib/server/endpoints/langserve/endpointLangserve.ts b/src/lib/server/endpoints/langserve/endpointLangserve.ts new file mode 100644 index 00000000000..2c5a475c967 --- /dev/null +++ b/src/lib/server/endpoints/langserve/endpointLangserve.ts @@ -0,0 +1,128 @@ +import { buildPrompt } from "$lib/buildPrompt"; +import { z } from "zod"; +import type { Endpoint } from "../endpoints"; +import type { TextGenerationStreamOutput } from "@huggingface/inference"; + +export const endpointLangserveParametersSchema = z.object({ + weight: z.number().int().positive().default(1), + model: z.any(), + type: z.literal("langserve"), + url: z.string().url(), +}); + +export function endpointLangserve( + input: z.input +): Endpoint { + const { url, model } = endpointLangserveParametersSchema.parse(input); + + return async ({ messages, preprompt, continueMessage }) => { + const prompt = await buildPrompt({ + messages, + continueMessage, + preprompt, + model, + }); + + const r = await fetch(`${url}/stream`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + input: { text: prompt }, + }), + }); + + if (!r.ok) { + throw new Error(`Failed to generate text: ${await r.text()}`); + } + + const encoder = new TextDecoderStream(); + const reader = r.body?.pipeThrough(encoder).getReader(); + + return (async function* () { + let stop = false; + let generatedText = ""; + let tokenId = 0; + let accumulatedData = ""; // Buffer to accumulate data chunks + + while (!stop) { + // Read the stream and log the outputs to console + const out = (await reader?.read()) ?? { done: false, value: undefined }; + + // If it's done, we cancel + if (out.done) { + reader?.cancel(); + return; + } + + if (!out.value) { + return; + } + + // Accumulate the data chunk + accumulatedData += out.value; + // Keep read data to check event type + const eventData = out.value; + + // Process each complete JSON object in the accumulated data + while (accumulatedData.includes("\n")) { + // Assuming each JSON object ends with a newline + const endIndex = accumulatedData.indexOf("\n"); + let jsonString = accumulatedData.substring(0, endIndex).trim(); + // Remove the processed part from the buffer + + accumulatedData = accumulatedData.substring(endIndex + 1); + + // Stopping with end event + if (eventData.startsWith("event: end")) { + stop = true; + yield { + token: { + id: tokenId++, + text: "", + logprob: 0, + special: true, + }, + generated_text: generatedText, + details: null, + } satisfies TextGenerationStreamOutput; + reader?.cancel(); + continue; + } + + if (eventData.startsWith("event: data") && jsonString.startsWith("data: ")) { + jsonString = jsonString.slice(6); + let data = null; + + // Handle the parsed data + try { + data = JSON.parse(jsonString); + } catch (e) { + console.error("Failed to parse JSON", e); + console.error("Problematic JSON string:", jsonString); + continue; // Skip this iteration and try the next chunk + } + // Assuming content within data is a plain string + if (data) { + generatedText += data; + const output: TextGenerationStreamOutput = { + token: { + id: tokenId++, + text: data, + logprob: 0, + special: false, + }, + generated_text: null, + details: null, + }; + yield output; + } + } + } + } + })(); + }; +} + +export default endpointLangserve; diff --git a/src/lib/server/models.ts b/src/lib/server/models.ts index 1991fc0c54f..f48d25288b3 100644 --- a/src/lib/server/models.ts +++ b/src/lib/server/models.ts @@ -177,6 +177,8 @@ const addEndpoint = (m: Awaited>) => ({ return await endpoints.cloudflare(args); case "cohere": return await endpoints.cohere(args); + case "langserve": + return await endpoints.langserve(args); default: // for legacy reason return endpoints.tgi(args);