Skip to content

Commit

Permalink
[editor][1/n] Update list_models to return {id, parser_id} Model objects
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan Holinshead committed Jan 6, 2024
1 parent 5b28259 commit 19578d9
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 20 deletions.
7 changes: 4 additions & 3 deletions python/src/aiconfig/editor/client/src/Editor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { AIConfig, InferenceSettings, JSONObject, Prompt } from "aiconfig";
import { useCallback, useEffect, useMemo, useState } from "react";
import { ufetch } from "ufetch";
import { ROUTE_TABLE } from "./utils/api";
import { Model } from "./shared/types";

export default function Editor() {
const [aiconfig, setAiConfig] = useState<AIConfig | undefined>();
Expand All @@ -28,16 +29,16 @@ export default function Editor() {
return res;
}, []);

const getModels = useCallback(async (search: string) => {
const getModels = useCallback(async (search?: string) => {
// For now, rely on caching and handle client-side search filtering
// We will use server-side search filtering for Gradio
const res = await ufetch.get(ROUTE_TABLE.LIST_MODELS);
const models = res.data;
if (search && search.length > 0) {
const lowerCaseSearch = search.toLowerCase();
return models.filter(
(model: string) =>
model.toLocaleLowerCase().indexOf(lowerCaseSearch) >= 0
(model: Model) =>
model.id.toLocaleLowerCase().indexOf(lowerCaseSearch) >= 0
);
}
return models;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import {
import aiconfigReducer, { AIConfigReducerAction } from "./aiconfigReducer";
import {
ClientPrompt,
Model,
aiConfigToClientConfig,
clientConfigToAIConfig,
clientPromptToAIConfigPrompt,
Expand Down Expand Up @@ -63,7 +64,7 @@ export type AIConfigCallbacks = {
index: number
) => Promise<{ aiconfig: AIConfig }>;
deletePrompt: (promptName: string) => Promise<void>;
getModels: (search: string) => Promise<string[]>;
getModels: (search: string) => Promise<Model[]>;
getServerStatus?: () => Promise<{ status: "OK" | "ERROR" }>;
runPrompt: (promptName: string) => Promise<{ aiconfig: AIConfig }>;
save: (aiconfig: AIConfig) => Promise<void>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@ import {
import { IconPlus, IconSearch, IconTextCaption } from "@tabler/icons-react";
import { memo, useCallback, useState } from "react";
import useLoadModels from "../../hooks/useLoadModels";
import { Model } from "../../shared/types";

type Props = {
addPrompt: (prompt: string) => void;
getModels: (search: string) => Promise<string[]>;
getModels: (search: string) => Promise<Model[]>;
};

function ModelMenuItems({
models,
onSelectModel,
collapseLimit,
}: {
models: string[];
onSelectModel: (model: string) => void;
models: Model[];
onSelectModel: (model: Model) => void;
collapseLimit: number;
}) {
const [isCollapsed, setIsCollapsed] = useState(models.length > collapseLimit);
Expand All @@ -31,11 +32,11 @@ function ModelMenuItems({
<ScrollArea mah={300} style={{ overflowY: "auto" }}>
{displayModels.map((model) => (
<Menu.Item
key={model}
key={model.id}
icon={<IconTextCaption size="16" />}
onClick={() => onSelectModel(model)}
>
{model}
{model.id}
</Menu.Item>
))}
{isCollapsed && (
Expand All @@ -50,8 +51,8 @@ export default memo(function AddPromptButton({ addPrompt, getModels }: Props) {
const [isOpen, setIsOpen] = useState(false);

const onAddPrompt = useCallback(
(model: string) => {
addPrompt(model);
(model: Model) => {
addPrompt(model.id);
setIsOpen(false);
},
[addPrompt]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { Autocomplete, AutocompleteItem, Button } from "@mantine/core";
import { memo, useState } from "react";
import { memo, useMemo, useState } from "react";
import { getPromptModelName } from "../../utils/promptUtils";
import { Prompt } from "aiconfig";
import useLoadModels from "../../hooks/useLoadModels";
import { IconX } from "@tabler/icons-react";
import { Model } from "../../shared/types";

type Props = {
prompt: Prompt;
getModels: (search: string) => Promise<string[]>;
getModels: (search: string) => Promise<Model[]>;
onSetModel: (model?: string) => void;
defaultConfigModelName?: string;
};
Expand All @@ -31,6 +32,8 @@ export default memo(function ModelSelector({
getModels
);

const modelNames = useMemo(() => models.map((model) => model.id), [models]);

const onSelectModel = (model?: string) => {
setSelectedModel(model);
onSetModel(model);
Expand Down Expand Up @@ -70,13 +73,13 @@ export default memo(function ModelSelector({
.toLocaleLowerCase()
.includes(searchValue.toLocaleLowerCase().trim());
}}
data={models}
data={modelNames}
value={autocompleteSearch}
onChange={(value: string) => {
setAutocompleteSearch(value);
setShowAll(false);
onSelectModel(value);
models.some((model) => {
modelNames.some((model) => {
if (model === value) {
setShowAll(true);
return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import PromptActionBar from "./PromptActionBar";
import PromptInputRenderer from "./prompt_input/PromptInputRenderer";
import PromptOutputsRenderer from "./prompt_outputs/PromptOutputsRenderer";
import { ClientPrompt } from "../../shared/types";
import { ClientPrompt, Model } from "../../shared/types";
import { getPromptSchema } from "../../utils/promptUtils";
import { Flex, Card } from "@mantine/core";
import { PromptInput as AIConfigPromptInput, JSONObject } from "aiconfig";
Expand All @@ -12,7 +12,7 @@ import ModelSelector from "./ModelSelector";

type Props = {
prompt: ClientPrompt;
getModels: (search: string) => Promise<string[]>;
getModels: (search: string) => Promise<Model[]>;
onChangePromptInput: (
promptId: string,
newPromptInput: AIConfigPromptInput
Expand Down
5 changes: 3 additions & 2 deletions python/src/aiconfig/editor/client/src/hooks/useLoadModels.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { useCallback, useEffect, useState } from "react";
import { showNotification } from "@mantine/notifications";
import { Model } from "../shared/types";

export default function useLoadModels(
modelSearch: string,
getModels: (search: string) => Promise<string[]>
getModels: (search: string) => Promise<Model[]>
) {
const [models, setModels] = useState<string[]>([]);
const [models, setModels] = useState<Model[]>([]);

const loadModels = useCallback(
async (modelSearch: string) => {
Expand Down
5 changes: 5 additions & 0 deletions python/src/aiconfig/editor/client/src/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ export type EditorFile = {
disabled?: boolean;
};

export type Model = {
id: string;
parser_id: string;
};

export type ClientPrompt = Prompt & {
_ui: {
id: string;
Expand Down
3 changes: 2 additions & 1 deletion python/src/aiconfig/editor/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def server_status() -> FlaskResponse:

@app.route("/api/list_models", methods=["GET"])
def list_models() -> FlaskResponse:
out: list[str] = ModelParserRegistry.parser_ids() # type: ignore
model_parser_map = ModelParserRegistry.display_parsers()
out: list[core_utils.JSONDict] = [{"id": model_name, "parser_id": model_parser} for model_name, model_parser in model_parser_map.items()]
json_obj = core_utils.JSONObject({"data": core_utils.JSONList(out)})
return FlaskResponse((json_obj, 200))

Expand Down

0 comments on commit 19578d9

Please sign in to comment.