From 72fecdc2ec4331e685f781197253da855375e016 Mon Sep 17 00:00:00 2001
From: Ryan Holinshead <>
Date: Tue, 26 Dec 2023 16:44:24 -0500
Subject: [PATCH] [editor] Add Prompt Button
---
python/requirements.txt | 2 +-
.../src/aiconfig/editor/client/src/Editor.tsx | 32 +++++-
.../client/src/components/EditorContainer.tsx | 108 +++++++++++++++---
.../client/src/components/aiconfigReducer.ts | 73 +++++++++++-
.../src/components/prompt/AddPromptButton.tsx | 100 ++++++++++++++++
.../client/src/utils/aiconfigStateUtils.ts | 10 ++
.../aiconfig/editor/client/src/utils/api.ts | 3 +
python/src/aiconfig/editor/server/server.py | 1 +
python/src/aiconfig/scripts/aiconfig_cli.py | 2 +-
9 files changed, 311 insertions(+), 20 deletions(-)
create mode 100644 python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx
create mode 100644 python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts
diff --git a/python/requirements.txt b/python/requirements.txt
index 11b48e5a5..0e3b7ffde 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -6,7 +6,7 @@ flask[async]
google-generativeai
huggingface_hub
hypothesis==6.91.0
-lastmile-utils==0.0.13
+lastmile-utils==0.0.14
mock
nest_asyncio
nltk
diff --git a/python/src/aiconfig/editor/client/src/Editor.tsx b/python/src/aiconfig/editor/client/src/Editor.tsx
index 3fdd9e75f..d5cda9575 100644
--- a/python/src/aiconfig/editor/client/src/Editor.tsx
+++ b/python/src/aiconfig/editor/client/src/Editor.tsx
@@ -1,7 +1,7 @@
import EditorContainer from "./components/EditorContainer";
import { ClientAIConfig } from "./shared/types";
import { Flex, Loader } from "@mantine/core";
-import { AIConfig } from "aiconfig";
+import { AIConfig, Prompt } from "aiconfig";
import { useCallback, useEffect, useState } from "react";
import { ufetch } from "ufetch";
import { ROUTE_TABLE } from "./utils/api";
@@ -20,13 +20,34 @@ export default function Editor() {
}, [loadConfig]);
const onSave = useCallback(async (aiconfig: AIConfig) => {
- const res = await ufetch.post(`/api/aiconfig/save`, {
+ const res = await ufetch.post(ROUTE_TABLE.SAVE, {
// path: file path,
aiconfig,
});
return res;
}, []);
+ const getModels = useCallback(async (search: string) => {
+ // For now, rely on caching and handle client-side search filtering
+ // We will use server-side search filtering for Gradio
+ const res = await ufetch.get(ROUTE_TABLE.LIST_MODELS);
+ const models = res.data;
+ if (search && search.length > 0) {
+ return models.filter((model: string) => model.indexOf(search) >= 0);
+ }
+ return models;
+ }, []);
+
+ const addPrompt = useCallback(
+ async (promptName: string, promptData: Prompt) => {
+ return await ufetch.post(ROUTE_TABLE.ADD_PROMPT, {
+ prompt_name: promptName,
+ prompt_data: promptData,
+ });
+ },
+ []
+ );
+
return (
{!aiconfig ? (
@@ -34,7 +55,12 @@ export default function Editor() {
) : (
-
+
)}
);
diff --git a/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx b/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx
index 0545f1b57..6c958d2f1 100644
--- a/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx
+++ b/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx
@@ -1,17 +1,45 @@
import PromptContainer from "./prompt/PromptContainer";
-import { Container, Group, Button, createStyles } from "@mantine/core";
+import { Container, Group, Button, createStyles, Stack } from "@mantine/core";
import { showNotification } from "@mantine/notifications";
-import { AIConfig, PromptInput } from "aiconfig";
-import { useCallback, useReducer, useState } from "react";
-import aiconfigReducer from "./aiconfigReducer";
+import { AIConfig, Prompt, PromptInput } from "aiconfig";
+import { useCallback, useReducer, useRef, useState } from "react";
+import aiconfigReducer, { AIConfigReducerAction } from "./aiconfigReducer";
import { ClientAIConfig, clientConfigToAIConfig } from "../shared/types";
+import AddPromptButton from "./prompt/AddPromptButton";
+import { getDefaultNewPromptName } from "../utils/aiconfigStateUtils";
type Props = {
aiconfig: ClientAIConfig;
+ addPrompt: (
+ promptName: string,
+ prompt: Prompt
+ ) => Promise<{ aiconfig: AIConfig }>;
onSave: (aiconfig: AIConfig) => Promise;
+ getModels: (search: string) => Promise;
};
const useStyles = createStyles((theme) => ({
+ addPromptRow: {
+ borderRadius: "4px",
+ display: "inline-block",
+ bottom: -24,
+ left: -40,
+ "&:hover": {
+ backgroundColor:
+ theme.colorScheme === "light"
+ ? theme.colors.gray[1]
+ : "rgba(255, 255, 255, 0.1)",
+ },
+ [theme.fn.smallerThan("sm")]: {
+ marginLeft: "0",
+ display: "block",
+ position: "static",
+ bottom: -10,
+ left: 0,
+ height: 28,
+ margin: "10px 0",
+ },
+ },
promptsContainer: {
[theme.fn.smallerThan("sm")]: {
padding: "0 0 200px 0",
@@ -22,7 +50,9 @@ const useStyles = createStyles((theme) => ({
export default function EditorContainer({
aiconfig: initialAIConfig,
+ addPrompt,
onSave,
+ getModels,
}: Props) {
const [isSaving, setIsSaving] = useState(false);
const [aiconfigState, dispatch] = useReducer(
@@ -30,6 +60,9 @@ export default function EditorContainer({
initialAIConfig
);
+ const stateRef = useRef(aiconfigState);
+ stateRef.current = aiconfigState;
+
const save = useCallback(async () => {
setIsSaving(true);
try {
@@ -81,6 +114,43 @@ export default function EditorContainer({
[dispatch]
);
+ const onAddPrompt = useCallback(
+ async (promptIndex: number, model: string) => {
+ const promptName = getDefaultNewPromptName(stateRef.current as AIConfig);
+ const newPrompt: Prompt = {
+ name: promptName,
+ input: "", // TODO: Can we use schema to get input structure, string vs object?
+ metadata: {
+ model,
+ },
+ };
+
+ const action: AIConfigReducerAction = {
+ type: "ADD_PROMPT_AT_INDEX",
+ index: promptIndex,
+ prompt: newPrompt,
+ };
+
+ dispatch(action);
+
+ try {
+ const serverConfigRes = await addPrompt(promptName, newPrompt);
+ dispatch({
+ type: "CONSOLIDATE_AICONFIG",
+ action,
+ config: serverConfigRes.aiconfig,
+ });
+ } catch (err: any) {
+ showNotification({
+ title: "Error adding prompt to config",
+ message: err.message,
+ color: "red",
+ });
+ }
+ },
+ [addPrompt, dispatch]
+ );
+
const { classes } = useStyles();
// TODO: Implement editor context for callbacks, readonly state, etc.
@@ -100,15 +170,27 @@ export default function EditorContainer({
{aiconfigState.prompts.map((prompt: any, i: number) => {
return (
-
+
+
+
+
+ onAddPrompt(
+ i + 1 /* insert below current prompt index */,
+ model
+ )
+ }
+ />
+
+
);
})}
diff --git a/python/src/aiconfig/editor/client/src/components/aiconfigReducer.ts b/python/src/aiconfig/editor/client/src/components/aiconfigReducer.ts
index 94e05dc7e..80f05c841 100644
--- a/python/src/aiconfig/editor/client/src/components/aiconfigReducer.ts
+++ b/python/src/aiconfig/editor/client/src/components/aiconfigReducer.ts
@@ -1,12 +1,29 @@
import { ClientAIConfig, ClientPrompt } from "../shared/types";
import { getPromptModelName } from "../utils/promptUtils";
-import { AIConfig, JSONObject, PromptInput } from "aiconfig";
+import { AIConfig, JSONObject, Prompt, PromptInput } from "aiconfig";
-type AIConfigReducerAction =
+export type AIConfigReducerAction =
+ | MutateAIConfigAction
+ | ConsolidateAIConfigAction;
+
+export type MutateAIConfigAction =
+ | AddPromptAction
| UpdatePromptInputAction
| UpdatePromptModelSettingsAction
| UpdatePromptParametersAction;
+export type ConsolidateAIConfigAction = {
+ type: "CONSOLIDATE_AICONFIG";
+ action: MutateAIConfigAction;
+ config: AIConfig;
+};
+
+export type AddPromptAction = {
+ type: "ADD_PROMPT_AT_INDEX";
+ index: number;
+ prompt: Prompt;
+};
+
export type UpdatePromptInputAction = {
type: "UPDATE_PROMPT_INPUT";
index: number;
@@ -50,11 +67,60 @@ function reduceReplaceInput(
}));
}
+function reduceInsertPromptAtIndex(
+ state: ClientAIConfig,
+ index: number,
+ prompt: ClientPrompt
+): ClientAIConfig {
+ return {
+ ...state,
+ prompts: [
+ ...state.prompts.slice(0, index),
+ prompt,
+ ...state.prompts.slice(index),
+ ],
+ };
+}
+
+function reduceConsolidateAIConfig(
+ state: ClientAIConfig,
+ action: MutateAIConfigAction,
+ responseConfig: AIConfig
+): ClientAIConfig {
+ switch (action.type) {
+ case "ADD_PROMPT_AT_INDEX": {
+ // Make sure prompt structure is properly updated. Client input and metadata takes precedence
+ // since it may have been updated by the user while the request was in flight
+ return reduceReplacePrompt(state, action.index, (prompt) => {
+ const responsePrompt = responseConfig.prompts[action.index];
+ return {
+ ...responsePrompt,
+ ...prompt,
+ metadata: {
+ ...responsePrompt.metadata,
+ ...prompt.metadata,
+ },
+ } as ClientPrompt;
+ });
+ }
+ default: {
+ return state;
+ }
+ }
+}
+
export default function aiconfigReducer(
state: ClientAIConfig,
action: AIConfigReducerAction
): ClientAIConfig {
switch (action.type) {
+ case "ADD_PROMPT_AT_INDEX": {
+ return reduceInsertPromptAtIndex(
+ state,
+ action.index,
+ action.prompt as ClientPrompt
+ );
+ }
case "UPDATE_PROMPT_INPUT": {
return reduceReplaceInput(state, action.index, () => action.input);
}
@@ -84,5 +150,8 @@ export default function aiconfigReducer(
},
}));
}
+ case "CONSOLIDATE_AICONFIG": {
+ return reduceConsolidateAIConfig(state, action.action, action.config);
+ }
}
}
diff --git a/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx b/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx
new file mode 100644
index 000000000..cd18f0a07
--- /dev/null
+++ b/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx
@@ -0,0 +1,100 @@
+import { ActionIcon, Menu, TextInput } from "@mantine/core";
+import { IconPlus, IconSearch, IconTextCaption } from "@tabler/icons-react";
+import { memo, useCallback, useEffect, useState } from "react";
+import { showNotification } from "@mantine/notifications";
+
+type Props = {
+ addPrompt: (prompt: string) => void;
+ getModels: (search: string) => Promise;
+};
+
+function ModelMenuItems({
+ models,
+ onSelectModel,
+ collapseLimit,
+}: {
+ models: string[];
+ onSelectModel: (model: string) => void;
+ collapseLimit: number;
+}) {
+ const [isCollapsed, setIsCollapsed] = useState(models.length > collapseLimit);
+
+ const displayModels = isCollapsed ? models.slice(0, collapseLimit) : models;
+
+ return (
+ <>
+ {displayModels.map((model) => (
+ }
+ onClick={() => onSelectModel(model)}
+ >
+ {model}
+
+ ))}
+ {isCollapsed && (
+ setIsCollapsed(false)}>...
+ )}
+ >
+ );
+}
+
+export default memo(function AddPromptButton({ addPrompt, getModels }: Props) {
+ const [modelSearch, setModelSearch] = useState("");
+ const [models, setModels] = useState([]);
+ const [isOpen, setIsOpen] = useState(false);
+
+ const loadModels = useCallback(async (modelSearch: string) => {
+ try {
+ const models = await getModels(modelSearch);
+ setModels(models);
+ } catch (err: any) {
+ showNotification({
+ title: "Error loading models",
+ message: err?.message,
+ color: "red",
+ });
+ }
+ }, []);
+
+ const onAddPrompt = useCallback((model: string) => {
+ addPrompt(model);
+ setIsOpen(false);
+ }, []);
+
+ useEffect(() => {
+ loadModels(modelSearch);
+ }, [loadModels, modelSearch]);
+
+ return (
+
+ );
+});
diff --git a/python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts b/python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts
new file mode 100644
index 000000000..95107095c
--- /dev/null
+++ b/python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts
@@ -0,0 +1,10 @@
+import { AIConfig } from "aiconfig";
+
+export function getDefaultNewPromptName(aiconfig: AIConfig): string {
+ const existingNames = aiconfig.prompts.map((prompt) => prompt.name);
+ let i = existingNames.length + 1;
+ while (existingNames.includes(`Prompt ${i}`)) {
+ i++;
+ }
+ return `Prompt ${i}`;
+}
diff --git a/python/src/aiconfig/editor/client/src/utils/api.ts b/python/src/aiconfig/editor/client/src/utils/api.ts
index d763b0012..110213b23 100644
--- a/python/src/aiconfig/editor/client/src/utils/api.ts
+++ b/python/src/aiconfig/editor/client/src/utils/api.ts
@@ -9,5 +9,8 @@ const HOST_ENDPOINT =
const API_ENDPOINT = `${HOST_ENDPOINT}/api`;
export const ROUTE_TABLE = {
+ ADD_PROMPT: urlJoin(API_ENDPOINT, "/add_prompt"),
+ SAVE: urlJoin(API_ENDPOINT, "/save"),
LOAD: urlJoin(API_ENDPOINT, "/load"),
+ LIST_MODELS: urlJoin(API_ENDPOINT, "/list_models"),
};
diff --git a/python/src/aiconfig/editor/server/server.py b/python/src/aiconfig/editor/server/server.py
index d39a84d15..1ef0576f5 100644
--- a/python/src/aiconfig/editor/server/server.py
+++ b/python/src/aiconfig/editor/server/server.py
@@ -182,6 +182,7 @@ def _get_op_args():
@app.route("/api/add_prompt", methods=["POST"])
def add_prompt() -> FlaskResponse:
method_name = MethodName("add_prompt")
+ # TODO: Support specifying index
signature: dict[str, Type[Any]] = {"prompt_name": str, "prompt_data": Prompt}
state = get_server_state(app)
diff --git a/python/src/aiconfig/scripts/aiconfig_cli.py b/python/src/aiconfig/scripts/aiconfig_cli.py
index 1296f8564..e625c68a0 100644
--- a/python/src/aiconfig/scripts/aiconfig_cli.py
+++ b/python/src/aiconfig/scripts/aiconfig_cli.py
@@ -34,7 +34,7 @@ async def main(argv: list[str]) -> int:
def run_subcommand(argv: list[str]) -> Result[str, str]:
LOGGER.info("Running subcommand")
subparser_record_types = {"edit": EditServerConfig}
- main_parser = core_utils.argparsify(AIConfigCLIConfig, subparser_rs=subparser_record_types)
+ main_parser = core_utils.argparsify(AIConfigCLIConfig, subparser_record_types=subparser_record_types)
res_cli_config = core_utils.parse_args(main_parser, argv[1:], AIConfigCLIConfig)
res_cli_config.and_then(_process_cli_config)