Skip to content

Commit

Permalink
Model Selector
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan Holinshead committed Dec 29, 2023
1 parent c88b9c9 commit 434a192
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 27 deletions.
13 changes: 12 additions & 1 deletion python/src/aiconfig/editor/client/src/Editor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import EditorContainer, {
AIConfigCallbacks,
} from "./components/EditorContainer";
import { Flex, Loader, MantineProvider } from "@mantine/core";
import { AIConfig, Prompt } from "aiconfig";
import { AIConfig, ModelMetadata, Prompt } from "aiconfig";
import { useCallback, useEffect, useMemo, useState } from "react";
import { ufetch } from "ufetch";
import { ROUTE_TABLE } from "./utils/api";
Expand Down Expand Up @@ -66,12 +66,23 @@ export default function Editor() {
[]
);

const updateModel = useCallback(
async (_promptName?: string, _modelData?: string | ModelMetadata) => {
// return await ufetch.post(ROUTE_TABLE.UPDATE_MODEL,
// prompt_name: promptName,
// model_data: modelData,
// });
},
[]
);

const callbacks: AIConfigCallbacks = useMemo(
() => ({
addPrompt,
getModels,
runPrompt,
save,
updateModel,
updatePrompt,
}),
[save, getModels, addPrompt, runPrompt]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import PromptContainer from "./prompt/PromptContainer";
import { Container, Group, Button, createStyles, Stack } from "@mantine/core";
import { showNotification } from "@mantine/notifications";
import { AIConfig, Prompt, PromptInput } from "aiconfig";
import { AIConfig, ModelMetadata, Prompt, PromptInput } from "aiconfig";
import { useCallback, useMemo, useReducer, useRef, useState } from "react";
import aiconfigReducer, { AIConfigReducerAction } from "./aiconfigReducer";
import {
Expand All @@ -28,6 +28,10 @@ export type AIConfigCallbacks = {
getModels: (search: string) => Promise<string[]>;
runPrompt: (promptName: string) => Promise<void>;
save: (aiconfig: AIConfig) => Promise<void>;
updateModel: (
promptName?: string,
modelData?: string | ModelMetadata
) => Promise<void /*{ aiconfig: AIConfig }*/>;
updatePrompt: (
promptName: string,
promptData: Prompt
Expand Down Expand Up @@ -150,14 +154,59 @@ export default function EditorContainer({
[dispatch]
);

const debouncedUpdateModel = useMemo(
() =>
debounce(
(promptName?: string, modelMetadata?: string | ModelMetadata) =>
callbacks.updateModel(promptName, modelMetadata),
250
),
[callbacks.updatePrompt]
);

const onUpdatePromptModelSettings = useCallback(
async (promptIndex: number, newModelSettings: any) => {
dispatch({
type: "UPDATE_PROMPT_MODEL_SETTINGS",
index: promptIndex,
modelSettings: newModelSettings,
});
// TODO: Call server-side endpoint to update model settings
// TODO: Call server-side endpoint to update model
},
[dispatch]
);

const onUpdatePromptModel = useCallback(
async (promptIndex: number, newModel?: string) => {
dispatch({
type: "UPDATE_PROMPT_MODEL",
index: promptIndex,
modelName: newModel,
});

try {
const prompt = clientPromptToAIConfigPrompt(
aiconfigState.prompts[promptIndex]
);
const currentModel = prompt.metadata?.model;
let modelData: string | ModelMetadata | undefined = newModel;
if (newModel && currentModel && typeof currentModel !== "string") {
modelData = {
...currentModel,
name: newModel,
};
}

await debouncedUpdateModel(prompt.name, modelData);

// TODO: Consolidate
} catch (err: any) {
showNotification({
title: "Error updating prompt model",
message: err.message,
color: "red",
});
}
},
[dispatch]
);
Expand Down Expand Up @@ -241,8 +290,6 @@ export default function EditorContainer({

const { classes } = useStyles();

// TODO: Implement editor context for callbacks, readonly state, etc.

return (
<>
<Container maw="80rem">
Expand All @@ -262,9 +309,11 @@ export default function EditorContainer({
<PromptContainer
index={i}
prompt={prompt}
getModels={callbacks.getModels}
onChangePromptInput={onChangePromptInput}
onChangePromptName={onChangePromptName}
onRunPrompt={onRunPrompt}
onUpdateModel={onUpdatePromptModel}
onUpdateModelSettings={onUpdatePromptModelSettings}
onUpdateParameters={onUpdatePromptParameters}
defaultConfigModelName={aiconfigState.metadata.default_model}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export type MutateAIConfigAction =
| AddPromptAction
| UpdatePromptInputAction
| UpdatePromptNameAction
| UpdatePromptModelAction
| UpdatePromptModelSettingsAction
| UpdatePromptParametersAction;

Expand Down Expand Up @@ -37,6 +38,12 @@ export type UpdatePromptNameAction = {
name: string;
};

export type UpdatePromptModelAction = {
type: "UPDATE_PROMPT_MODEL";
index: number;
modelName?: string;
};

export type UpdatePromptModelSettingsAction = {
type: "UPDATE_PROMPT_MODEL_SETTINGS";
index: number;
Expand Down Expand Up @@ -134,6 +141,20 @@ export default function aiconfigReducer(
name: action.name,
}));
}
case "UPDATE_PROMPT_MODEL": {
return reduceReplacePrompt(state, action.index, (prompt) => ({
...prompt,
metadata: {
...prompt.metadata,
model: action.modelName
? {
name: action.modelName,
// TODO: Consolidate settings based on schema union
}
: undefined,
},
}));
}
case "UPDATE_PROMPT_MODEL_SETTINGS": {
return reduceReplacePrompt(state, action.index, (prompt) => ({
...prompt,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ActionIcon, Menu, TextInput } from "@mantine/core";
import { IconPlus, IconSearch, IconTextCaption } from "@tabler/icons-react";
import { memo, useCallback, useEffect, useState } from "react";
import { showNotification } from "@mantine/notifications";
import { memo, useCallback, useState } from "react";
import useLoadModels from "../../hooks/useLoadModels";

type Props = {
addPrompt: (prompt: string) => void;
Expand Down Expand Up @@ -41,30 +41,14 @@ function ModelMenuItems({

export default memo(function AddPromptButton({ addPrompt, getModels }: Props) {
const [modelSearch, setModelSearch] = useState("");
const [models, setModels] = useState<string[]>([]);
const [isOpen, setIsOpen] = useState(false);

const loadModels = useCallback(async (modelSearch: string) => {
try {
const models = await getModels(modelSearch);
setModels(models);
} catch (err: any) {
showNotification({
title: "Error loading models",
message: err?.message,
color: "red",
});
}
}, []);

const onAddPrompt = useCallback((model: string) => {
addPrompt(model);
setIsOpen(false);
}, []);

useEffect(() => {
loadModels(modelSearch);
}, [loadModels, modelSearch]);
const models = useLoadModels(modelSearch, getModels);

return (
<Menu
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { Autocomplete, AutocompleteItem, Button } from "@mantine/core";
import { memo, useState } from "react";
import { getPromptModelName } from "../../utils/promptUtils";
import { Prompt } from "aiconfig";
import useLoadModels from "../../hooks/useLoadModels";
import { IconX } from "@tabler/icons-react";

type Props = {
prompt: Prompt;
getModels: (search: string) => Promise<string[]>;
onSetModel: (model?: string) => void;
defaultConfigModelName?: string;
};

export default memo(function ModelSelector({
prompt,
getModels,
onSetModel,
defaultConfigModelName,
}: Props) {
const [selectedModel, setSelectedModel] = useState<string | undefined>(
getPromptModelName(prompt, defaultConfigModelName)
);
const [showAll, setShowAll] = useState(true);
const [autocompleteSearch, setAutocompleteSearch] = useState(
getPromptModelName(prompt, defaultConfigModelName)
);

const models = useLoadModels(showAll ? "" : autocompleteSearch, getModels);

const onSelectModel = (model?: string) => {
setSelectedModel(model);
onSetModel(model);
};

return (
<Autocomplete
placeholder="Select model"
limit={100}
maxDropdownHeight={200}
rightSection={
selectedModel ? (
<Button
size="xs"
variant="subtle"
mr={10}
onClick={() => {
onSelectModel(undefined);
setShowAll(true);
setAutocompleteSearch("");
}}
>
<IconX size={10} />
</Button>
) : null
}
filter={(searchValue: string, item: AutocompleteItem) => {
if (showAll) {
return true;
}

const modelName: string = item.value;
return modelName
.toLocaleLowerCase()
.includes(searchValue.toLocaleLowerCase().trim());
}}
data={models}
value={autocompleteSearch}
onChange={(value: string) => {
setAutocompleteSearch(value);
setShowAll(false);
onSelectModel(value);
models.some((model) => {
if (model === value) {
setShowAll(true);
return true;
}
});
}}
/>
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@ import PromptActionBar from "./PromptActionBar";
import PromptInputRenderer from "./prompt_input/PromptInputRenderer";
import PromptOutputsRenderer from "./prompt_outputs/PromptOutputsRenderer";
import { ClientPrompt } from "../../shared/types";
import { getPromptModelName, getPromptSchema } from "../../utils/promptUtils";
import { Flex, Card, Text, createStyles } from "@mantine/core";
import { getPromptSchema } from "../../utils/promptUtils";
import { Flex, Card, createStyles } from "@mantine/core";
import { PromptInput as AIConfigPromptInput } from "aiconfig";
import { memo, useCallback } from "react";
import { ParametersArray } from "../ParametersRenderer";
import PromptOutputBar from "./PromptOutputBar";
import PromptName from "./PromptName";
import ModelSelector from "./ModelSelector";

type Props = {
index: number;
prompt: ClientPrompt;
getModels: (search: string) => Promise<string[]>;
onChangePromptInput: (
promptIndex: number,
newPromptInput: AIConfigPromptInput
) => void;
onChangePromptName: (promptIndex: number, newName: string) => void;
onRunPrompt(promptIndex: number): Promise<void>;
onUpdateModel: (promptIndex: number, newModel?: string) => void;
onUpdateModelSettings: (promptIndex: number, newModelSettings: any) => void;
onUpdateParameters: (promptIndex: number, newParameters: any) => void;
defaultConfigModelName?: string;
Expand All @@ -42,10 +45,12 @@ const useStyles = createStyles((theme) => ({
export default memo(function PromptContainer({
prompt,
index,
getModels,
onChangePromptInput,
onChangePromptName,
defaultConfigModelName,
onRunPrompt,
onUpdateModel,
onUpdateModelSettings,
onUpdateParameters,
}: Props) {
Expand Down Expand Up @@ -87,6 +92,11 @@ export default memo(function PromptContainer({
[index, onRunPrompt]
);

const updateModel = useCallback(
(model?: string) => onUpdateModel(index, model),
[index, onUpdateModel]
);

// TODO: When adding support for custom PromptContainers, implement a PromptContainerRenderer which
// will take in the index and callback and render the appropriate PromptContainer with new memoized
// callback and not having to pass index down to PromptContainer
Expand All @@ -102,7 +112,12 @@ export default memo(function PromptContainer({
<Flex direction="column">
<Flex justify="space-between" mb="0.5em">
<PromptName name={prompt.name} onUpdate={onChangeName} />
<Text>{getPromptModelName(prompt, defaultConfigModelName)}</Text>
<ModelSelector
getModels={getModels}
prompt={prompt}
onSetModel={updateModel}
defaultConfigModelName={defaultConfigModelName}
/>
</Flex>
<PromptInputRenderer
input={prompt.input}
Expand Down
28 changes: 28 additions & 0 deletions python/src/aiconfig/editor/client/src/hooks/useLoadModels.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { useCallback, useEffect, useState } from "react";
import { showNotification } from "@mantine/notifications";

export default function useLoadModels(
modelSearch: string,
getModels: (search: string) => Promise<string[]>
) {
const [models, setModels] = useState<string[]>([]);

const loadModels = useCallback(async (modelSearch: string) => {
try {
const models = await getModels(modelSearch);
setModels(models);
} catch (err: any) {
showNotification({
title: "Error loading models",
message: err?.message,
color: "red",
});
}
}, []);

useEffect(() => {
loadModels(modelSearch);
}, [loadModels, modelSearch]);

return models;
}
Loading

0 comments on commit 434a192

Please sign in to comment.