Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[editor][2/n] Add list_model_parsers endpoint #796

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/src/aiconfig/editor/client/src/Editor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { AIConfig, InferenceSettings, JSONObject, Prompt } from "aiconfig";
import { useCallback, useEffect, useMemo, useState } from "react";
import { ufetch } from "ufetch";
import { ROUTE_TABLE } from "./utils/api";
import { Model } from "./shared/types";

export default function Editor() {
const [aiconfig, setAiConfig] = useState<AIConfig | undefined>();
Expand All @@ -28,16 +29,16 @@ export default function Editor() {
return res;
}, []);

const getModels = useCallback(async (search: string) => {
const getModels = useCallback(async (search?: string) => {
// For now, rely on caching and handle client-side search filtering
// We will use server-side search filtering for Gradio
const res = await ufetch.get(ROUTE_TABLE.LIST_MODELS);
const models = res.data;
if (search && search.length > 0) {
const lowerCaseSearch = search.toLowerCase();
return models.filter(
(model: string) =>
model.toLocaleLowerCase().indexOf(lowerCaseSearch) >= 0
(model: Model) =>
model.id.toLocaleLowerCase().indexOf(lowerCaseSearch) >= 0
);
}
return models;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import {
import aiconfigReducer, { AIConfigReducerAction } from "./aiconfigReducer";
import {
ClientPrompt,
Model,
aiConfigToClientConfig,
clientConfigToAIConfig,
clientPromptToAIConfigPrompt,
Expand Down Expand Up @@ -63,7 +64,7 @@ export type AIConfigCallbacks = {
index: number
) => Promise<{ aiconfig: AIConfig }>;
deletePrompt: (promptName: string) => Promise<void>;
getModels: (search: string) => Promise<string[]>;
getModels: (search: string) => Promise<Model[]>;
getServerStatus?: () => Promise<{ status: "OK" | "ERROR" }>;
runPrompt: (promptName: string) => Promise<{ aiconfig: AIConfig }>;
save: (aiconfig: AIConfig) => Promise<void>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@ import {
import { IconPlus, IconSearch, IconTextCaption } from "@tabler/icons-react";
import { memo, useCallback, useState } from "react";
import useLoadModels from "../../hooks/useLoadModels";
import { Model } from "../../shared/types";

type Props = {
addPrompt: (prompt: string) => void;
getModels: (search: string) => Promise<string[]>;
getModels: (search: string) => Promise<Model[]>;
};

function ModelMenuItems({
models,
onSelectModel,
collapseLimit,
}: {
models: string[];
onSelectModel: (model: string) => void;
models: Model[];
onSelectModel: (model: Model) => void;
collapseLimit: number;
}) {
const [isCollapsed, setIsCollapsed] = useState(models.length > collapseLimit);
Expand All @@ -31,11 +32,11 @@ function ModelMenuItems({
<ScrollArea mah={300} style={{ overflowY: "auto" }}>
{displayModels.map((model) => (
<Menu.Item
key={model}
key={model.id}
icon={<IconTextCaption size="16" />}
onClick={() => onSelectModel(model)}
>
{model}
{model.id}
</Menu.Item>
))}
{isCollapsed && (
Expand All @@ -50,8 +51,8 @@ export default memo(function AddPromptButton({ addPrompt, getModels }: Props) {
const [isOpen, setIsOpen] = useState(false);

const onAddPrompt = useCallback(
(model: string) => {
addPrompt(model);
(model: Model) => {
addPrompt(model.id);
setIsOpen(false);
},
[addPrompt]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { Autocomplete, AutocompleteItem, Button } from "@mantine/core";
import { memo, useState } from "react";
import { memo, useMemo, useState } from "react";
import { getPromptModelName } from "../../utils/promptUtils";
import { Prompt } from "aiconfig";
import useLoadModels from "../../hooks/useLoadModels";
import { IconX } from "@tabler/icons-react";
import { Model } from "../../shared/types";

type Props = {
prompt: Prompt;
getModels: (search: string) => Promise<string[]>;
getModels: (search: string) => Promise<Model[]>;
onSetModel: (model?: string) => void;
defaultConfigModelName?: string;
};
Expand All @@ -31,6 +32,8 @@ export default memo(function ModelSelector({
getModels
);

const modelNames = useMemo(() => models.map((model) => model.id), [models]);

const onSelectModel = (model?: string) => {
setSelectedModel(model);
onSetModel(model);
Expand Down Expand Up @@ -70,13 +73,13 @@ export default memo(function ModelSelector({
.toLocaleLowerCase()
.includes(searchValue.toLocaleLowerCase().trim());
}}
data={models}
data={modelNames}
value={autocompleteSearch}
onChange={(value: string) => {
setAutocompleteSearch(value);
setShowAll(false);
onSelectModel(value);
models.some((model) => {
modelNames.some((model) => {
if (model === value) {
setShowAll(true);
return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import PromptActionBar from "./PromptActionBar";
import PromptInputRenderer from "./prompt_input/PromptInputRenderer";
import PromptOutputsRenderer from "./prompt_outputs/PromptOutputsRenderer";
import { ClientPrompt } from "../../shared/types";
import { ClientPrompt, Model } from "../../shared/types";
import { getPromptSchema } from "../../utils/promptUtils";
import { Flex, Card } from "@mantine/core";
import { PromptInput as AIConfigPromptInput, JSONObject } from "aiconfig";
Expand All @@ -12,7 +12,7 @@ import ModelSelector from "./ModelSelector";

type Props = {
prompt: ClientPrompt;
getModels: (search: string) => Promise<string[]>;
getModels: (search: string) => Promise<Model[]>;
onChangePromptInput: (
promptId: string,
newPromptInput: AIConfigPromptInput
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { useCallback, useEffect, useState } from "react";
import { showNotification } from "@mantine/notifications";
import { Model } from "../shared/types";

export default function useLoadModels(
modelSearch: string,
getModels: (search: string) => Promise<string[]>
getModels: (search: string) => Promise<Model[]>
) {
const [models, setModels] = useState<string[]>([]);
const [models, setModels] = useState<Model[]>([]);

const loadModels = useCallback(
async (modelSearch: string) => {
Expand Down
5 changes: 5 additions & 0 deletions python/src/aiconfig/editor/client/src/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ export type EditorFile = {
disabled?: boolean;
};

export type Model = {
id: string;
parser_id: string;
};

export type ClientPrompt = Prompt & {
_ui: {
id: string;
Expand Down
11 changes: 10 additions & 1 deletion python/src/aiconfig/editor/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,16 @@ def server_status() -> FlaskResponse:

@app.route("/api/list_models", methods=["GET"])
def list_models() -> FlaskResponse:
out: list[str] = ModelParserRegistry.parser_ids() # type: ignore
model_parser_map = ModelParserRegistry.display_parsers()
out: list[core_utils.JSONDict] = [{"id": model_name, "parser_id": model_parser} for model_name, model_parser in model_parser_map.items()]
json_obj = core_utils.JSONObject({"data": core_utils.JSONList(out)})
return FlaskResponse((json_obj, 200))


@app.route("/api/list_model_parsers", methods=["GET"])
def list_model_parsers() -> FlaskResponse:
model_parser_map = ModelParserRegistry.display_parsers()
out: list[str] = list(model_parser_map.values())
json_obj = core_utils.JSONObject({"data": core_utils.JSONList(out)})
return FlaskResponse((json_obj, 200))

Expand Down