diff --git a/python/src/aiconfig/editor/client/src/Editor.tsx b/python/src/aiconfig/editor/client/src/Editor.tsx index 26748e394..7fcef67d1 100644 --- a/python/src/aiconfig/editor/client/src/Editor.tsx +++ b/python/src/aiconfig/editor/client/src/Editor.tsx @@ -2,7 +2,7 @@ import EditorContainer, { AIConfigCallbacks, } from "./components/EditorContainer"; import { Flex, Loader, MantineProvider } from "@mantine/core"; -import { AIConfig, Prompt } from "aiconfig"; +import { AIConfig, ModelMetadata, Prompt } from "aiconfig"; import { useCallback, useEffect, useMemo, useState } from "react"; import { ufetch } from "ufetch"; import { ROUTE_TABLE } from "./utils/api"; @@ -50,6 +50,12 @@ export default function Editor() { [] ); + const deletePrompt = useCallback(async (promptName: string) => { + return await ufetch.post(ROUTE_TABLE.DELETE_PROMPT, { + prompt_name: promptName, + }); + }, []); + const runPrompt = useCallback(async (promptName: string) => { return await ufetch.post(ROUTE_TABLE.RUN_PROMPT, { prompt_name: promptName, @@ -66,12 +72,24 @@ export default function Editor() { [] ); + const updateModel = useCallback( + async (_promptName?: string, _modelData?: string | ModelMetadata) => { + // return await ufetch.post(ROUTE_TABLE.UPDATE_MODEL, + // prompt_name: promptName, + // model_data: modelData, + // }); + }, + [] + ); + const callbacks: AIConfigCallbacks = useMemo( () => ({ addPrompt, + deletePrompt, getModels, runPrompt, save, + updateModel, updatePrompt, }), [save, getModels, addPrompt, runPrompt] diff --git a/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx b/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx index 2536e9795..c3c2d12d6 100644 --- a/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx +++ b/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx @@ -1,7 +1,14 @@ import PromptContainer from "./prompt/PromptContainer"; -import { Container, Group, Button, createStyles, Stack } from "@mantine/core"; +import { + Container, + Group, + Button, + createStyles, + Stack, + Flex, +} from "@mantine/core"; import { showNotification } from "@mantine/notifications"; -import { AIConfig, Prompt, PromptInput } from "aiconfig"; +import { AIConfig, ModelMetadata, Prompt, PromptInput } from "aiconfig"; import { useCallback, useMemo, useReducer, useRef, useState } from "react"; import aiconfigReducer, { AIConfigReducerAction } from "./aiconfigReducer"; import { @@ -11,8 +18,12 @@ import { clientPromptToAIConfigPrompt, } from "../shared/types"; import AddPromptButton from "./prompt/AddPromptButton"; -import { getDefaultNewPromptName } from "../utils/aiconfigStateUtils"; +import { + getDefaultNewPromptName, + getPrompt, +} from "../utils/aiconfigStateUtils"; import { debounce, uniqueId } from "lodash"; +import PromptMenuButton from "./prompt/PromptMenuButton"; type Props = { aiconfig: AIConfig; @@ -25,9 +36,14 @@ export type AIConfigCallbacks = { prompt: Prompt, index: number ) => Promise<{ aiconfig: AIConfig }>; + deletePrompt: (promptName: string) => Promise; getModels: (search: string) => Promise; runPrompt: (promptName: string) => Promise; save: (aiconfig: AIConfig) => Promise; + updateModel: ( + promptName?: string, + modelData?: string | ModelMetadata + ) => Promise; updatePrompt: ( promptName: string, promptData: Prompt @@ -80,7 +96,7 @@ export default function EditorContainer({ const onSave = useCallback(async () => { setIsSaving(true); try { - await callbacks.save(clientConfigToAIConfig(aiconfigState)); + await callbacks.save(clientConfigToAIConfig(stateRef.current)); } catch (err: any) { showNotification({ title: "Error saving", @@ -90,7 +106,7 @@ export default function EditorContainer({ } finally { setIsSaving(false); } - }, [aiconfigState, callbacks.save]); + }, [callbacks.save]); const debouncedUpdatePrompt = useMemo( () => @@ -103,10 +119,10 @@ export default function EditorContainer({ ); const onChangePromptInput = useCallback( - async (promptIndex: number, newPromptInput: PromptInput) => { + async (promptId: string, newPromptInput: PromptInput) => { const action: AIConfigReducerAction = { type: "UPDATE_PROMPT_INPUT", - index: promptIndex, + id: promptId, input: newPromptInput, }; @@ -114,7 +130,7 @@ export default function EditorContainer({ try { const prompt = clientPromptToAIConfigPrompt( - aiconfigState.prompts[promptIndex] + getPrompt(stateRef.current, promptId)! ); const serverConfigRes = await debouncedUpdatePrompt(prompt.name, { ...prompt, @@ -138,10 +154,10 @@ export default function EditorContainer({ ); const onChangePromptName = useCallback( - async (promptIndex: number, newName: string) => { + async (promptId: string, newName: string) => { const action: AIConfigReducerAction = { type: "UPDATE_PROMPT_NAME", - index: promptIndex, + id: promptId, name: newName, }; @@ -150,23 +166,68 @@ export default function EditorContainer({ [dispatch] ); + const debouncedUpdateModel = useMemo( + () => + debounce( + (promptName?: string, modelMetadata?: string | ModelMetadata) => + callbacks.updateModel(promptName, modelMetadata), + 250 + ), + [callbacks.updateModel] + ); + const onUpdatePromptModelSettings = useCallback( - async (promptIndex: number, newModelSettings: any) => { + async (promptId: string, newModelSettings: any) => { dispatch({ type: "UPDATE_PROMPT_MODEL_SETTINGS", - index: promptIndex, + id: promptId, modelSettings: newModelSettings, }); - // TODO: Call server-side endpoint to update model settings + // TODO: Call server-side endpoint to update model }, [dispatch] ); + const onUpdatePromptModel = useCallback( + async (promptId: string, newModel?: string) => { + dispatch({ + type: "UPDATE_PROMPT_MODEL", + id: promptId, + modelName: newModel, + }); + + try { + const prompt = clientPromptToAIConfigPrompt( + getPrompt(stateRef.current, promptId)! + ); + const currentModel = prompt.metadata?.model; + let modelData: string | ModelMetadata | undefined = newModel; + if (newModel && currentModel && typeof currentModel !== "string") { + modelData = { + ...currentModel, + name: newModel, + }; + } + + await debouncedUpdateModel(prompt.name, modelData); + + // TODO: Consolidate + } catch (err: any) { + showNotification({ + title: "Error updating prompt model", + message: err.message, + color: "red", + }); + } + }, + [dispatch, debouncedUpdateModel] + ); + const onUpdatePromptParameters = useCallback( - async (promptIndex: number, newParameters: any) => { + async (promptId: string, newParameters: any) => { dispatch({ type: "UPDATE_PROMPT_PARAMETERS", - index: promptIndex, + id: promptId, parameters: newParameters, }); // TODO: Call server-side endpoint to update prompt parameters @@ -223,9 +284,30 @@ export default function EditorContainer({ [callbacks.addPrompt, dispatch] ); + const onDeletePrompt = useCallback( + async (promptId: string) => { + dispatch({ + type: "DELETE_PROMPT", + id: promptId, + }); + + try { + const prompt = getPrompt(stateRef.current, promptId)!; + await callbacks.deletePrompt(prompt.name); + } catch (err: any) { + showNotification({ + title: "Error deleting prompt", + message: err.message, + color: "red", + }); + } + }, + [callbacks.deletePrompt, dispatch] + ); + const onRunPrompt = useCallback( - async (promptIndex: number) => { - const promptName = aiconfigState.prompts[promptIndex].name; + async (promptId: string) => { + const promptName = getPrompt(stateRef.current, promptId)!.name; try { await callbacks.runPrompt(promptName); } catch (err: any) { @@ -241,8 +323,6 @@ export default function EditorContainer({ const { classes } = useStyles(); - // TODO: Implement editor context for callbacks, readonly state, etc. - return ( <> @@ -259,16 +339,23 @@ export default function EditorContainer({ {aiconfigState.prompts.map((prompt: ClientPrompt, i: number) => { return ( - + + onDeletePrompt(prompt._ui.id)} + /> + +
ClientPrompt ): ClientAIConfig { return { ...state, - prompts: state.prompts.map((prompt, i) => - i === index ? replacerFn(prompt) : prompt + prompts: state.prompts.map((prompt) => + prompt._ui.id === id ? replacerFn(prompt) : prompt ), }; } function reduceReplaceInput( state: ClientAIConfig, - index: number, + id: string, replacerFn: (input: PromptInput) => PromptInput ): ClientAIConfig { - return reduceReplacePrompt(state, index, (prompt) => ({ + return reduceReplacePrompt(state, id, (prompt) => ({ ...prompt, input: replacerFn(prompt.input), })); @@ -94,22 +107,32 @@ function reduceConsolidateAIConfig( action: MutateAIConfigAction, responseConfig: AIConfig ): ClientAIConfig { + // Make sure prompt structure is properly updated. Client input and metadata takes precedence + // since it may have been updated by the user while the request was in flight + const consolidatePrompt = (statePrompt: ClientPrompt) => { + const responsePrompt = responseConfig.prompts.find( + (resPrompt) => resPrompt.name === statePrompt.name + ); + return { + ...responsePrompt, + ...statePrompt, + metadata: { + ...responsePrompt!.metadata, + ...statePrompt.metadata, + }, + } as ClientPrompt; + }; + switch (action.type) { - case "ADD_PROMPT_AT_INDEX": + case "ADD_PROMPT_AT_INDEX": { + return reduceReplacePrompt( + state, + action.prompt._ui.id, + consolidatePrompt + ); + } case "UPDATE_PROMPT_INPUT": { - // Make sure prompt structure is properly updated. Client input and metadata takes precedence - // since it may have been updated by the user while the request was in flight - return reduceReplacePrompt(state, action.index, (prompt) => { - const responsePrompt = responseConfig.prompts[action.index]; - return { - ...responsePrompt, - ...prompt, - metadata: { - ...responsePrompt.metadata, - ...prompt.metadata, - }, - } as ClientPrompt; - }); + return reduceReplacePrompt(state, action.id, consolidatePrompt); } default: { return state; @@ -125,17 +148,37 @@ export default function aiconfigReducer( case "ADD_PROMPT_AT_INDEX": { return reduceInsertPromptAtIndex(state, action.index, action.prompt); } + case "DELETE_PROMPT": { + return { + ...state, + prompts: state.prompts.filter((prompt) => prompt._ui.id !== action.id), + }; + } case "UPDATE_PROMPT_INPUT": { - return reduceReplaceInput(state, action.index, () => action.input); + return reduceReplaceInput(state, action.id, () => action.input); } case "UPDATE_PROMPT_NAME": { - return reduceReplacePrompt(state, action.index, (prompt) => ({ + return reduceReplacePrompt(state, action.id, (prompt) => ({ ...prompt, name: action.name, })); } + case "UPDATE_PROMPT_MODEL": { + return reduceReplacePrompt(state, action.id, (prompt) => ({ + ...prompt, + metadata: { + ...prompt.metadata, + model: action.modelName + ? { + name: action.modelName, + // TODO: Consolidate settings based on schema union + } + : undefined, + }, + })); + } case "UPDATE_PROMPT_MODEL_SETTINGS": { - return reduceReplacePrompt(state, action.index, (prompt) => ({ + return reduceReplacePrompt(state, action.id, (prompt) => ({ ...prompt, metadata: { ...prompt.metadata, @@ -152,7 +195,7 @@ export default function aiconfigReducer( })); } case "UPDATE_PROMPT_PARAMETERS": { - return reduceReplacePrompt(state, action.index, (prompt) => ({ + return reduceReplacePrompt(state, action.id, (prompt) => ({ ...prompt, metadata: { ...prompt.metadata, diff --git a/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx b/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx index cd18f0a07..d15380eb6 100644 --- a/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx +++ b/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx @@ -1,7 +1,7 @@ import { ActionIcon, Menu, TextInput } from "@mantine/core"; import { IconPlus, IconSearch, IconTextCaption } from "@tabler/icons-react"; -import { memo, useCallback, useEffect, useState } from "react"; -import { showNotification } from "@mantine/notifications"; +import { memo, useCallback, useState } from "react"; +import useLoadModels from "../../hooks/useLoadModels"; type Props = { addPrompt: (prompt: string) => void; @@ -41,30 +41,14 @@ function ModelMenuItems({ export default memo(function AddPromptButton({ addPrompt, getModels }: Props) { const [modelSearch, setModelSearch] = useState(""); - const [models, setModels] = useState([]); const [isOpen, setIsOpen] = useState(false); - const loadModels = useCallback(async (modelSearch: string) => { - try { - const models = await getModels(modelSearch); - setModels(models); - } catch (err: any) { - showNotification({ - title: "Error loading models", - message: err?.message, - color: "red", - }); - } - }, []); - const onAddPrompt = useCallback((model: string) => { addPrompt(model); setIsOpen(false); }, []); - useEffect(() => { - loadModels(modelSearch); - }, [loadModels, modelSearch]); + const models = useLoadModels(modelSearch, getModels); return ( Promise; + onSetModel: (model?: string) => void; + defaultConfigModelName?: string; +}; + +export default memo(function ModelSelector({ + prompt, + getModels, + onSetModel, + defaultConfigModelName, +}: Props) { + const [selectedModel, setSelectedModel] = useState( + getPromptModelName(prompt, defaultConfigModelName) + ); + const [showAll, setShowAll] = useState(true); + const [autocompleteSearch, setAutocompleteSearch] = useState( + getPromptModelName(prompt, defaultConfigModelName) + ); + + const models = useLoadModels(showAll ? "" : autocompleteSearch, getModels); + + const onSelectModel = (model?: string) => { + setSelectedModel(model); + onSetModel(model); + }; + + return ( + { + onSelectModel(undefined); + setShowAll(true); + setAutocompleteSearch(""); + }} + > + + + ) : null + } + filter={(searchValue: string, item: AutocompleteItem) => { + if (showAll) { + return true; + } + + const modelName: string = item.value; + return modelName + .toLocaleLowerCase() + .includes(searchValue.toLocaleLowerCase().trim()); + }} + data={models} + value={autocompleteSearch} + onChange={(value: string) => { + setAutocompleteSearch(value); + setShowAll(false); + onSelectModel(value); + models.some((model) => { + if (model === value) { + setShowAll(true); + return true; + } + }); + }} + /> + ); +}); diff --git a/python/src/aiconfig/editor/client/src/components/prompt/PromptContainer.tsx b/python/src/aiconfig/editor/client/src/components/prompt/PromptContainer.tsx index de38a03ed..2f4acb7f5 100644 --- a/python/src/aiconfig/editor/client/src/components/prompt/PromptContainer.tsx +++ b/python/src/aiconfig/editor/client/src/components/prompt/PromptContainer.tsx @@ -2,25 +2,27 @@ import PromptActionBar from "./PromptActionBar"; import PromptInputRenderer from "./prompt_input/PromptInputRenderer"; import PromptOutputsRenderer from "./prompt_outputs/PromptOutputsRenderer"; import { ClientPrompt } from "../../shared/types"; -import { getPromptModelName, getPromptSchema } from "../../utils/promptUtils"; -import { Flex, Card, Text, createStyles } from "@mantine/core"; +import { getPromptSchema } from "../../utils/promptUtils"; +import { Flex, Card, createStyles } from "@mantine/core"; import { PromptInput as AIConfigPromptInput } from "aiconfig"; import { memo, useCallback } from "react"; import { ParametersArray } from "../ParametersRenderer"; import PromptOutputBar from "./PromptOutputBar"; import PromptName from "./PromptName"; +import ModelSelector from "./ModelSelector"; type Props = { - index: number; prompt: ClientPrompt; + getModels: (search: string) => Promise; onChangePromptInput: ( - promptIndex: number, + promptId: string, newPromptInput: AIConfigPromptInput ) => void; - onChangePromptName: (promptIndex: number, newName: string) => void; - onRunPrompt(promptIndex: number): Promise; - onUpdateModelSettings: (promptIndex: number, newModelSettings: any) => void; - onUpdateParameters: (promptIndex: number, newParameters: any) => void; + onChangePromptName: (promptId: string, newName: string) => void; + onRunPrompt(promptId: string): Promise; + onUpdateModel: (promptId: string, newModel?: string) => void; + onUpdateModelSettings: (ppromptId: string, newModelSettings: any) => void; + onUpdateParameters: (promptId: string, newParameters: any) => void; defaultConfigModelName?: string; }; @@ -41,27 +43,30 @@ const useStyles = createStyles((theme) => ({ export default memo(function PromptContainer({ prompt, - index, + getModels, onChangePromptInput, onChangePromptName, defaultConfigModelName, onRunPrompt, + onUpdateModel, onUpdateModelSettings, onUpdateParameters, }: Props) { + const promptId = prompt._ui.id; const onChangeInput = useCallback( - (newInput: AIConfigPromptInput) => onChangePromptInput(index, newInput), - [index, onChangePromptInput] + (newInput: AIConfigPromptInput) => onChangePromptInput(promptId, newInput), + [promptId, onChangePromptInput] ); const onChangeName = useCallback( - (newName: string) => onChangePromptName(index, newName), - [index, onChangePromptName] + (newName: string) => onChangePromptName(promptId, newName), + [promptId, onChangePromptName] ); const updateModelSettings = useCallback( - (newModelSettings: any) => onUpdateModelSettings(index, newModelSettings), - [index, onUpdateModelSettings] + (newModelSettings: any) => + onUpdateModelSettings(promptId, newModelSettings), + [promptId, onUpdateModelSettings] ); const updateParameters = useCallback( @@ -77,19 +82,24 @@ export default memo(function PromptContainer({ newParameters[key] = val; } - onUpdateParameters(index, newParameters); + onUpdateParameters(promptId, newParameters); }, - [index, onUpdateParameters] + [promptId, onUpdateParameters] ); const runPrompt = useCallback( - async () => await onRunPrompt(index), - [index, onRunPrompt] + async () => await onRunPrompt(promptId), + [promptId, onRunPrompt] + ); + + const updateModel = useCallback( + (model?: string) => onUpdateModel(promptId, model), + [promptId, onUpdateModel] ); // TODO: When adding support for custom PromptContainers, implement a PromptContainerRenderer which - // will take in the index and callback and render the appropriate PromptContainer with new memoized - // callback and not having to pass index down to PromptContainer + // will take in the promptId and callback and render the appropriate PromptContainer with new memoized + // callback and not having to pass promptId down to PromptContainer const promptSchema = getPromptSchema(prompt, defaultConfigModelName); const inputSchema = promptSchema?.input; @@ -97,12 +107,17 @@ export default memo(function PromptContainer({ const { classes } = useStyles(); return ( - + - {getPromptModelName(prompt, defaultConfigModelName)} + ({ + promptMenuButton: { + marginLeft: -8, + }, +})); + +export default memo(function PromptMenuButton({ + promptId, + onDeletePrompt, +}: { + promptId: string; + onDeletePrompt: (id: string) => void; +}) { + const { classes } = useStyles(); + + return ( + + + + + + + } + color="red" + onClick={() => onDeletePrompt(promptId)} + > + Delete Prompt + + + + ); +}); diff --git a/python/src/aiconfig/editor/client/src/hooks/useLoadModels.ts b/python/src/aiconfig/editor/client/src/hooks/useLoadModels.ts new file mode 100644 index 000000000..0da3b8a5d --- /dev/null +++ b/python/src/aiconfig/editor/client/src/hooks/useLoadModels.ts @@ -0,0 +1,31 @@ +import { useCallback, useEffect, useState } from "react"; +import { showNotification } from "@mantine/notifications"; + +export default function useLoadModels( + modelSearch: string, + getModels: (search: string) => Promise +) { + const [models, setModels] = useState([]); + + const loadModels = useCallback( + async (modelSearch: string) => { + try { + const models = await getModels(modelSearch); + setModels(models); + } catch (err: any) { + showNotification({ + title: "Error loading models", + message: err?.message, + color: "red", + }); + } + }, + [getModels] + ); + + useEffect(() => { + loadModels(modelSearch); + }, [loadModels, modelSearch]); + + return models; +} diff --git a/python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts b/python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts index 3e25d0a4b..56b5641bf 100644 --- a/python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts +++ b/python/src/aiconfig/editor/client/src/utils/aiconfigStateUtils.ts @@ -1,4 +1,5 @@ import { AIConfig } from "aiconfig"; +import { ClientAIConfig, ClientPrompt } from "../shared/types"; export function getDefaultNewPromptName(aiconfig: AIConfig): string { const existingNames = aiconfig.prompts.map((prompt) => prompt.name); @@ -8,3 +9,10 @@ export function getDefaultNewPromptName(aiconfig: AIConfig): string { } return `prompt_${i}`; } + +export function getPrompt( + aiconfig: ClientAIConfig, + id: string +): ClientPrompt | undefined { + return aiconfig.prompts.find((prompt) => prompt._ui.id === id); +} diff --git a/python/src/aiconfig/editor/client/src/utils/api.ts b/python/src/aiconfig/editor/client/src/utils/api.ts index 9959e039f..44c7e49b1 100644 --- a/python/src/aiconfig/editor/client/src/utils/api.ts +++ b/python/src/aiconfig/editor/client/src/utils/api.ts @@ -10,9 +10,11 @@ const API_ENDPOINT = `${HOST_ENDPOINT}/api`; export const ROUTE_TABLE = { ADD_PROMPT: urlJoin(API_ENDPOINT, "/add_prompt"), + DELETE_PROMPT: urlJoin(API_ENDPOINT, "/delete_prompt"), SAVE: urlJoin(API_ENDPOINT, "/save"), LOAD: urlJoin(API_ENDPOINT, "/load"), LIST_MODELS: urlJoin(API_ENDPOINT, "/list_models"), RUN_PROMPT: urlJoin(API_ENDPOINT, "/run"), + UPDATE_MODEL: urlJoin(API_ENDPOINT, "/update_model"), UPDATE_PROMPT: urlJoin(API_ENDPOINT, "/update_prompt"), };