Skip to content

Commit

Permalink
chore: fetch model list
Browse files Browse the repository at this point in the history
  • Loading branch information
appflowy committed Feb 1, 2025
1 parent dd6b285 commit 343fe1c
Show file tree
Hide file tree
Showing 16 changed files with 145 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import 'dart:convert';

import 'package:appflowy/user/application/user_listener.dart';
import 'package:appflowy/user/application/user_service.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart';
Expand All @@ -9,6 +11,7 @@ import 'package:bloc/bloc.dart';
import 'package:freezed_annotation/freezed_annotation.dart';

part 'settings_ai_bloc.freezed.dart';
part 'settings_ai_bloc.g.dart';

class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
SettingsAIBloc(
Expand Down Expand Up @@ -65,6 +68,7 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
},
);
_loadUserWorkspaceSetting();
_loadModelList();
},
didReceiveUserProfile: (userProfile) {
emit(state.copyWith(userProfile: userProfile));
Expand All @@ -78,7 +82,7 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
!(state.aiSettings?.disableSearchIndexing ?? false),
);
},
selectModel: (AIModelPB model) {
selectModel: (String model) {
_updateUserWorkspaceSetting(model: model);
},
didLoadAISetting: (UseAISettingPB settings) {
Expand All @@ -89,6 +93,14 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
),
);
},
didLoadAvailableModels: (String models) {
final dynamic decodedJson = jsonDecode(models);
Log.info("Available models: $decodedJson");
if (decodedJson is Map<String, dynamic>) {
final models = ModelList.fromJson(decodedJson).models;
emit(state.copyWith(availableModels: models));
}
},
refreshMember: (member) {
emit(state.copyWith(currentWorkspaceMemberRole: member.role));
},
Expand All @@ -98,7 +110,7 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {

void _updateUserWorkspaceSetting({
bool? disableSearchIndexing,
AIModelPB? model,
String? model,
}) {
final payload = UpdateUserWorkspaceSettingPB(
workspaceId: workspaceId,
Expand Down Expand Up @@ -132,6 +144,18 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
});
});
}

void _loadModelList() {
AIEventGetAvailableModels().send().then((result) {
result.fold((config) {
if (!isClosed) {
add(SettingsAIEvent.didLoadAvailableModels(config.models));
}
}, (err) {
Log.error(err);
});
});
}
}

@freezed
Expand All @@ -145,11 +169,15 @@ class SettingsAIEvent with _$SettingsAIEvent {
const factory SettingsAIEvent.refreshMember(WorkspaceMemberPB member) =
_RefreshMember;

const factory SettingsAIEvent.selectModel(AIModelPB model) = _SelectAIModel;
const factory SettingsAIEvent.selectModel(String model) = _SelectAIModel;

const factory SettingsAIEvent.didReceiveUserProfile(
UserProfilePB newUserProfile,
) = _DidReceiveUserProfile;

const factory SettingsAIEvent.didLoadAvailableModels(
String models,
) = _DidLoadAvailableModels;
}

@freezed
Expand All @@ -158,6 +186,21 @@ class SettingsAIState with _$SettingsAIState {
required UserProfilePB userProfile,
UseAISettingPB? aiSettings,
AFRolePB? currentWorkspaceMemberRole,
@Default(["default"]) List<String> availableModels,
@Default(true) bool enableSearchIndexing,
}) = _SettingsAIState;
}

@JsonSerializable()
class ModelList {
ModelList({
required this.models,
});

factory ModelList.fromJson(Map<String, dynamic> json) =>
_$ModelListFromJson(json);

final List<String> models;

Map<String, dynamic> toJson() => _$ModelListToJson(this);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/workspace/application/settings/ai/settings_ai_bloc.dart';
import 'package:appflowy/workspace/presentation/settings/shared/af_dropdown_menu_entry.dart';
import 'package:appflowy/workspace/presentation/settings/shared/settings_dropdown.dart';
import 'package:appflowy_backend/log.dart';
import 'package:appflowy_backend/protobuf/flowy-user/protobuf.dart';
import 'package:easy_localization/easy_localization.dart';
import 'package:flowy_infra_ui/style_widget/text.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
Expand All @@ -30,18 +28,18 @@ class AIModelSelection extends StatelessWidget {
),
const Spacer(),
Flexible(
child: SettingsDropdown<AIModelPB>(
child: SettingsDropdown<String>(
key: const Key('_AIModelSelection'),
onChanged: (model) => context
.read<SettingsAIBloc>()
.add(SettingsAIEvent.selectModel(model)),
selectedOption: state.userProfile.aiModel,
options: _availableModels
options: state.availableModels
.map(
(format) => buildDropdownMenuEntry<AIModelPB>(
(model) => buildDropdownMenuEntry<String>(
context,
value: format,
label: _titleForAIModel(format),
value: model,
label: model,
),
)
.toList(),
Expand All @@ -54,29 +52,3 @@ class AIModelSelection extends StatelessWidget {
);
}
}

List<AIModelPB> _availableModels = [
AIModelPB.DefaultModel,
AIModelPB.Claude3Opus,
AIModelPB.Claude3Sonnet,
AIModelPB.GPT4oMini,
AIModelPB.GPT4o,
];

String _titleForAIModel(AIModelPB model) {
switch (model) {
case AIModelPB.DefaultModel:
return "Default";
case AIModelPB.Claude3Opus:
return "Claude 3 Opus";
case AIModelPB.Claude3Sonnet:
return "Claude 3 Sonnet";
case AIModelPB.GPT4oMini:
return "GPT-4o-mini";
case AIModelPB.GPT4o:
return "GPT-4o";
default:
Log.error("Unknown AI model: $model, fallback to default");
return "Default";
}
}
24 changes: 12 additions & 12 deletions frontend/rust-lib/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions frontend/rust-lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ dashmap = "6.0.1"
# Run the script.add_workspace_members:
# scripts/tool/update_client_api_rev.sh new_rev_id
# ⚠️⚠️⚠️️
client-api = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "4a26572a4e43714def9b362d444c640fdf1bc0d9" }
client-api-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "4a26572a4e43714def9b362d444c640fdf1bc0d9" }
client-api = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "18b1386bc2d16851d4b5f42d28f23b8c333d02db" }
client-api-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "18b1386bc2d16851d4b5f42d28f23b8c333d02db" }

