Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🦺 Validator util for hf/inference #385

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions packages/inference/src/lib/InferenceOutputError.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
export class InferenceOutputError extends TypeError {
constructor(message: string) {
constructor(err: unknown) {
super(
`Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
`Invalid inference output: ${
err instanceof Error ? err.message : String(err)
}. Use the 'request' method with the same parameters to do a custom call with no type checking.`,
err instanceof Error ? { cause: err } : undefined
);
this.name = "InferenceOutputError";
}
Expand Down
152 changes: 152 additions & 0 deletions packages/inference/src/lib/validateOutput.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/**
* Heavily inspired by zod
*
* Re-created to avoid extra dependencies
*/

import { InferenceOutputError } from "./InferenceOutputError";

export const z = {
array<T extends { parse: (value: any) => any }>(
coyotte508 marked this conversation as resolved.
Show resolved Hide resolved
items: T
): { parse: (value: any) => ReturnType<T["parse"]>[]; toString: () => string } {
return {
parse: (value: unknown) => {
if (!Array.isArray(value)) {
throw new Error("Expected " + this.toString());
}
try {
return value.map((val) => items.parse(val));
} catch (err) {
throw new Error("Expected " + this.toString(), { cause: err });
}
},
toString(): string {
return `Array<${items.toString()}>`;
},
};
},
first<T extends { parse: (value: any) => any }>(
items: T
): { parse: (value: any) => ReturnType<T["parse"]>; toString: () => string } {
return {
parse: (value: unknown) => {
if (!Array.isArray(value) || value.length === 0) {
throw new Error("Expected " + this.toString());
}
try {
return items.parse(value[0]);
} catch (err) {
throw new Error("Expected " + this.toString(), { cause: err });
}
},
toString(): string {
return `[${items.toString()}]`;
},
};
},
or: <T extends { parse: (value: any) => any }[]>(
...items: T
): { parse: (value: any) => ReturnType<T[number]["parse"]>; toString(): string } => ({
parse: (value: unknown): ReturnType<T[number]["parse"]> => {
const errors: Error[] = [];
for (const item of items) {
try {
return item.parse(value);
} catch (err) {
errors.push(err as Error);
}
}
throw new Error("Expected " + items.map((item) => item.toString()).join(" | "), { cause: errors });
},
toString(): string {
return items.map((item) => item.toString()).join(" | ");
},
}),
object<T extends Record<string, { parse: (value: any) => any }>>(
item: T
): { parse: (value: any) => { [key in keyof T]: ReturnType<T[key]["parse"]> }; toString: () => string } {
return {
parse: (value: unknown) => {
if (typeof value !== "object" || value === null || Array.isArray(value)) {
throw new Error("Expected " + this.toString());
}
return Object.fromEntries(Object.entries(item).map(([key, val]) => [key, val.parse((value as any)[key])])) as {
[key in keyof T]: ReturnType<T[key]["parse"]>;
};
},
toString(): string {
return `{${Object.entries(item)
.map(([key, val]) => `${key}: ${val.toString()}`)
.join(", ")}}`;
},
};
},
string(): { parse: (value: any) => string; toString: () => string } {
return {
parse: (value: unknown): string => {
if (typeof value !== "string") {
throw new Error("Expected " + this.toString());
}
return value;
},
toString(): string {
return "string";
},
};
},
number(): { parse: (value: any) => number; toString: () => string } {
return {
parse: (value: unknown): number => {
if (typeof value !== "number") {
throw new Error("Expected " + this.toString());
}
return value;
},
toString(): string {
return "number";
},
};
},
blob(): { parse: (value: any) => Blob; toString: () => string } {
return {
parse: (value: unknown): Blob => {
if (!(value instanceof Blob)) {
throw new Error("Expected " + this.toString());
}
return value;
},
toString(): string {
return "Blob";
},
};
},
optional<T extends { parse: (value: any) => any }>(
item: T
): { parse: (value: any) => ReturnType<T["parse"]> | undefined; toString: () => string } {
return {
parse: (value: unknown): ReturnType<T["parse"]> | undefined => {
if (value === undefined) {
return undefined;
}
try {
return item.parse(value);
} catch (err) {
throw new Error("Expected " + this.toString(), { cause: err });
}
},
toString(): string {
return `${item.toString()} | undefined`;
},
};
},
};

export function validateOutput<T>(value: unknown, schema: { parse: (value: any) => T }): T {
try {
return schema.parse(value);
} catch (err) {
throw new InferenceOutputError(err);
}
}
9 changes: 2 additions & 7 deletions packages/inference/src/tasks/audio/audioClassification.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -35,10 +35,5 @@ export async function audioClassification(
...options,
taskHint: "audio-classification",
});
const isValidOutput =
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
}
return res;
return validateOutput(res, z.array(z.object({ label: z.string(), score: z.number() })));
}
21 changes: 11 additions & 10 deletions packages/inference/src/tasks/audio/audioToAudio.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -37,13 +37,14 @@ export async function audioToAudio(args: AudioToAudioArgs, options?: Options): P
...options,
taskHint: "audio-to-audio",
});
const isValidOutput =
Array.isArray(res) &&
res.every(
(x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
);
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");
}
return res;
return validateOutput(
res,
z.array(
z.object({
label: z.string(),
blob: z.string(),
"content-type": z.string(),
})
)
);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -28,9 +28,5 @@ export async function automaticSpeechRecognition(
...options,
taskHint: "automatic-speech-recognition",
});
const isValidOutput = typeof res?.text === "string";
if (!isValidOutput) {
throw new InferenceOutputError("Expected {text: string}");
}
return res;
return validateOutput(res, z.object({ text: z.string() }));
}
8 changes: 2 additions & 6 deletions packages/inference/src/tasks/audio/textToSpeech.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand All @@ -20,9 +20,5 @@ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): P
...options,
taskHint: "text-to-speech",
});
const isValidOutput = res && res instanceof Blob;
if (!isValidOutput) {
throw new InferenceOutputError("Expected Blob");
}
return res;
return validateOutput(res, z.blob());
}
9 changes: 2 additions & 7 deletions packages/inference/src/tasks/cv/imageClassification.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -34,10 +34,5 @@ export async function imageClassification(
...options,
taskHint: "image-classification",
});
const isValidOutput =
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
}
return res;
return validateOutput(res, z.array(z.object({ label: z.string(), score: z.number() })));
}
10 changes: 2 additions & 8 deletions packages/inference/src/tasks/cv/imageSegmentation.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -38,11 +38,5 @@ export async function imageSegmentation(
...options,
taskHint: "image-segmentation",
});
const isValidOutput =
Array.isArray(res) &&
res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
}
return res;
return validateOutput(res, z.array(z.object({ label: z.string(), mask: z.string(), score: z.number() })));
}
8 changes: 2 additions & 6 deletions packages/inference/src/tasks/cv/imageToImage.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { request } from "../custom/request";
import { base64FromBytes } from "../../../../shared";
import { validateOutput, z } from "../../lib/validateOutput";

export type ImageToImageArgs = BaseArgs & {
/**
Expand Down Expand Up @@ -78,9 +78,5 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
...options,
taskHint: "image-to-image",
});
const isValidOutput = res && res instanceof Blob;
if (!isValidOutput) {
throw new InferenceOutputError("Expected Blob");
}
return res;
return validateOutput(res, z.blob());
}
18 changes: 6 additions & 12 deletions packages/inference/src/tasks/cv/imageToText.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand All @@ -20,16 +20,10 @@ export interface ImageToTextOutput {
* This task reads some image input and outputs the text caption.
*/
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
const res = (
await request<[ImageToTextOutput]>(args, {
...options,
taskHint: "image-to-text",
})
)?.[0];
const res = await request<[ImageToTextOutput]>(args, {
...options,
taskHint: "image-to-text",
});

if (typeof res?.generated_text !== "string") {
throw new InferenceOutputError("Expected {generated_text: string}");
}

return res;
return validateOutput(res, z.first(z.object({ generated_text: z.string() })));
}
30 changes: 12 additions & 18 deletions packages/inference/src/tasks/cv/objectDetection.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { request } from "../custom/request";
import type { BaseArgs, Options } from "../../types";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";

export type ObjectDetectionArgs = BaseArgs & {
/**
Expand Down Expand Up @@ -41,21 +41,15 @@ export async function objectDetection(args: ObjectDetectionArgs, options?: Optio
...options,
taskHint: "object-detection",
});
const isValidOutput =
Array.isArray(res) &&
res.every(
(x) =>
typeof x.label === "string" &&
typeof x.score === "number" &&
typeof x.box.xmin === "number" &&
typeof x.box.ymin === "number" &&
typeof x.box.xmax === "number" &&
typeof x.box.ymax === "number"
);
if (!isValidOutput) {
throw new InferenceOutputError(
"Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
);
}
return res;

return validateOutput(
res,
z.array(
z.object({
label: z.string(),
score: z.number(),
box: z.object({ xmin: z.number(), ymin: z.number(), xmax: z.number(), ymax: z.number() }),
})
)
);
}
8 changes: 2 additions & 6 deletions packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { validateOutput, z } from "../../lib/validateOutput";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand Down Expand Up @@ -43,9 +43,5 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
...options,
taskHint: "text-to-image",
});
const isValidOutput = res && res instanceof Blob;
if (!isValidOutput) {
throw new InferenceOutputError("Expected Blob");
}
return res;
return validateOutput(res, z.blob());
}
Loading
Loading