diff --git a/python/src/aiconfig/editor/client/src/Editor.tsx b/python/src/aiconfig/editor/client/src/Editor.tsx index 2a077096d..cb7529bd3 100644 --- a/python/src/aiconfig/editor/client/src/Editor.tsx +++ b/python/src/aiconfig/editor/client/src/Editor.tsx @@ -6,6 +6,7 @@ import { AIConfig, InferenceSettings, JSONObject, Prompt } from "aiconfig"; import { useCallback, useEffect, useMemo, useState } from "react"; import { ufetch } from "ufetch"; import { ROUTE_TABLE } from "./utils/api"; +import { Model } from "./shared/types"; export default function Editor() { const [aiconfig, setAiConfig] = useState(); @@ -28,7 +29,7 @@ export default function Editor() { return res; }, []); - const getModels = useCallback(async (search: string) => { + const getModels = useCallback(async (search?: string) => { // For now, rely on caching and handle client-side search filtering // We will use server-side search filtering for Gradio const res = await ufetch.get(ROUTE_TABLE.LIST_MODELS); @@ -36,8 +37,8 @@ export default function Editor() { if (search && search.length > 0) { const lowerCaseSearch = search.toLowerCase(); return models.filter( - (model: string) => - model.toLocaleLowerCase().indexOf(lowerCaseSearch) >= 0 + (model: Model) => + model.id.toLocaleLowerCase().indexOf(lowerCaseSearch) >= 0 ); } return models; diff --git a/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx b/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx index 73488081d..271d9ba9d 100644 --- a/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx +++ b/python/src/aiconfig/editor/client/src/components/EditorContainer.tsx @@ -28,6 +28,7 @@ import { import aiconfigReducer, { AIConfigReducerAction } from "./aiconfigReducer"; import { ClientPrompt, + Model, aiConfigToClientConfig, clientConfigToAIConfig, clientPromptToAIConfigPrompt, @@ -63,7 +64,7 @@ export type AIConfigCallbacks = { index: number ) => Promise<{ aiconfig: AIConfig }>; deletePrompt: (promptName: string) => Promise; - getModels: (search: string) => Promise; + getModels: (search: string) => Promise; getServerStatus?: () => Promise<{ status: "OK" | "ERROR" }>; runPrompt: (promptName: string) => Promise<{ aiconfig: AIConfig }>; save: (aiconfig: AIConfig) => Promise; diff --git a/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx b/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx index f104c1c46..b5c11fb4f 100644 --- a/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx +++ b/python/src/aiconfig/editor/client/src/components/prompt/AddPromptButton.tsx @@ -8,10 +8,11 @@ import { import { IconPlus, IconSearch, IconTextCaption } from "@tabler/icons-react"; import { memo, useCallback, useState } from "react"; import useLoadModels from "../../hooks/useLoadModels"; +import { Model } from "../../shared/types"; type Props = { addPrompt: (prompt: string) => void; - getModels: (search: string) => Promise; + getModels: (search: string) => Promise; }; function ModelMenuItems({ @@ -19,8 +20,8 @@ function ModelMenuItems({ onSelectModel, collapseLimit, }: { - models: string[]; - onSelectModel: (model: string) => void; + models: Model[]; + onSelectModel: (model: Model) => void; collapseLimit: number; }) { const [isCollapsed, setIsCollapsed] = useState(models.length > collapseLimit); @@ -31,11 +32,11 @@ function ModelMenuItems({ {displayModels.map((model) => ( } onClick={() => onSelectModel(model)} > - {model} + {model.id} ))} {isCollapsed && ( @@ -50,8 +51,8 @@ export default memo(function AddPromptButton({ addPrompt, getModels }: Props) { const [isOpen, setIsOpen] = useState(false); const onAddPrompt = useCallback( - (model: string) => { - addPrompt(model); + (model: Model) => { + addPrompt(model.id); setIsOpen(false); }, [addPrompt] diff --git a/python/src/aiconfig/editor/client/src/components/prompt/ModelSelector.tsx b/python/src/aiconfig/editor/client/src/components/prompt/ModelSelector.tsx index 143f392b6..16f445a1f 100644 --- a/python/src/aiconfig/editor/client/src/components/prompt/ModelSelector.tsx +++ b/python/src/aiconfig/editor/client/src/components/prompt/ModelSelector.tsx @@ -1,13 +1,14 @@ import { Autocomplete, AutocompleteItem, Button } from "@mantine/core"; -import { memo, useState } from "react"; +import { memo, useMemo, useState } from "react"; import { getPromptModelName } from "../../utils/promptUtils"; import { Prompt } from "aiconfig"; import useLoadModels from "../../hooks/useLoadModels"; import { IconX } from "@tabler/icons-react"; +import { Model } from "../../shared/types"; type Props = { prompt: Prompt; - getModels: (search: string) => Promise; + getModels: (search: string) => Promise; onSetModel: (model?: string) => void; defaultConfigModelName?: string; }; @@ -31,6 +32,8 @@ export default memo(function ModelSelector({ getModels ); + const modelNames = useMemo(() => models.map((model) => model.id), [models]); + const onSelectModel = (model?: string) => { setSelectedModel(model); onSetModel(model); @@ -70,13 +73,13 @@ export default memo(function ModelSelector({ .toLocaleLowerCase() .includes(searchValue.toLocaleLowerCase().trim()); }} - data={models} + data={modelNames} value={autocompleteSearch} onChange={(value: string) => { setAutocompleteSearch(value); setShowAll(false); onSelectModel(value); - models.some((model) => { + modelNames.some((model) => { if (model === value) { setShowAll(true); return true; diff --git a/python/src/aiconfig/editor/client/src/components/prompt/PromptContainer.tsx b/python/src/aiconfig/editor/client/src/components/prompt/PromptContainer.tsx index c68be9b8d..8e84191c5 100644 --- a/python/src/aiconfig/editor/client/src/components/prompt/PromptContainer.tsx +++ b/python/src/aiconfig/editor/client/src/components/prompt/PromptContainer.tsx @@ -1,7 +1,7 @@ import PromptActionBar from "./PromptActionBar"; import PromptInputRenderer from "./prompt_input/PromptInputRenderer"; import PromptOutputsRenderer from "./prompt_outputs/PromptOutputsRenderer"; -import { ClientPrompt } from "../../shared/types"; +import { ClientPrompt, Model } from "../../shared/types"; import { getPromptSchema } from "../../utils/promptUtils"; import { Flex, Card } from "@mantine/core"; import { PromptInput as AIConfigPromptInput, JSONObject } from "aiconfig"; @@ -12,7 +12,7 @@ import ModelSelector from "./ModelSelector"; type Props = { prompt: ClientPrompt; - getModels: (search: string) => Promise; + getModels: (search: string) => Promise; onChangePromptInput: ( promptId: string, newPromptInput: AIConfigPromptInput diff --git a/python/src/aiconfig/editor/client/src/hooks/useLoadModels.ts b/python/src/aiconfig/editor/client/src/hooks/useLoadModels.ts index 557a898d2..dfbf5ef2d 100644 --- a/python/src/aiconfig/editor/client/src/hooks/useLoadModels.ts +++ b/python/src/aiconfig/editor/client/src/hooks/useLoadModels.ts @@ -1,11 +1,12 @@ import { useCallback, useEffect, useState } from "react"; import { showNotification } from "@mantine/notifications"; +import { Model } from "../shared/types"; export default function useLoadModels( modelSearch: string, - getModels: (search: string) => Promise + getModels: (search: string) => Promise ) { - const [models, setModels] = useState([]); + const [models, setModels] = useState([]); const loadModels = useCallback( async (modelSearch: string) => { diff --git a/python/src/aiconfig/editor/client/src/shared/types.ts b/python/src/aiconfig/editor/client/src/shared/types.ts index 69069fa08..1bdae13ff 100644 --- a/python/src/aiconfig/editor/client/src/shared/types.ts +++ b/python/src/aiconfig/editor/client/src/shared/types.ts @@ -9,6 +9,11 @@ export type EditorFile = { disabled?: boolean; }; +export type Model = { + id: string; + parser_id: string; +}; + export type ClientPrompt = Prompt & { _ui: { id: string; diff --git a/python/src/aiconfig/editor/server/server.py b/python/src/aiconfig/editor/server/server.py index 06fe07b1c..217ad9f08 100644 --- a/python/src/aiconfig/editor/server/server.py +++ b/python/src/aiconfig/editor/server/server.py @@ -91,7 +91,16 @@ def server_status() -> FlaskResponse: @app.route("/api/list_models", methods=["GET"]) def list_models() -> FlaskResponse: - out: list[str] = ModelParserRegistry.parser_ids() # type: ignore + model_parser_map = ModelParserRegistry.display_parsers() + out: list[core_utils.JSONDict] = [{"id": model_name, "parser_id": model_parser} for model_name, model_parser in model_parser_map.items()] + json_obj = core_utils.JSONObject({"data": core_utils.JSONList(out)}) + return FlaskResponse((json_obj, 200)) + + +@app.route("/api/list_model_parsers", methods=["GET"]) +def list_model_parsers() -> FlaskResponse: + model_parser_map = ModelParserRegistry.display_parsers() + out: list[str] = list(model_parser_map.values()) json_obj = core_utils.JSONObject({"data": core_utils.JSONList(out)}) return FlaskResponse((json_obj, 200))