From 72fecdc2ec4331e685f781197253da855375e016 Mon Sep 17 00:00:00 2001 From: Ryan Holinshead <> Date: Tue, 26 Dec 2023 16:44:24 -0500 Subject: [PATCH] [editor] Add Prompt Button --- python/requirements.txt | 2 +- .../src/aiconfig/editor/client/src/Editor.tsx | 32 +++++- .../client/src/components/EditorContainer.tsx | 108 +++++++++++++++--- .../client/src/components/aiconfigReducer.ts | 73 +++++++++++- .../src/components/prompt/AddPromptButton.tsx | 100 ++++++++++++++++ .../client/src/utils/aiconfigStateUtils.ts | 10 ++ .../aiconfig/editor/client/src/utils/api.ts | 3 + python/src/aiconfig/editor/server/server.py | 1 + python/src/aiconfig/scripts/aiconfig_cli.py | 2 +- 9 files changed, 311 insertions(+), 20 deletions(-) create mode 100644 python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx create mode 100644 python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts diff --git a/python/requirements.txt b/python/requirements.txt index 11b48e5a5..0e3b7ffde 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -6,7 +6,7 @@ flask[async] google-generativeai huggingface_hub hypothesis==6.91.0 -lastmile-utils==0.0.13 +lastmile-utils==0.0.14 mock nest_asyncio nltk diff --git a/python/src/aiconfig/editor/client/src/Editor.tsx b/python/src/aiconfig/editor/client/src/Editor.tsx index 3fdd9e75f..d5cda9575 100644 --- a/python/src/aiconfig/editor/client/src/Editor.tsx +++ b/python/src/aiconfig/editor/client/src/Editor.tsx @@ -1,7 +1,7 @@ import EditorContainer from "./components/EditorContainer"; import { ClientAIConfig } from "./shared/types"; import { Flex, Loader } from "@mantine/core"; -import { AIConfig } from "aiconfig"; +import { AIConfig, Prompt } from "aiconfig"; import { useCallback, useEffect, useState } from "react"; import { ufetch } from "ufetch"; import { ROUTE_TABLE } from "./utils/api"; @@ -20,13 +20,34 @@ export default function Editor() { }, [loadConfig]); const onSave = useCallback(async (aiconfig: AIConfig) => { - const res = await ufetch.post(`/api/aiconfig/save`, { + const res = await ufetch.post(ROUTE_TABLE.SAVE, { // path: file path, aiconfig, }); return res; }, []); + const getModels = useCallback(async (search: string) => { + // For now, rely on caching and handle client-side search filtering + // We will use server-side search filtering for Gradio + const res = await ufetch.get(ROUTE_TABLE.LIST_MODELS); + const models = res.data; + if (search && search.length > 0) { + return models.filter((model: string) => model.indexOf(search) >= 0); + } + return models; + }, []); + + const addPrompt = useCallback( + async (promptName: string, promptData: Prompt) => { + return await ufetch.post(ROUTE_TABLE.ADD_PROMPT, { + prompt_name: promptName, + prompt_data: promptData, + }); + }, + [] + ); + return (
{!aiconfig ? ( @@ -34,7 +55,12 @@ export default function Editor() { ) : ( - + )}
); diff --git a/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx b/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx index 0545f1b57..6c958d2f1 100644 --- a/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx +++ b/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx @@ -1,17 +1,45 @@ import PromptContainer from "./prompt/PromptContainer"; -import { Container, Group, Button, createStyles } from "@mantine/core"; +import { Container, Group, Button, createStyles, Stack } from "@mantine/core"; import { showNotification } from "@mantine/notifications"; -import { AIConfig, PromptInput } from "aiconfig"; -import { useCallback, useReducer, useState } from "react"; -import aiconfigReducer from "./aiconfigReducer"; +import { AIConfig, Prompt, PromptInput } from "aiconfig"; +import { useCallback, useReducer, useRef, useState } from "react"; +import aiconfigReducer, { AIConfigReducerAction } from "./aiconfigReducer"; import { ClientAIConfig, clientConfigToAIConfig } from "../shared/types"; +import AddPromptButton from "./prompt/AddPromptButton"; +import { getDefaultNewPromptName } from "../utils/aiconfigStateUtils"; type Props = { aiconfig: ClientAIConfig; + addPrompt: ( + promptName: string, + prompt: Prompt + ) => Promise<{ aiconfig: AIConfig }>; onSave: (aiconfig: AIConfig) => Promise; + getModels: (search: string) => Promise; }; const useStyles = createStyles((theme) => ({ + addPromptRow: { + borderRadius: "4px", + display: "inline-block", + bottom: -24, + left: -40, + "&:hover": { + backgroundColor: + theme.colorScheme === "light" + ? theme.colors.gray[1] + : "rgba(255, 255, 255, 0.1)", + }, + [theme.fn.smallerThan("sm")]: { + marginLeft: "0", + display: "block", + position: "static", + bottom: -10, + left: 0, + height: 28, + margin: "10px 0", + }, + }, promptsContainer: { [theme.fn.smallerThan("sm")]: { padding: "0 0 200px 0", @@ -22,7 +50,9 @@ const useStyles = createStyles((theme) => ({ export default function EditorContainer({ aiconfig: initialAIConfig, + addPrompt, onSave, + getModels, }: Props) { const [isSaving, setIsSaving] = useState(false); const [aiconfigState, dispatch] = useReducer( @@ -30,6 +60,9 @@ export default function EditorContainer({ initialAIConfig ); + const stateRef = useRef(aiconfigState); + stateRef.current = aiconfigState; + const save = useCallback(async () => { setIsSaving(true); try { @@ -81,6 +114,43 @@ export default function EditorContainer({ [dispatch] ); + const onAddPrompt = useCallback( + async (promptIndex: number, model: string) => { + const promptName = getDefaultNewPromptName(stateRef.current as AIConfig); + const newPrompt: Prompt = { + name: promptName, + input: "", // TODO: Can we use schema to get input structure, string vs object? + metadata: { + model, + }, + }; + + const action: AIConfigReducerAction = { + type: "ADD_PROMPT_AT_INDEX", + index: promptIndex, + prompt: newPrompt, + }; + + dispatch(action); + + try { + const serverConfigRes = await addPrompt(promptName, newPrompt); + dispatch({ + type: "CONSOLIDATE_AICONFIG", + action, + config: serverConfigRes.aiconfig, + }); + } catch (err: any) { + showNotification({ + title: "Error adding prompt to config", + message: err.message, + color: "red", + }); + } + }, + [addPrompt, dispatch] + ); + const { classes } = useStyles(); // TODO: Implement editor context for callbacks, readonly state, etc. @@ -100,15 +170,27 @@ export default function EditorContainer({ {aiconfigState.prompts.map((prompt: any, i: number) => { return ( - + + +
+ + onAddPrompt( + i + 1 /* insert below current prompt index */, + model + ) + } + /> +
+
); })}
diff --git a/python/src/aiconfig/editor/client/src/components/aiconfigReducer.ts b/python/src/aiconfig/editor/client/src/components/aiconfigReducer.ts index 94e05dc7e..80f05c841 100644 --- a/python/src/aiconfig/editor/client/src/components/aiconfigReducer.ts +++ b/python/src/aiconfig/editor/client/src/components/aiconfigReducer.ts @@ -1,12 +1,29 @@ import { ClientAIConfig, ClientPrompt } from "../shared/types"; import { getPromptModelName } from "../utils/promptUtils"; -import { AIConfig, JSONObject, PromptInput } from "aiconfig"; +import { AIConfig, JSONObject, Prompt, PromptInput } from "aiconfig"; -type AIConfigReducerAction = +export type AIConfigReducerAction = + | MutateAIConfigAction + | ConsolidateAIConfigAction; + +export type MutateAIConfigAction = + | AddPromptAction | UpdatePromptInputAction | UpdatePromptModelSettingsAction | UpdatePromptParametersAction; +export type ConsolidateAIConfigAction = { + type: "CONSOLIDATE_AICONFIG"; + action: MutateAIConfigAction; + config: AIConfig; +}; + +export type AddPromptAction = { + type: "ADD_PROMPT_AT_INDEX"; + index: number; + prompt: Prompt; +}; + export type UpdatePromptInputAction = { type: "UPDATE_PROMPT_INPUT"; index: number; @@ -50,11 +67,60 @@ function reduceReplaceInput( })); } +function reduceInsertPromptAtIndex( + state: ClientAIConfig, + index: number, + prompt: ClientPrompt +): ClientAIConfig { + return { + ...state, + prompts: [ + ...state.prompts.slice(0, index), + prompt, + ...state.prompts.slice(index), + ], + }; +} + +function reduceConsolidateAIConfig( + state: ClientAIConfig, + action: MutateAIConfigAction, + responseConfig: AIConfig +): ClientAIConfig { + switch (action.type) { + case "ADD_PROMPT_AT_INDEX": { + // Make sure prompt structure is properly updated. Client input and metadata takes precedence + // since it may have been updated by the user while the request was in flight + return reduceReplacePrompt(state, action.index, (prompt) => { + const responsePrompt = responseConfig.prompts[action.index]; + return { + ...responsePrompt, + ...prompt, + metadata: { + ...responsePrompt.metadata, + ...prompt.metadata, + }, + } as ClientPrompt; + }); + } + default: { + return state; + } + } +} + export default function aiconfigReducer( state: ClientAIConfig, action: AIConfigReducerAction ): ClientAIConfig { switch (action.type) { + case "ADD_PROMPT_AT_INDEX": { + return reduceInsertPromptAtIndex( + state, + action.index, + action.prompt as ClientPrompt + ); + } case "UPDATE_PROMPT_INPUT": { return reduceReplaceInput(state, action.index, () => action.input); } @@ -84,5 +150,8 @@ export default function aiconfigReducer( }, })); } + case "CONSOLIDATE_AICONFIG": { + return reduceConsolidateAIConfig(state, action.action, action.config); + } } } diff --git a/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx b/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx new file mode 100644 index 000000000..cd18f0a07 --- /dev/null +++ b/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx @@ -0,0 +1,100 @@ +import { ActionIcon, Menu, TextInput } from "@mantine/core"; +import { IconPlus, IconSearch, IconTextCaption } from "@tabler/icons-react"; +import { memo, useCallback, useEffect, useState } from "react"; +import { showNotification } from "@mantine/notifications"; + +type Props = { + addPrompt: (prompt: string) => void; + getModels: (search: string) => Promise; +}; + +function ModelMenuItems({ + models, + onSelectModel, + collapseLimit, +}: { + models: string[]; + onSelectModel: (model: string) => void; + collapseLimit: number; +}) { + const [isCollapsed, setIsCollapsed] = useState(models.length > collapseLimit); + + const displayModels = isCollapsed ? models.slice(0, collapseLimit) : models; + + return ( + <> + {displayModels.map((model) => ( + } + onClick={() => onSelectModel(model)} + > + {model} + + ))} + {isCollapsed && ( + setIsCollapsed(false)}>... + )} + + ); +} + +export default memo(function AddPromptButton({ addPrompt, getModels }: Props) { + const [modelSearch, setModelSearch] = useState(""); + const [models, setModels] = useState([]); + const [isOpen, setIsOpen] = useState(false); + + const loadModels = useCallback(async (modelSearch: string) => { + try { + const models = await getModels(modelSearch); + setModels(models); + } catch (err: any) { + showNotification({ + title: "Error loading models", + message: err?.message, + color: "red", + }); + } + }, []); + + const onAddPrompt = useCallback((model: string) => { + addPrompt(model); + setIsOpen(false); + }, []); + + useEffect(() => { + loadModels(modelSearch); + }, [loadModels, modelSearch]); + + return ( + + + + + + + + + } + placeholder="Search" + value={modelSearch} + onChange={(e) => setModelSearch(e.currentTarget.value)} + /> + + + }>Add New Model + + + ); +}); diff --git a/python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts b/python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts new file mode 100644 index 000000000..95107095c --- /dev/null +++ b/python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts @@ -0,0 +1,10 @@ +import { AIConfig } from "aiconfig"; + +export function getDefaultNewPromptName(aiconfig: AIConfig): string { + const existingNames = aiconfig.prompts.map((prompt) => prompt.name); + let i = existingNames.length + 1; + while (existingNames.includes(`Prompt ${i}`)) { + i++; + } + return `Prompt ${i}`; +} diff --git a/python/src/aiconfig/editor/client/src/utils/api.ts b/python/src/aiconfig/editor/client/src/utils/api.ts index d763b0012..110213b23 100644 --- a/python/src/aiconfig/editor/client/src/utils/api.ts +++ b/python/src/aiconfig/editor/client/src/utils/api.ts @@ -9,5 +9,8 @@ const HOST_ENDPOINT = const API_ENDPOINT = `${HOST_ENDPOINT}/api`; export const ROUTE_TABLE = { + ADD_PROMPT: urlJoin(API_ENDPOINT, "/add_prompt"), + SAVE: urlJoin(API_ENDPOINT, "/save"), LOAD: urlJoin(API_ENDPOINT, "/load"), + LIST_MODELS: urlJoin(API_ENDPOINT, "/list_models"), }; diff --git a/python/src/aiconfig/editor/server/server.py b/python/src/aiconfig/editor/server/server.py index d39a84d15..1ef0576f5 100644 --- a/python/src/aiconfig/editor/server/server.py +++ b/python/src/aiconfig/editor/server/server.py @@ -182,6 +182,7 @@ def _get_op_args(): @app.route("/api/add_prompt", methods=["POST"]) def add_prompt() -> FlaskResponse: method_name = MethodName("add_prompt") + # TODO: Support specifying index signature: dict[str, Type[Any]] = {"prompt_name": str, "prompt_data": Prompt} state = get_server_state(app) diff --git a/python/src/aiconfig/scripts/aiconfig_cli.py b/python/src/aiconfig/scripts/aiconfig_cli.py index 1296f8564..e625c68a0 100644 --- a/python/src/aiconfig/scripts/aiconfig_cli.py +++ b/python/src/aiconfig/scripts/aiconfig_cli.py @@ -34,7 +34,7 @@ async def main(argv: list[str]) -> int: def run_subcommand(argv: list[str]) -> Result[str, str]: LOGGER.info("Running subcommand") subparser_record_types = {"edit": EditServerConfig} - main_parser = core_utils.argparsify(AIConfigCLIConfig, subparser_rs=subparser_record_types) + main_parser = core_utils.argparsify(AIConfigCLIConfig, subparser_record_types=subparser_record_types) res_cli_config = core_utils.parse_args(main_parser, argv[1:], AIConfigCLIConfig) res_cli_config.and_then(_process_cli_config)