[profile.dev]
opt-level = 0
Expand Down
4 changes: 3 additions & 1 deletion frontend/rust-lib/flowy-ai-pub/src/cloud.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use bytes::Bytes;
pub use client_api::entity::ai_dto::{
AppFlowyOfflineAI, CompleteTextParams, CompletionMetadata, CompletionType, CreateChatContext,
LLMModel, LocalAIConfig, ModelInfo, OutputContent, OutputLayout, RelatedQuestion,
LLMModel, LocalAIConfig, ModelInfo, ModelList, OutputContent, OutputLayout, RelatedQuestion,
RepeatedRelatedQuestion, ResponseFormat, StringOrMessage,
};
pub use client_api::entity::billing_dto::SubscriptionPlan;
Expand Down Expand Up @@ -119,4 +119,6 @@ pub trait ChatCloudService: Send + Sync + 'static {
chat_id: &str,
params: UpdateChatParams,
) -> Result<(), FlowyError>;

async fn get_available_models(&self, workspace_id: &str) -> Result<ModelList, FlowyError>;
}
11 changes: 10 additions & 1 deletion frontend/rust-lib/flowy-ai/src/ai_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::collections::HashMap;

use appflowy_plugin::manager::PluginManager;
use dashmap::DashMap;
use flowy_ai_pub::cloud::{ChatCloudService, ChatSettings, UpdateChatParams};
use flowy_ai_pub::cloud::{ChatCloudService, ChatSettings, ModelList, UpdateChatParams};
use flowy_error::{FlowyError, FlowyResult};
use flowy_sqlite::kv::KVStorePreferences;
use flowy_sqlite::DBConnection;
Expand Down Expand Up @@ -241,6 +241,15 @@ impl AIManager {
Ok(())
}

pub async fn get_available_models(&self) -> FlowyResult<ModelList> {
let workspace_id = self.user_service.workspace_id()?;
let list = self
.cloud_service_wm
.get_available_models(&workspace_id)
.await?;
Ok(list)
}

pub async fn get_or_create_chat_instance(&self, chat_id: &str) -> Result<Arc<Chat>, FlowyError> {
let chat = self.chats.get(chat_id).as_deref().cloned();
match chat {
Expand Down
6 changes: 6 additions & 0 deletions frontend/rust-lib/flowy-ai/src/entities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ pub struct ChatMessageListPB {
pub total: i64,
}

#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
pub struct ModelConfigPB {
#[pb(index = 1)]
pub models: String,
}

impl From<RepeatedChatMessage> for ChatMessageListPB {
fn from(repeated_chat_message: RepeatedChatMessage) -> Self {
let messages = repeated_chat_message
Expand Down
9 changes: 9 additions & 0 deletions frontend/rust-lib/flowy-ai/src/event_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ pub(crate) async fn regenerate_response_handler(
Ok(())
}

#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn get_available_model_list_handler(
ai_manager: AFPluginState<Weak<AIManager>>,
) -> DataResult<ModelConfigPB, FlowyError> {
let ai_manager = upgrade_ai_manager(ai_manager)?;
let models = serde_json::to_string(&ai_manager.get_available_models().await?)?;
data_result_ok(ModelConfigPB { models })
}

#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn load_prev_message_handler(
data: AFPluginData<LoadPrevChatMessagePB>,
Expand Down
7 changes: 7 additions & 0 deletions frontend/rust-lib/flowy-ai/src/event_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ pub fn init(ai_manager: Weak<AIManager>) -> AFPlugin {
.event(AIEvent::GetChatSettings, get_chat_settings_handler)
.event(AIEvent::UpdateChatSettings, update_chat_settings_handler)
.event(AIEvent::RegenerateResponse, regenerate_response_handler)
.event(
AIEvent::GetAvailableModels,
get_available_model_list_handler,
)
}

#[derive(Clone, Copy, PartialEq, Eq, Debug, Display, Hash, ProtoBuf_Enum, Flowy_Event)]
Expand Down Expand Up @@ -154,4 +158,7 @@ pub enum AIEvent {

#[event(input = "RegenerateResponsePB")]
RegenerateResponse = 27,

#[event(output = "ModelConfigPB")]
GetAvailableModels = 28,
}
Loading

0 comments on commit 343fe1c

Please sign in to comment.