Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[editor] Model Selector #662

Merged
merged 1 commit into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/src/aiconfig/editor/client/src/Editor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import EditorContainer, {
AIConfigCallbacks,
} from "./components/EditorContainer";
import { Flex, Loader, MantineProvider } from "@mantine/core";
import { AIConfig, Prompt } from "aiconfig";
import { AIConfig, ModelMetadata, Prompt } from "aiconfig";
import { useCallback, useEffect, useMemo, useState } from "react";
import { ufetch } from "ufetch";
import { ROUTE_TABLE } from "./utils/api";
Expand Down Expand Up @@ -66,12 +66,23 @@ export default function Editor() {
[]
);

const updateModel = useCallback(
async (_promptName?: string, _modelData?: string | ModelMetadata) => {
// return await ufetch.post(ROUTE_TABLE.UPDATE_MODEL,
// prompt_name: promptName,
// model_data: modelData,
// });
Comment on lines +71 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you commenting this out to add it later once the update_model() api is built on backend?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

},
[]
);

const callbacks: AIConfigCallbacks = useMemo(
() => ({
addPrompt,
getModels,
runPrompt,
save,
updateModel,
updatePrompt,
}),
[save, getModels, addPrompt, runPrompt]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import PromptContainer from "./prompt/PromptContainer";
import { Container, Group, Button, createStyles, Stack } from "@mantine/core";
import { showNotification } from "@mantine/notifications";
import { AIConfig, Prompt, PromptInput } from "aiconfig";
import { AIConfig, ModelMetadata, Prompt, PromptInput } from "aiconfig";
import { useCallback, useMemo, useReducer, useRef, useState } from "react";
import aiconfigReducer, { AIConfigReducerAction } from "./aiconfigReducer";
import {
Expand All @@ -28,6 +28,10 @@ export type AIConfigCallbacks = {
getModels: (search: string) => Promise<string[]>;
runPrompt: (promptName: string) => Promise<void>;
save: (aiconfig: AIConfig) => Promise<void>;
updateModel: (
promptName?: string,
modelData?: string | ModelMetadata
) => Promise<void /*{ aiconfig: AIConfig }*/>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not return aiconfig?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will once the endpoint is implemented -- right now there's no updated aiconfig to return.

updatePrompt: (
promptName: string,
promptData: Prompt
Expand Down Expand Up @@ -150,18 +154,63 @@ export default function EditorContainer({
[dispatch]
);

const debouncedUpdateModel = useMemo(
() =>
debounce(
(promptName?: string, modelMetadata?: string | ModelMetadata) =>
callbacks.updateModel(promptName, modelMetadata),
250
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we define this as a const somewhere? I see it's also 250 for addPrompt

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can pull it out to a const for now in a separate PR, but it could technically be dependent on where it's used

),
[callbacks.updateModel]
);

const onUpdatePromptModelSettings = useCallback(
async (promptIndex: number, newModelSettings: any) => {
dispatch({
type: "UPDATE_PROMPT_MODEL_SETTINGS",
index: promptIndex,
modelSettings: newModelSettings,
});
// TODO: Call server-side endpoint to update model settings
// TODO: Call server-side endpoint to update model
Copy link
Member Author

@rholinshead rholinshead Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not setting this up just yet (unlike model name below) since there is a higher potential for bugs in managing the settings object & don't want to add friction to whomever is adding the server-side endpoint. I'll complete this once the endpoint is implemented

},
[dispatch]
);

const onUpdatePromptModel = useCallback(
async (promptIndex: number, newModel?: string) => {
dispatch({
type: "UPDATE_PROMPT_MODEL",
index: promptIndex,
modelName: newModel,
});

try {
const prompt = clientPromptToAIConfigPrompt(
aiconfigState.prompts[promptIndex]
);
const currentModel = prompt.metadata?.model;
let modelData: string | ModelMetadata | undefined = newModel;
if (newModel && currentModel && typeof currentModel !== "string") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with the new SDK you won't have to do individual checks because we're passing model_name and settings as two separate args, hopefully this makes it easier for you

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok - for the updated config it seems like we can't support the following scenario:
Change model name for a prompt and don't set any settings (i.e. set model to the string name instead of {name: <name>, settings: {}}). This is also another case where we have multiple ways of representing the same thing in aiconfig, and might be annoying to set the model name and clear the settings but always have settings: {} set by the editor instead of just having model: name in the config?

On top of that, if we change the model and it did already have settings set, we would ideally want to maintain only the settings that are supported by the new model (e.g. say temperature, topK are used by both models) as defined in the prompt schema that is associated with the model provider, and drop anything not supported. This might need to wait until post-MVP, though. For now we can either clear the settings or keep them entirely... I was doing the former here, but open to opinions

modelData = {
...currentModel,
name: newModel,
};
}

await debouncedUpdateModel(prompt.name, modelData);

// TODO: Consolidate
} catch (err: any) {
showNotification({
title: "Error updating prompt model",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to read the http message that we return and we display it? Can be done future PR

message: err.message,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I guess it's done here?

color: "red",
});
}
},
[dispatch, debouncedUpdateModel]
);

const onUpdatePromptParameters = useCallback(
async (promptIndex: number, newParameters: any) => {
dispatch({
Expand Down Expand Up @@ -241,8 +290,6 @@ export default function EditorContainer({

const { classes } = useStyles();

// TODO: Implement editor context for callbacks, readonly state, etc.

return (
<>
<Container maw="80rem">
Expand All @@ -262,9 +309,11 @@ export default function EditorContainer({
<PromptContainer
index={i}
prompt={prompt}
getModels={callbacks.getModels}
onChangePromptInput={onChangePromptInput}
onChangePromptName={onChangePromptName}
onRunPrompt={onRunPrompt}
onUpdateModel={onUpdatePromptModel}
onUpdateModelSettings={onUpdatePromptModelSettings}
Comment on lines +316 to 317
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok cool, so we are keeping them as separate actions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, we should keep them separate since there's different logic for them:

  • update settings --> keep same model name but new model settings object
  • update model --> change model name and either keep settings or clear them

onUpdateParameters={onUpdatePromptParameters}
defaultConfigModelName={aiconfigState.metadata.default_model}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export type MutateAIConfigAction =
| AddPromptAction
| UpdatePromptInputAction
| UpdatePromptNameAction
| UpdatePromptModelAction
| UpdatePromptModelSettingsAction
| UpdatePromptParametersAction;

Expand Down Expand Up @@ -37,6 +38,12 @@ export type UpdatePromptNameAction = {
name: string;
};

export type UpdatePromptModelAction = {
Copy link
Contributor

@rossdanlm rossdanlm Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this being used anywhere in this PR? Does the displatch type being set to "onUpdatePromptModel" mean that it's automatically mapped to this action type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we dispatch the "UPDATE_PROMPT_MODEL" action, that action is implicitly typed as this UpdatePromptModelAction type. We add it to MutateAIConfigAction above and that does two things:

  1. Makes it part of the AIConfigReducerAction via the union for that type, which means our reducer below will require a case "UPDATE_PROMPT_MODEL" to handle it to ensure the state is update -- this is the automatic mapping you're referring to
  2. Allows us to dispatch CONSOLIDATE_AICONFIG after we get the response, in order to update the client state with response state (not done here because we don't have an aiconfig response to do that with just yet)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for the explanation!

type: "UPDATE_PROMPT_MODEL";
index: number;
modelName?: string;
};

export type UpdatePromptModelSettingsAction = {
type: "UPDATE_PROMPT_MODEL_SETTINGS";
index: number;
Expand Down Expand Up @@ -134,6 +141,20 @@ export default function aiconfigReducer(
name: action.name,
}));
}
case "UPDATE_PROMPT_MODEL": {
return reduceReplacePrompt(state, action.index, (prompt) => ({
...prompt,
metadata: {
...prompt.metadata,
model: action.modelName
? {
name: action.modelName,
// TODO: Consolidate settings based on schema union
}
: undefined,
Comment on lines +149 to +154
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully don't need to do this logic with updated SDK

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll revisit this once the endpoint is implemented. We still need some UI state change here (this happens before the request is sent and before the response is returned to ensure responsive UI). Probably what will need to be done is:

  • this action sets just the name
  • aiconfig response comes back with full metadata, either cleared settings or updated settings to match supported settings based on the schema
  • from aiconfig response, we consolidate the UI state model metadata with updated settings

},
}));
}
case "UPDATE_PROMPT_MODEL_SETTINGS": {
return reduceReplacePrompt(state, action.index, (prompt) => ({
...prompt,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ActionIcon, Menu, TextInput } from "@mantine/core";
import { IconPlus, IconSearch, IconTextCaption } from "@tabler/icons-react";
import { memo, useCallback, useEffect, useState } from "react";
import { showNotification } from "@mantine/notifications";
import { memo, useCallback, useState } from "react";
import useLoadModels from "../../hooks/useLoadModels";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, much cleaner


type Props = {
addPrompt: (prompt: string) => void;
Expand Down Expand Up @@ -41,30 +41,14 @@ function ModelMenuItems({

export default memo(function AddPromptButton({ addPrompt, getModels }: Props) {
const [modelSearch, setModelSearch] = useState("");
Copy link
Contributor

@rossdanlm rossdanlm Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what modelSearch string supposed to represent?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a search input and modelSearch is the text of that input

const [models, setModels] = useState<string[]>([]);
const [isOpen, setIsOpen] = useState(false);

const loadModels = useCallback(async (modelSearch: string) => {
try {
const models = await getModels(modelSearch);
setModels(models);
} catch (err: any) {
showNotification({
title: "Error loading models",
message: err?.message,
color: "red",
});
}
}, []);

const onAddPrompt = useCallback((model: string) => {
addPrompt(model);
setIsOpen(false);
}, []);

useEffect(() => {
loadModels(modelSearch);
}, [loadModels, modelSearch]);
const models = useLoadModels(modelSearch, getModels);

return (
<Menu
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { Autocomplete, AutocompleteItem, Button } from "@mantine/core";
import { memo, useState } from "react";
import { getPromptModelName } from "../../utils/promptUtils";
import { Prompt } from "aiconfig";
import useLoadModels from "../../hooks/useLoadModels";
import { IconX } from "@tabler/icons-react";

type Props = {
prompt: Prompt;
getModels: (search: string) => Promise<string[]>;
onSetModel: (model?: string) => void;
defaultConfigModelName?: string;
};

export default memo(function ModelSelector({
prompt,
getModels,
onSetModel,
defaultConfigModelName,
}: Props) {
const [selectedModel, setSelectedModel] = useState<string | undefined>(
getPromptModelName(prompt, defaultConfigModelName)
);
const [showAll, setShowAll] = useState(true);
const [autocompleteSearch, setAutocompleteSearch] = useState(
getPromptModelName(prompt, defaultConfigModelName)
);

const models = useLoadModels(showAll ? "" : autocompleteSearch, getModels);

const onSelectModel = (model?: string) => {
setSelectedModel(model);
onSetModel(model);
};

return (
<Autocomplete
placeholder="Select model"
limit={100}
maxDropdownHeight={200}
rightSection={
selectedModel ? (
<Button
size="xs"
variant="subtle"
mr={10}
onClick={() => {
onSelectModel(undefined);
setShowAll(true);
setAutocompleteSearch("");
}}
>
<IconX size={10} />
</Button>
) : null
}
filter={(searchValue: string, item: AutocompleteItem) => {
if (showAll) {
return true;
}

const modelName: string = item.value;
return modelName
.toLocaleLowerCase()
.includes(searchValue.toLocaleLowerCase().trim());
Comment on lines +57 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright so what's the difference between this vs. the functionality in getModels callback:

const getModels = useCallback(async (search: string) => {
// For now, rely on caching and handle client-side search filtering
// We will use server-side search filtering for Gradio
const res = await ufetch.get(ROUTE_TABLE.LIST_MODELS);
const models = res.data;
if (search && search.length > 0) {
return models.filter((model: string) => model.indexOf(search) >= 0);
}
return models;
}, []);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is specific to this AutoComplete UI component because needs to handle the search/filtering of the provided data on its own

}}
data={models}
value={autocompleteSearch}
onChange={(value: string) => {
setAutocompleteSearch(value);
setShowAll(false);
onSelectModel(value);
models.some((model) => {
if (model === value) {
setShowAll(true);
Copy link
Contributor

@rossdanlm rossdanlm Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok just to clarify, this onChange only occurs when a value is selected from the autocomplete options? Does this also occur when user changes the search string? If this occurs when user changes the text and it matches a model name, we wouldn't want to have setShowAll to true right? I noticed that even when we enter gpt-4 in your test plan, at the bottom are still some model names that don't match like dall-e-3

Screenshot 2023-12-29 at 10 14 58

Copy link
Member Author

@rholinshead rholinshead Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mainly copied from our model selector in Workbooks and is sort of a workaround to match our desired behaviour. By default, the component will only ever show results in the dropdown that match the text in the input box. So, if the aiconfig has model set to gpt-4 when you load it, the dropdown would only ever show models with gpt-4 substring.
This isn't what we want, though, since users may not know that there are other non gpt-4 models to use then. The behaviour we want is:

  • if the user is typing to 'search' for a model, then the dropdown should filter on that search substring
  • otherwise, if the model is selected from the list / defaulted from the config then we want to show all the model options to change to

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I generally assumed that the user would just click the x button on the right to see everything, hmmmm. I don't think it's a blocker but feels a bit strange to me that auto-completions don't remain filtered after selected. We can do this as P2 someday

return true;
}
});
}}
/>
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@ import PromptActionBar from "./PromptActionBar";
import PromptInputRenderer from "./prompt_input/PromptInputRenderer";
import PromptOutputsRenderer from "./prompt_outputs/PromptOutputsRenderer";
import { ClientPrompt } from "../../shared/types";
import { getPromptModelName, getPromptSchema } from "../../utils/promptUtils";
import { Flex, Card, Text, createStyles } from "@mantine/core";
import { getPromptSchema } from "../../utils/promptUtils";
import { Flex, Card, createStyles } from "@mantine/core";
import { PromptInput as AIConfigPromptInput } from "aiconfig";
import { memo, useCallback } from "react";
import { ParametersArray } from "../ParametersRenderer";
import PromptOutputBar from "./PromptOutputBar";
import PromptName from "./PromptName";
import ModelSelector from "./ModelSelector";

type Props = {
index: number;
prompt: ClientPrompt;
getModels: (search: string) => Promise<string[]>;
onChangePromptInput: (
promptIndex: number,
newPromptInput: AIConfigPromptInput
) => void;
onChangePromptName: (promptIndex: number, newName: string) => void;
onRunPrompt(promptIndex: number): Promise<void>;
onUpdateModel: (promptIndex: number, newModel?: string) => void;
onUpdateModelSettings: (promptIndex: number, newModelSettings: any) => void;
onUpdateParameters: (promptIndex: number, newParameters: any) => void;
defaultConfigModelName?: string;
Expand All @@ -42,10 +45,12 @@ const useStyles = createStyles((theme) => ({
export default memo(function PromptContainer({
prompt,
index,
getModels,
onChangePromptInput,
onChangePromptName,
defaultConfigModelName,
onRunPrompt,
onUpdateModel,
onUpdateModelSettings,
onUpdateParameters,
}: Props) {
Expand Down Expand Up @@ -87,6 +92,11 @@ export default memo(function PromptContainer({
[index, onRunPrompt]
);

const updateModel = useCallback(
(model?: string) => onUpdateModel(index, model),
[index, onUpdateModel]
);

// TODO: When adding support for custom PromptContainers, implement a PromptContainerRenderer which
// will take in the index and callback and render the appropriate PromptContainer with new memoized
// callback and not having to pass index down to PromptContainer
Expand All @@ -102,7 +112,12 @@ export default memo(function PromptContainer({
<Flex direction="column">
<Flex justify="space-between" mb="0.5em">
<PromptName name={prompt.name} onUpdate={onChangeName} />
<Text>{getPromptModelName(prompt, defaultConfigModelName)}</Text>
<ModelSelector
getModels={getModels}
prompt={prompt}
onSetModel={updateModel}
defaultConfigModelName={defaultConfigModelName}
/>
</Flex>
<PromptInputRenderer
input={prompt.input}
Expand Down
Loading