Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[editor] HuggingFaceImage2TextTransformerPromptSchema #848

Merged
merged 3 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ import {
DEBOUNCE_MS,
SERVER_HEARTBEAT_INTERVAL_MS,
} from "../utils/constants";
import { getPromptModelName } from "../utils/promptUtils";
import {
getDefaultPromptInputForModel,
getPromptModelName,
} from "../utils/promptUtils";
import { IconDeviceFloppy } from "@tabler/icons-react";
import CopyButton from "./CopyButton";

Expand Down Expand Up @@ -498,7 +501,7 @@ export default function EditorContainer({

const newPrompt: Prompt = {
name: promptName,
input: "", // TODO: Can we use schema to get input structure, string vs object?
input: getDefaultPromptInputForModel(model),
metadata: {
model,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,65 @@ type Props = {
onCancel?: () => void;
};

async function uploadFile(_file: File) {
// TODO: Implement
return {
url: "https://s3.amazonaws.com/files.uploads.lastmileai.com/uploads/cldxsqbel0000qs8owp8mkd0z/2023_12_1_21_23_24/942/Screenshot 2023-11-28 at 11.11.25 AM.png",
};
// s3 file uris cannot have '+' character, so replace with '_'
function sanitizeFileName(name: string) {
return name.replace(/[_+]/g, "_");
}

export function getTodayDateString(): string {
const date = new Date();
const dateString = `${date.getFullYear()}_${
date.getMonth() + 1
}_${date.getDate()}`;
const timeString = `${date.getUTCHours()}_${date.getUTCMinutes()}_${date.getUTCSeconds()}`;
return `${dateString}_${timeString}`;
}

// TODO: Make this configurable for external deployments
async function uploadFile(file: File): Promise<{ url: string }> {
const randomPath = Math.round(Math.random() * 10000);
// TODO: Add back once CORS is resolved
// const policyResponse = await fetch(
// "https://lastmileai.dev/api/upload/publicpolicy"
// );
// const policy = await policyResponse.json();
const uploadUrl = "https://s3.amazonaws.com/lastmileai.aiconfig.public/";
const uploadKey = `uploads/${getTodayDateString()}/${randomPath}/${sanitizeFileName(
file.name
)}`;

const formData = new FormData();
formData.append("key", uploadKey);
formData.append("acl", "public-read");
formData.append("Content-Type", file.type);
// formData.append("AWSAccessKeyId", policy.AWSAccessKeyId);
formData.append("success_action_status", "201");
// formData.append("Policy", policy.s3Policy);
// formData.append("Signature", policy.s3Signature);
formData.append("file", file);

// See this about changing to use XMLHTTPRequest to show upload progress as well
// https://medium.com/@cpatarun/tracking-file-upload-progress-to-amazon-s3-from-the-browser-71be6712c63d
const rawRes = await fetch(uploadUrl, {
method: "POST",
mode: "cors",
cache: "no-cache",
body: formData,
headers: {
Authorization: "",
},
});

if (rawRes.ok && rawRes.status === 201) {
// Dont really need to parse xml s3 response, just use the keys, etc. that were passed
return { url: `${uploadUrl}${uploadKey}` };
} else {
throw new Error("Error uploading to S3!");
}
}

function getSupportedFileTypes(schema: PromptInputObjectAttachmentsSchema) {
return schema.items.mime_types;
return schema.items.mime_types.join(", ");
}

export default memo(function AttachmentUploader({
Expand Down Expand Up @@ -66,7 +116,7 @@ Props) {
const attachments: Attachment[] = uploads.map((upload) => {
return {
data: upload.url,
mimeType: upload.mimeType,
mime_type: upload.mimeType,
};
});

Expand Down Expand Up @@ -112,10 +162,10 @@ Props) {
}}
// TODO: Get these from schema,
// maxSize={MAX_IMAGE_FILE_SIZE}
// accept={}
accept={schema.items.mime_types}
>
{fileList.length > 0 ? (
`${fileList.length} File(s) to Upload`
`${fileList.length} File(s) Uploading...`
) : (
<div>
<Title order={4}>Upload File</Title>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,21 @@ export default memo(function PromptInputAttachmentsSchemaRenderer({
addedAttachments: InputAttachment[],
index: number
) => {
const updatedAttachments = attachments.reduce((acc, attachment, i) => {
if (i === index) {
// Replace the attachment at the given index with the new attachments
return [...acc, ...addedAttachments];
}
return [...acc, attachment];
}, [] as InputAttachment[]);
let updatedAttachments;

// Add to end of attachments
if (index > attachments.length) {
updatedAttachments = [...attachments, ...addedAttachments];
} else {
updatedAttachments = attachments.reduce((acc, attachment, i) => {
if (i === index) {
// Replace the attachment at the given index with the new attachments
return [...acc, ...addedAttachments];
}
return [...acc, attachment];
}, [] as InputAttachment[]);
}

onChangeAttachments(updatedAttachments);
};

Expand Down Expand Up @@ -106,18 +114,20 @@ export default memo(function PromptInputAttachmentsSchemaRenderer({
{/* If the schema supports more attachments, add another attachment input below existing attachments. This
* will allow the user to add more attachments without 'clearing' existing ones.
*/}
{schema.max_items == null ||
(numAttachments < schema.max_items && (
<EditableAttachmentRenderer
schema={schema}
onUpdateAttachment={(attachment) =>
onUpdateAttachment(attachment, numAttachments + 1)
}
onAddAttachments={(addedAttachments) =>
onAddAttachments(addedAttachments, numAttachments + 1)
}
/>
))}
{(schema.max_items == null || numAttachments < schema.max_items) && (
<EditableAttachmentRenderer
// Use key to force re-mount if a new attachment is 'attached' (will render above as a result)
// So that this 'new' attachment input is mounted with fresh state
key={numAttachments}
schema={schema}
onUpdateAttachment={(attachment) =>
onUpdateAttachment(attachment, numAttachments + 1)
}
onAddAttachments={(addedAttachments) =>
onAddAttachments(addedAttachments, numAttachments + 1)
}
/>
)}
</>
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,18 @@ function SchemaRenderer({ input, schema, onChangeInput }: SchemaRendererProps) {

return (
<Flex direction="column">
<DataRenderer
schema={dataSchema}
data={data}
onChangeData={onChangeData}
/>
{dataSchema && (
<DataRenderer
schema={dataSchema}
data={data}
onChangeData={onChangeData}
/>
)}
{attachmentsSchema && (
<AttachmentsRenderer
schema={attachmentsSchema}
onChangeAttachments={onChangeAttachments}
attachments={attachments}
attachments={attachments ?? []}
/>
)}
{/* <JSONRenderer properties={restProperties} data={restData}/> */}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ export const HuggingFaceAutomaticSpeechRecognitionPromptSchema: PromptSchema = {
// tokenizer, feature_extractor, device, decoder
input: {
type: "object",
required: ["data"],
required: ["attachments"],
properties: {
data: {
type: "string",
},
attachments: {
type: "array",
items: {
Expand All @@ -25,6 +22,7 @@ export const HuggingFaceAutomaticSpeechRecognitionPromptSchema: PromptSchema = {
},
},
},
max_items: 1,
},
},
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { PromptSchema } from "../../utils/promptUtils";

export const HuggingFaceImage2TextTransformerPromptSchema: PromptSchema = {
input: {
type: "object",
required: ["data"],
properties: {
attachments: {
type: "array",
items: {
type: "attachment",
required: ["data"],
mime_types: ["image/png"],
properties: {
data: {
type: "string",
},
},
},
max_items: 1,
},
},
},
};
24 changes: 22 additions & 2 deletions python/src/aiconfig/editor/client/src/utils/promptUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { HuggingFaceText2ImageDiffusorPromptSchema } from "../shared/prompt_sche
import { HuggingFaceTextGenerationTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextGenerationTransformerPromptSchema";
import { HuggingFaceAutomaticSpeechRecognitionPromptSchema } from "../shared/prompt_schemas/HuggingFaceAutomaticSpeechRecognitionPromptSchema";
import { HuggingFaceTextSummarizationTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextSummarizationTransformerPromptSchema";
import { HuggingFaceImage2TextTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceImage2TextTransformerPromptSchema";

/**
* Get the name of the model for the specified prompt. The name will either be specified in the prompt's
Expand Down Expand Up @@ -94,7 +95,8 @@ export const PROMPT_SCHEMAS: Record<string, PromptSchema> = {
HuggingFaceTextSummarizationTransformerPromptSchema,
HuggingFaceTextTranslationTransformer:
HuggingFaceTextGenerationTransformerPromptSchema,
// Image2Text: HuggingFaceImage2TextTransformerPromptSchema,
HuggingFaceImage2TextTransformer:
HuggingFaceImage2TextTransformerPromptSchema,
// Text2Speech: HuggingFaceText2SpeechTransformerPromptSchema,
};

Expand Down Expand Up @@ -131,7 +133,7 @@ export type PromptInputObjectAttachmentsSchema = {
export type PromptInputObjectSchema = {
type: "object";
properties: {
data: PromptInputObjectDataSchema;
data?: PromptInputObjectDataSchema;
attachments?: PromptInputObjectAttachmentsSchema;
} & Record<string, JSONObject>;
required?: string[] | undefined;
Expand Down Expand Up @@ -180,3 +182,21 @@ function isTextInputModality(prompt: Prompt): boolean {
export function checkParametersSupported(prompt: Prompt): boolean {
return prompt.metadata?.parameters != null || isTextInputModality(prompt);
}

export function getDefaultPromptInputForModel(model: string) {
const schema = PROMPT_SCHEMAS[model];
if (schema) {
if (schema.input.type === "string") {
return "";
} else {
if (schema.input.properties.data) {
return {
data: schema.input.properties.data.type === "string" ? "" : {},
};
}
return {};
}
}

return "";
}