+ ${this.networkSearchConfig.visible.value
+ ? html`
+
+ ${PublishIcon()}
+
Toggle Network Search
+
+ `
+ : nothing}
void) | undefined = undefined;
diff --git a/blocksuite/blocks/src/root-block/widgets/ai-panel/type.ts b/blocksuite/blocks/src/root-block/widgets/ai-panel/type.ts
index a9961c27ba67b..7be30f3928702 100644
--- a/blocksuite/blocks/src/root-block/widgets/ai-panel/type.ts
+++ b/blocksuite/blocks/src/root-block/widgets/ai-panel/type.ts
@@ -2,6 +2,7 @@ import type {
AIError,
AIItemGroupConfig,
} from '@blocksuite/affine-components/ai-item';
+import type { Signal } from '@preact/signals-core';
import type { nothing, TemplateResult } from 'lit';
export interface CopyConfig {
@@ -28,6 +29,12 @@ export interface AIPanelGeneratingConfig {
stages?: string[];
}
+export interface AINetworkSearchConfig {
+ visible: Signal;
+ enabled: Signal;
+ setEnabled: (state: boolean) => void;
+}
+
export interface AffineAIPanelWidgetConfig {
answerRenderer: (
answer: string,
@@ -44,10 +51,10 @@ export interface AffineAIPanelWidgetConfig {
finishStateConfig: AIPanelAnswerConfig;
generatingStateConfig: AIPanelGeneratingConfig;
errorStateConfig: AIPanelErrorConfig;
+ networkSearchConfig: AINetworkSearchConfig;
hideCallback?: () => void;
discardCallback?: () => void;
inputCallback?: (input: string) => void;
-
copy?: CopyConfig;
}
diff --git a/packages/backend/server/.env.example b/packages/backend/server/.env.example
index 2a56e69bf6960..9c740f049243a 100644
--- a/packages/backend/server/.env.example
+++ b/packages/backend/server/.env.example
@@ -2,6 +2,7 @@
# REDIS_SERVER_HOST=localhost
# COPILOT_FAL_API_KEY=YOUR_KEY
# COPILOT_OPENAI_API_KEY=YOUR_KEY
+# COPILOT_PERPLEXITY_API_KEY=YOUR_KEY
# MAILER_HOST=127.0.0.1
# MAILER_PORT=1025
diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json
index 196277b3f1c1a..9533ebc4ecf03 100644
--- a/packages/backend/server/package.json
+++ b/packages/backend/server/package.json
@@ -58,6 +58,7 @@
"@socket.io/redis-adapter": "^8.3.0",
"cookie-parser": "^1.4.7",
"dotenv": "^16.4.7",
+ "eventsource-parser": "^3.0.0",
"express": "^4.21.2",
"fast-xml-parser": "^4.5.0",
"get-stream": "^9.0.1",
diff --git a/packages/backend/server/src/config/affine.env.ts b/packages/backend/server/src/config/affine.env.ts
index 05ff3dc60bd55..c72f5e617f57b 100644
--- a/packages/backend/server/src/config/affine.env.ts
+++ b/packages/backend/server/src/config/affine.env.ts
@@ -28,6 +28,7 @@ AFFiNE.ENV_MAP = {
CAPTCHA_TURNSTILE_SECRET: ['plugins.captcha.turnstile.secret', 'string'],
COPILOT_OPENAI_API_KEY: 'plugins.copilot.openai.apiKey',
COPILOT_FAL_API_KEY: 'plugins.copilot.fal.apiKey',
+ COPILOT_PERPLEXITY_API_KEY: 'plugins.copilot.perplexity.apiKey',
COPILOT_UNSPLASH_API_KEY: 'plugins.copilot.unsplashKey',
REDIS_SERVER_HOST: 'redis.host',
REDIS_SERVER_PORT: ['redis.port', 'int'],
diff --git a/packages/backend/server/src/plugins/copilot/config.ts b/packages/backend/server/src/plugins/copilot/config.ts
index 516334d7ce64b..b7c54ca37ed8d 100644
--- a/packages/backend/server/src/plugins/copilot/config.ts
+++ b/packages/backend/server/src/plugins/copilot/config.ts
@@ -3,10 +3,12 @@ import type { ClientOptions as OpenAIClientOptions } from 'openai';
import { defineStartupConfig, ModuleConfig } from '../../base/config';
import { StorageConfig } from '../../base/storage/config';
import type { FalConfig } from './providers/fal';
+import { PerplexityConfig } from './providers/perplexity';
export interface CopilotStartupConfigurations {
openai?: OpenAIClientOptions;
fal?: FalConfig;
+ perplexity?: PerplexityConfig;
test?: never;
unsplashKey?: string;
storage: StorageConfig;
diff --git a/packages/backend/server/src/plugins/copilot/index.ts b/packages/backend/server/src/plugins/copilot/index.ts
index 01bbeb6b46e43..2940044c4ad2d 100644
--- a/packages/backend/server/src/plugins/copilot/index.ts
+++ b/packages/backend/server/src/plugins/copilot/index.ts
@@ -13,6 +13,7 @@ import {
CopilotProviderService,
FalProvider,
OpenAIProvider,
+ PerplexityProvider,
registerCopilotProvider,
} from './providers';
import {
@@ -26,6 +27,7 @@ import { CopilotWorkflowExecutors, CopilotWorkflowService } from './workflow';
registerCopilotProvider(FalProvider);
registerCopilotProvider(OpenAIProvider);
+registerCopilotProvider(PerplexityProvider);
@Plugin({
name: 'copilot',
diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts
index f2e613f160205..ebc1effbe03bc 100644
--- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts
+++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts
@@ -952,6 +952,11 @@ const chat: Prompt[] = [
},
],
},
+ {
+ name: 'Search With AFFiNE AI',
+ model: 'llama-3.1-sonar-small-128k-online',
+ messages: [],
+ },
// use for believer plan
{
name: 'Chat With AFFiNE AI - Believer',
diff --git a/packages/backend/server/src/plugins/copilot/providers/index.ts b/packages/backend/server/src/plugins/copilot/providers/index.ts
index a72cb307f57e0..1e89774d0c9ef 100644
--- a/packages/backend/server/src/plugins/copilot/providers/index.ts
+++ b/packages/backend/server/src/plugins/copilot/providers/index.ts
@@ -124,9 +124,7 @@ export class CopilotProviderService {
if (!this.cachedProviders.has(provider)) {
this.cachedProviders.set(provider, this.create(provider));
}
-
- // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
- return this.cachedProviders.get(provider)!;
+ return this.cachedProviders.get(provider) as CopilotProvider;
}
async getProviderByCapability(
@@ -196,3 +194,4 @@ export class CopilotProviderService {
export { FalProvider } from './fal';
export { OpenAIProvider } from './openai';
+export { PerplexityProvider } from './perplexity';
diff --git a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts
new file mode 100644
index 0000000000000..f107a7501859d
--- /dev/null
+++ b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts
@@ -0,0 +1,209 @@
+import assert from 'node:assert';
+
+import { EventSourceParserStream } from 'eventsource-parser/stream';
+
+import {
+ CopilotPromptInvalid,
+ CopilotProviderSideError,
+ metrics,
+} from '../../../base';
+import {
+ CopilotCapability,
+ CopilotChatOptions,
+ CopilotProviderType,
+ CopilotTextToTextProvider,
+ PromptMessage,
+} from '../types';
+
+export type PerplexityConfig = {
+ apiKey: string;
+};
+
+interface Message {
+ role: 'assistant';
+ content: string;
+}
+
+interface Choice {
+ message: Message;
+ delta: Message;
+ finish_reason: 'stop' | null;
+}
+
+interface PerplexityData {
+ citations: string[];
+ choices: Choice[];
+}
+
+export function injectCitations(content: string, citations: string[]) {
+ // Match [[n]] and [n] patterns
+ // Not match if they're already part of a formatted citation
+ const regex = /(?:\[\[(\d+)\]\]|\[(\d+)\])(?!\]?\([^)]*\))/g;
+ return content.replace(regex, (_, g1, g2) => {
+ const index = parseInt(g1 || g2) - 1;
+ if (index >= 0 && index < citations.length) {
+ return `[[${g1 || g2}](${citations[index]})]`;
+ }
+ return _;
+ });
+}
+
+export class PerplexityProvider implements CopilotTextToTextProvider {
+ static readonly type = CopilotProviderType.Perplexity;
+
+ static readonly capabilities = [CopilotCapability.TextToText];
+
+ static assetsConfig(config: PerplexityConfig) {
+ return !!config.apiKey;
+ }
+
+ constructor(private readonly config: PerplexityConfig) {
+ assert(PerplexityProvider.assetsConfig(config));
+ }
+
+ readonly availableModels = [
+ 'llama-3.1-sonar-small-128k-online',
+ 'llama-3.1-sonar-large-128k-online',
+ 'llama-3.1-sonar-huge-128k-online',
+ ];
+
+ get type(): CopilotProviderType {
+ return PerplexityProvider.type;
+ }
+
+ getCapabilities(): CopilotCapability[] {
+ return PerplexityProvider.capabilities;
+ }
+
+ async isModelAvailable(model: string): Promise {
+ return this.availableModels.includes(model);
+ }
+
+ async generateText(
+ messages: PromptMessage[],
+ model: string = 'llama-3.1-sonar-small-128k-online',
+ options: CopilotChatOptions = {}
+ ): Promise {
+ await this.checkParams({ messages, model, options });
+ try {
+ metrics.ai.counter('chat_text_calls').add(1, { model });
+ const sMessages = messages
+ .map(({ content, role }) => ({ content, role }))
+ .filter(({ content }) => typeof content === 'string');
+
+ const params = {
+ method: 'POST',
+ headers: {
+ Authorization: `Bearer ${this.config.apiKey}`,
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({
+ model,
+ messages: sMessages,
+ max_tokens: options.maxTokens || 4096,
+ }),
+ };
+ const response = await fetch(
+ 'https://api.perplexity.ai/chat/completions',
+ params
+ );
+ const json: PerplexityData = await response.json();
+ return injectCitations(json.choices[0].message.content, json.citations);
+ } catch (e: any) {
+ metrics.ai.counter('chat_text_errors').add(1, { model });
+ throw this.handleError(e);
+ }
+ }
+
+ async *generateTextStream(
+ messages: PromptMessage[],
+ model: string = 'llama-3.1-sonar-small-128k-online',
+ options: CopilotChatOptions = {}
+ ): AsyncIterable {
+ await this.checkParams({ messages, model, options });
+ try {
+ metrics.ai.counter('chat_text_stream_calls').add(1, { model });
+ const sMessages = messages
+ .map(({ content, role }) => ({ content, role }))
+ .filter(({ content }) => typeof content === 'string');
+
+ const params = {
+ method: 'POST',
+ headers: {
+ Authorization: `Bearer ${this.config.apiKey}`,
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({
+ model,
+ messages: sMessages,
+ max_tokens: options.maxTokens || 4096,
+ stream: true,
+ }),
+ };
+ const response = await fetch(
+ 'https://api.perplexity.ai/chat/completions',
+ params
+ );
+ if (response.body) {
+ const eventStream = response.body
+ .pipeThrough(new TextDecoderStream())
+ .pipeThrough(new EventSourceParserStream())
+ .pipeThrough(
+ new TransformStream({
+ transform(chunk, controller) {
+ if (options.signal?.aborted) {
+ controller.enqueue(null);
+ return;
+ }
+ const data = JSON.parse(chunk.data) as PerplexityData | null;
+ if (data?.choices?.[0]?.delta?.content) {
+ const content = injectCitations(
+ data.choices[0].delta.content,
+ data.citations || []
+ );
+ controller.enqueue(content);
+ }
+ },
+ flush(controller) {
+ controller.enqueue(null);
+ },
+ })
+ );
+
+ const reader = eventStream.getReader();
+ while (true) {
+ const { done, value } = await reader.read();
+ if (done) break;
+ yield value;
+ }
+ } else {
+ const result = await this.generateText(messages, model, options);
+ yield result;
+ }
+ } catch (e) {
+ metrics.ai.counter('chat_text_stream_errors').add(1, { model });
+ throw e;
+ }
+ }
+
+ protected async checkParams({
+ model,
+ }: {
+ messages?: PromptMessage[];
+ embeddings?: string[];
+ model: string;
+ options: CopilotChatOptions;
+ }) {
+ if (!(await this.isModelAvailable(model))) {
+ throw new CopilotPromptInvalid(`Invalid model: ${model}`);
+ }
+ }
+
+ private handleError(e: any) {
+ return new CopilotProviderSideError({
+ provider: this.type,
+ kind: 'unexpected_response',
+ message: e?.message || 'Unexpected perplexity response',
+ });
+ }
+}
diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts
index 9d0429b678656..a4dad5040b2bf 100644
--- a/packages/backend/server/src/plugins/copilot/resolver.ts
+++ b/packages/backend/server/src/plugins/copilot/resolver.ts
@@ -22,6 +22,7 @@ import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs';
import {
CallMetric,
CopilotFailedToCreateMessage,
+ CopilotSessionNotFound,
FileUpload,
RequestMutex,
Throttle,
@@ -62,6 +63,17 @@ class CreateChatSessionInput {
promptName!: string;
}
+@InputType()
+class UpdateChatSessionInput {
+ @Field(() => String)
+ sessionId!: string;
+
+ @Field(() => String, {
+ description: 'The prompt name to use for the session',
+ })
+ promptName!: string;
+}
+
@InputType()
class ForkChatSessionInput {
@Field(() => String)
@@ -372,6 +384,41 @@ export class CopilotResolver {
});
}
+ @Mutation(() => String, {
+ description: 'Update a chat session',
+ })
+ @CallMetric('ai', 'chat_session_update')
+ async updateCopilotSession(
+ @CurrentUser() user: CurrentUser,
+ @Args({ name: 'options', type: () => UpdateChatSessionInput })
+ options: UpdateChatSessionInput
+ ) {
+ const session = await this.chatSession.get(options.sessionId);
+ if (!session) {
+ throw new CopilotSessionNotFound();
+ }
+ const { workspaceId, docId, parentSessionId } = session.config;
+ await this.permissions.checkCloudPagePermission(
+ workspaceId,
+ docId,
+ user.id
+ );
+ const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${workspaceId}`;
+ await using lock = await this.mutex.acquire(lockFlag);
+ if (!lock) {
+ return new TooManyRequest('Server is busy');
+ }
+
+ await this.chatSession.checkQuota(user.id);
+ return await this.chatSession.update({
+ ...options,
+ parentSessionId,
+ workspaceId,
+ docId,
+ userId: user.id,
+ });
+ }
+
@Mutation(() => String, {
description: 'Create a chat session',
})
diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts
index 3e6414828f7f9..c6175bdb43fbd 100644
--- a/packages/backend/server/src/plugins/copilot/session.ts
+++ b/packages/backend/server/src/plugins/copilot/session.ts
@@ -23,6 +23,7 @@ import {
ChatSessionForkOptions,
ChatSessionOptions,
ChatSessionState,
+ ChatSessionUpdateOptions,
getTokenEncoder,
ListHistoriesOptions,
PromptMessage,
@@ -46,13 +47,22 @@ export class ChatSession implements AsyncDisposable {
get config() {
const {
sessionId,
+ parentSessionId,
userId,
workspaceId,
docId,
prompt: { name: promptName, config: promptConfig },
} = this.state;
- return { sessionId, userId, workspaceId, docId, promptName, promptConfig };
+ return {
+ sessionId,
+ parentSessionId,
+ userId,
+ workspaceId,
+ docId,
+ promptName,
+ promptConfig,
+ };
}
get stashMessages() {
@@ -198,6 +208,19 @@ export class ChatSessionService {
private readonly prompt: PromptService
) {}
+ private async haveSession(sessionId: string, userId: string) {
+ return await this.db.$transaction(async tx => {
+ return await tx.aiSession
+ .count({
+ where: {
+ id: sessionId,
+ userId,
+ },
+ })
+ .then(c => c > 0);
+ });
+ }
+
private async setSession(state: ChatSessionState): Promise {
return await this.db.$transaction(async tx => {
let sessionId = state.sessionId;
@@ -226,15 +249,7 @@ export class ChatSessionService {
if (id) sessionId = id;
}
- const haveSession = await tx.aiSession
- .count({
- where: {
- id: sessionId,
- userId: state.userId,
- },
- })
- .then(c => c > 0);
-
+ const haveSession = await this.haveSession(sessionId, state.userId);
if (haveSession) {
// message will only exists when setSession call by session.save
if (state.messages.length) {
@@ -280,6 +295,22 @@ export class ChatSessionService {
});
}
+ private async updateSession(state: ChatSessionState): Promise {
+ return await this.db.$transaction(async tx => {
+ let sessionId = state.sessionId;
+ const haveSession = await this.haveSession(sessionId, state.userId);
+ if (haveSession) {
+ await tx.aiSession.update({
+ where: { id: sessionId },
+ data: {
+ promptName: state.prompt.name,
+ },
+ });
+ }
+ return sessionId;
+ });
+ }
+
private async getSession(
sessionId: string
): Promise {
@@ -570,6 +601,19 @@ export class ChatSessionService {
});
}
+ async update(options: ChatSessionUpdateOptions): Promise {
+ const prompt = await this.prompt.get(options.promptName);
+ if (!prompt) {
+ this.logger.error(`Prompt not found: ${options.promptName}`);
+ throw new CopilotPromptNotFound({ name: options.promptName });
+ }
+ return await this.updateSession({
+ ...options,
+ prompt,
+ messages: [],
+ });
+ }
+
async fork(options: ChatSessionForkOptions): Promise {
const state = await this.getSession(options.sessionId);
if (!state) {
diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts
index 679377f16a3a2..25aa8a77af601 100644
--- a/packages/backend/server/src/plugins/copilot/types.ts
+++ b/packages/backend/server/src/plugins/copilot/types.ts
@@ -123,6 +123,15 @@ export interface ChatSessionOptions {
promptName: string;
}
+export interface ChatSessionUpdateOptions {
+ sessionId: string;
+ parentSessionId: string | null;
+ userId: string;
+ workspaceId: string;
+ docId: string;
+ promptName: string;
+}
+
export interface ChatSessionForkOptions
extends Omit {
sessionId: string;
@@ -154,6 +163,7 @@ export type ListHistoriesOptions = {
export enum CopilotProviderType {
FAL = 'fal',
OpenAI = 'openai',
+ Perplexity = 'perplexity',
// only for test
Test = 'test',
}
diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql
index 3ec9dee256c99..a5c168609aa98 100644
--- a/packages/backend/server/src/schema.gql
+++ b/packages/backend/server/src/schema.gql
@@ -551,6 +551,9 @@ type Mutation {
"""Update a copilot prompt"""
updateCopilotPrompt(messages: [CopilotPromptMessageInput!]!, name: String!): CopilotPromptType!
+
+ """Update a chat session"""
+ updateCopilotSession(options: UpdateChatSessionInput!): String!
updateProfile(input: UpdateUserInput!): UserType!
"""update server runtime configurable setting"""
@@ -865,6 +868,12 @@ type UnsupportedSubscriptionPlanDataType {
plan: String!
}
+input UpdateChatSessionInput {
+ """The prompt name to use for the session"""
+ promptName: String!
+ sessionId: String!
+}
+
input UpdateUserInput {
"""User name"""
name: String
diff --git a/packages/backend/server/tests/copilot-provider.spec.ts b/packages/backend/server/tests/copilot-provider.spec.ts
index c22493f987da1..e17d8b4b86fff 100644
--- a/packages/backend/server/tests/copilot-provider.spec.ts
+++ b/packages/backend/server/tests/copilot-provider.spec.ts
@@ -13,6 +13,7 @@ import {
CopilotProviderService,
FalProvider,
OpenAIProvider,
+ PerplexityProvider,
registerCopilotProvider,
unregisterCopilotProvider,
} from '../src/plugins/copilot/providers';
@@ -47,8 +48,10 @@ const test = ava as TestFn;
const isCopilotConfigured =
!!process.env.COPILOT_OPENAI_API_KEY &&
!!process.env.COPILOT_FAL_API_KEY &&
+ !!process.env.COPILOT_PERPLEXITY_API_KEY &&
process.env.COPILOT_OPENAI_API_KEY !== '1' &&
- process.env.COPILOT_FAL_API_KEY !== '1';
+ process.env.COPILOT_FAL_API_KEY !== '1' &&
+ process.env.COPILOT_PERPLEXITY_API_KEY !== '1';
const runIfCopilotConfigured = test.macro(
async (
t,
@@ -75,6 +78,9 @@ test.serial.before(async t => {
fal: {
apiKey: process.env.COPILOT_FAL_API_KEY,
},
+ perplexity: {
+ apiKey: process.env.COPILOT_PERPLEXITY_API_KEY,
+ },
},
},
}),
@@ -111,6 +117,7 @@ test.serial.before(async t => {
registerCopilotProvider(OpenAIProvider);
registerCopilotProvider(FalProvider);
+ registerCopilotProvider(PerplexityProvider);
for (const name of await prompt.listNames()) {
await prompt.delete(name);
@@ -124,6 +131,7 @@ test.serial.before(async t => {
test.after(async _ => {
unregisterCopilotProvider(OpenAIProvider.type);
unregisterCopilotProvider(FalProvider.type);
+ unregisterCopilotProvider(PerplexityProvider.type);
});
test.after(async t => {
@@ -152,7 +160,6 @@ const checkMDList = (text: string) => {
return false;
}
- // eslint-disable-next-line @typescript-eslint/no-non-null-asserted-optional-chain
const currentIndent = line.match(/^( *)/)?.[0].length!;
if (Number.isNaN(currentIndent) || currentIndent % 2 !== 0) {
return false;
@@ -282,6 +289,8 @@ const actions = [
'Make it longer',
'Make it shorter',
'Continue writing',
+ 'Chat With AFFiNE AI',
+ 'Search With AFFiNE AI',
],
messages: [{ role: 'user' as const, content: TestAssets.SSOT }],
verifier: (t: ExecutionContext, result: string) => {
diff --git a/packages/backend/server/tests/copilot.e2e.ts b/packages/backend/server/tests/copilot.e2e.ts
index 49262127d9a98..71a523b31460a 100644
--- a/packages/backend/server/tests/copilot.e2e.ts
+++ b/packages/backend/server/tests/copilot.e2e.ts
@@ -16,6 +16,7 @@ import {
CopilotProviderService,
FalProvider,
OpenAIProvider,
+ PerplexityProvider,
registerCopilotProvider,
unregisterCopilotProvider,
} from '../src/plugins/copilot/providers';
@@ -41,6 +42,7 @@ import {
sse2array,
textToEventStream,
unsplashSearch,
+ updateCopilotSession,
} from './utils/copilot';
const test = ava as TestFn<{
@@ -63,6 +65,9 @@ test.beforeEach(async t => {
fal: {
apiKey: '1',
},
+ perplexity: {
+ apiKey: '1',
+ },
unsplashKey: process.env.UNSPLASH_ACCESS_KEY || '1',
},
},
@@ -91,6 +96,7 @@ test.beforeEach(async t => {
unregisterCopilotProvider(OpenAIProvider.type);
unregisterCopilotProvider(FalProvider.type);
+ unregisterCopilotProvider(PerplexityProvider.type);
registerCopilotProvider(MockCopilotTestProvider);
await prompt.set(promptName, 'test', [
@@ -156,6 +162,85 @@ test('should create session correctly', async t => {
}
});
+test('should update session correctly', async t => {
+ const { app } = t.context;
+
+ const assertUpdateSession = async (
+ sessionId: string,
+ error: string,
+ asserter = async (x: any) => {
+ t.truthy(await x, error);
+ }
+ ) => {
+ await asserter(updateCopilotSession(app, token, sessionId, promptName));
+ };
+
+ {
+ const { id: workspaceId } = await createWorkspace(app, token);
+ const docId = randomUUID();
+ const sessionId = await createCopilotSession(
+ app,
+ token,
+ workspaceId,
+ docId,
+ promptName
+ );
+ await assertUpdateSession(
+ sessionId,
+ 'should be able to update session with cloud workspace that user can access'
+ );
+ }
+
+ {
+ const sessionId = await createCopilotSession(
+ app,
+ token,
+ randomUUID(),
+ randomUUID(),
+ promptName
+ );
+ await assertUpdateSession(
+ sessionId,
+ 'should be able to update session with local workspace'
+ );
+ }
+
+ {
+ const aToken = (await signUp(app, 'test', 'test@affine.pro', '123456'))
+ .token.token;
+ const { id: workspaceId } = await createWorkspace(app, aToken);
+ const inviteId = await inviteUser(
+ app,
+ aToken,
+ workspaceId,
+ 'darksky@affine.pro'
+ );
+ await acceptInviteById(app, workspaceId, inviteId, false);
+ const sessionId = await createCopilotSession(
+ app,
+ token,
+ workspaceId,
+ randomUUID(),
+ promptName
+ );
+ await assertUpdateSession(
+ sessionId,
+ 'should able to update session after user have permission'
+ );
+ }
+
+ {
+ const sessionId = '123456';
+ await assertUpdateSession(sessionId, '', async x => {
+ await t.throwsAsync(
+ x,
+ { instanceOf: Error },
+ 'should not able to update invalid session id'
+ );
+ });
+ }
+});
+
test('should fork session correctly', async t => {
const { app } = t.context;
diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts
index 9372fc893a368..f4f0774f5ba22 100644
--- a/packages/backend/server/tests/copilot.spec.ts
+++ b/packages/backend/server/tests/copilot.spec.ts
@@ -68,7 +68,10 @@ test.beforeEach(async t => {
apiKey: process.env.COPILOT_OPENAI_API_KEY ?? '1',
},
fal: {
- apiKey: '1',
+ apiKey: process.env.COPILOT_FAL_API_KEY ?? '1',
+ },
+ perplexity: {
+ apiKey: process.env.COPILOT_PERPLEXITY_API_KEY ?? '1',
},
},
},
@@ -274,6 +277,46 @@ test('should be able to manage chat session', async t => {
}
});
+test('should be able to update chat session', async t => {
+ const { prompt, session } = t.context;
+
+ // Set up a prompt to be used in the session
+ await prompt.set('prompt', 'model', [
+ { role: 'system', content: 'hello {{word}}' },
+ ]);
+
+ const commonParams = {
+ docId: 'test',
+ workspaceId: 'test',
+ parentSessionId: null,
+ userId,
+ };
+
+ // Create a session
+ const sessionId = await session.create({
+ promptName: 'prompt',
+ ...commonParams,
+ });
+ t.truthy(sessionId, 'should create session');
+
+ // Update the session
+ const updatedSessionId = await session.update({
+ sessionId,
+ promptName: 'Search With AFFiNE AI',
+ ...commonParams,
+ });
+ t.is(updatedSessionId, sessionId, 'should update session with same id');
+
+ // Verify the session was updated
+ const updatedSession = await session.get(sessionId);
+ t.truthy(updatedSession, 'should retrieve updated session');
+ t.is(
+ updatedSession?.config.promptName,
+ 'Search With AFFiNE AI',
+ 'should have updated prompt name'
+ );
+});
+
test('should be able to fork chat session', async t => {
const { auth, prompt, session } = t.context;
diff --git a/packages/backend/server/tests/perplexity.spec.ts b/packages/backend/server/tests/perplexity.spec.ts
new file mode 100644
index 0000000000000..7c534399942eb
--- /dev/null
+++ b/packages/backend/server/tests/perplexity.spec.ts
@@ -0,0 +1,70 @@
+///
+
+import { ReadableStreamDefaultReader } from 'node:stream/web';
+
+import ava, { TestFn } from 'ava';
+import Sinon from 'sinon';
+
+import {
+ injectCitations,
+ PerplexityProvider,
+} from '../src/plugins/copilot/providers/perplexity';
+
+const test = ava as TestFn<{
+ provider: PerplexityProvider;
+ mockReader: Sinon.SinonStubbedInstance<
+ ReadableStreamDefaultReader
+ >;
+ loggerSpy: Sinon.SinonSpy;
+}>;
+
+test.beforeEach(t => {
+ const provider = new PerplexityProvider({ apiKey: 'test-api-key' });
+ const mockReader = {
+ read: Sinon.stub(),
+ releaseLock: Sinon.stub(),
+ cancel: Sinon.stub(),
+ } as unknown as Sinon.SinonStubbedInstance<
+ ReadableStreamDefaultReader
+ >;
+
+ // Spy on the logger
+ const loggerSpy = Sinon.spy(provider['logger'], 'warn');
+
+ t.context = { provider, mockReader, loggerSpy };
+});
+
+test.afterEach.always(t => {
+ // Restore the original logger method
+ t.context.loggerSpy.restore();
+});
+
+test('injectCitations should replace citation placeholders with URLs', t => {
+ const content =
+ 'This is [a] test sentence with citations [1] and [[2]] and [3].';
+ const citations = [
+ 'https://example.com/citation1',
+ 'https://example.com/citation2',
+ ];
+
+ const expected =
+ 'This is [a] test sentence with citations [[1](https://example.com/citation1)] and [[2](https://example.com/citation2)] and [3].';
+ const result = injectCitations(content, citations);
+
+ t.is(result, expected);
+});
+
+test('injectCitations should not replace citation already with URLs', t => {
+ const content =
+ 'Test sentence with citations [1](https://example.com/citationx) and [[2]](https://example.com/citationy) and [[3](https://example.com/citationz)].';
+ const citations = [
+ 'https://example.com/citation1',
+ 'https://example.com/citation2',
+ 'https://example.com/citation3',
+ ];
+
+ const expected = content;
+ const result = injectCitations(content, citations);
+
+ t.is(result, expected);
+});
diff --git a/packages/backend/server/tests/utils/copilot.ts b/packages/backend/server/tests/utils/copilot.ts
index 07bce549a2535..e88971ce30e70 100644
--- a/packages/backend/server/tests/utils/copilot.ts
+++ b/packages/backend/server/tests/utils/copilot.ts
@@ -184,6 +184,31 @@ export async function createCopilotSession(
return res.body.data.createCopilotSession;
}
+export async function updateCopilotSession(
+ app: INestApplication,
+ userToken: string,
+ sessionId: string,
+ promptName: string
+): Promise {
+ const res = await request(app.getHttpServer())
+ .post(gql)
+ .auth(userToken, { type: 'bearer' })
+ .set({ 'x-request-id': 'test', 'x-operation-name': 'test' })
+ .send({
+ query: `
+ mutation updateCopilotSession($options: UpdateChatSessionInput!) {
+ updateCopilotSession(options: $options)
+ }
+ `,
+ variables: { options: { sessionId, promptName } },
+ })
+ .expect(200);
+
+ handleGraphQLError(res);
+
+ return res.body.data.updateCopilotSession;
+}
+
export async function forkCopilotSession(
app: INestApplication,
userToken: string,
diff --git a/packages/frontend/core/src/blocksuite/presets/ai/ai-panel.ts b/packages/frontend/core/src/blocksuite/presets/ai/ai-panel.ts
index 9e77c66c1339f..b232b8fe99164 100644
--- a/packages/frontend/core/src/blocksuite/presets/ai/ai-panel.ts
+++ b/packages/frontend/core/src/blocksuite/presets/ai/ai-panel.ts
@@ -1,3 +1,4 @@
+import { AINetworkSearchService } from '@affine/core/modules/ai-button/services/network-search';
import type { EditorHost } from '@blocksuite/affine/block-std';
import {
type AffineAIPanelWidget,
@@ -9,6 +10,7 @@ import {
NoteDisplayMode,
} from '@blocksuite/affine/blocks';
import { assertExists, Bound } from '@blocksuite/affine/global/utils';
+import type { FrameworkProvider } from '@toeverything/infra';
import type { TemplateResult } from 'lit';
import { createTextRenderer, insertFromMarkdown } from '../_common';
@@ -287,14 +289,21 @@ export function buildCopyConfig(panel: AffineAIPanelWidget) {
}
export function buildAIPanelConfig(
- panel: AffineAIPanelWidget
+ panel: AffineAIPanelWidget,
+ framework: FrameworkProvider
): AffineAIPanelWidgetConfig {
const ctx = new AIContext();
+ const searchService = framework.get(AINetworkSearchService);
return {
answerRenderer: createTextRenderer(panel.host, { maxHeight: 320 }),
finishStateConfig: buildFinishConfig(panel, 'chat', ctx),
generatingStateConfig: buildGeneratingConfig(),
errorStateConfig: buildErrorConfig(panel),
copy: buildCopyConfig(panel),
+ networkSearchConfig: {
+ visible: searchService.visible,
+ enabled: searchService.enabled,
+ setEnabled: searchService.setEnabled,
+ },
};
}
diff --git a/packages/frontend/core/src/blocksuite/presets/ai/ai-spec.ts b/packages/frontend/core/src/blocksuite/presets/ai/ai-spec.ts
index b672cf1e3ab30..6ad115b1ebfa1 100644
--- a/packages/frontend/core/src/blocksuite/presets/ai/ai-spec.ts
+++ b/packages/frontend/core/src/blocksuite/presets/ai/ai-spec.ts
@@ -23,6 +23,7 @@ import {
} from '@blocksuite/affine/blocks';
import { assertInstanceOf } from '@blocksuite/affine/global/utils';
import type { ExtensionType } from '@blocksuite/affine/store';
+import type { FrameworkProvider } from '@toeverything/infra';
import { literal, unsafeStatic } from 'lit/static-html.js';
import { buildAIPanelConfig } from './ai-panel';
@@ -36,96 +37,110 @@ import { setupImageToolbarAIEntry } from './entries/image-toolbar/setup-image-to
import { setupSlashMenuAIEntry } from './entries/slash-menu/setup-slash-menu';
import { setupSpaceAIEntry } from './entries/space/setup-space';
-class AIPageRootWatcher extends BlockServiceWatcher {
- static override readonly flavour = 'affine:page';
-
- override mounted() {
- super.mounted();
- this.blockService.specSlots.widgetConnected.on(view => {
- if (view.component instanceof AffineAIPanelWidget) {
- view.component.style.width = '630px';
- view.component.config = buildAIPanelConfig(view.component);
- setupSpaceAIEntry(view.component);
- }
-
- if (view.component instanceof AffineFormatBarWidget) {
- setupFormatBarAIEntry(view.component);
- }
-
- if (view.component instanceof AffineSlashMenuWidget) {
- setupSlashMenuAIEntry(view.component);
- }
- });
+function getAIPageRootWatcher(framework: FrameworkProvider) {
+ class AIPageRootWatcher extends BlockServiceWatcher {
+ static override readonly flavour = 'affine:page';
+
+ override mounted() {
+ super.mounted();
+ this.blockService.specSlots.widgetConnected.on(view => {
+ if (view.component instanceof AffineAIPanelWidget) {
+ view.component.style.width = '630px';
+ view.component.config = buildAIPanelConfig(view.component, framework);
+ setupSpaceAIEntry(view.component);
+ }
+
+ if (view.component instanceof AffineFormatBarWidget) {
+ setupFormatBarAIEntry(view.component);
+ }
+
+ if (view.component instanceof AffineSlashMenuWidget) {
+ setupSlashMenuAIEntry(view.component);
+ }
+ });
+ }
}
+ return AIPageRootWatcher;
}
-export const AIPageRootBlockSpec: ExtensionType[] = [
- ...PageRootBlockSpec,
- AIPageRootWatcher,
- {
- setup: di => {
- di.override(WidgetViewMapIdentifier('affine:page'), () => {
- return {
- ...pageRootWidgetViewMap,
- [AFFINE_AI_PANEL_WIDGET]: literal`${unsafeStatic(
- AFFINE_AI_PANEL_WIDGET
- )}`,
- };
- });
+export function createAIPageRootBlockSpec(
+ framework: FrameworkProvider
+): ExtensionType[] {
+ return [
+ ...PageRootBlockSpec,
+ getAIPageRootWatcher(framework),
+ {
+ setup: di => {
+ di.override(WidgetViewMapIdentifier('affine:page'), () => {
+ return {
+ ...pageRootWidgetViewMap,
+ [AFFINE_AI_PANEL_WIDGET]: literal`${unsafeStatic(
+ AFFINE_AI_PANEL_WIDGET
+ )}`,
+ };
+ });
+ },
},
- },
-];
-
-class AIEdgelessRootWatcher extends BlockServiceWatcher {
- static override readonly flavour = 'affine:page';
-
- override mounted() {
- super.mounted();
- this.blockService.specSlots.widgetConnected.on(view => {
- if (view.component instanceof AffineAIPanelWidget) {
- view.component.style.width = '430px';
- view.component.config = buildAIPanelConfig(view.component);
- setupSpaceAIEntry(view.component);
- }
-
- if (view.component instanceof EdgelessCopilotWidget) {
- setupEdgelessCopilot(view.component);
- }
-
- if (view.component instanceof EdgelessElementToolbarWidget) {
- setupEdgelessElementToolbarAIEntry(view.component);
- }
-
- if (view.component instanceof AffineFormatBarWidget) {
- setupFormatBarAIEntry(view.component);
- }
+ ];
+}
- if (view.component instanceof AffineSlashMenuWidget) {
- setupSlashMenuAIEntry(view.component);
- }
- });
+function getAIEdgelessRootWatcher(framework: FrameworkProvider) {
+ class AIEdgelessRootWatcher extends BlockServiceWatcher {
+ static override readonly flavour = 'affine:page';
+
+ override mounted() {
+ super.mounted();
+ this.blockService.specSlots.widgetConnected.on(view => {
+ if (view.component instanceof AffineAIPanelWidget) {
+ view.component.style.width = '430px';
+ view.component.config = buildAIPanelConfig(view.component, framework);
+ setupSpaceAIEntry(view.component);
+ }
+
+ if (view.component instanceof EdgelessCopilotWidget) {
+ setupEdgelessCopilot(view.component);
+ }
+
+ if (view.component instanceof EdgelessElementToolbarWidget) {
+ setupEdgelessElementToolbarAIEntry(view.component);
+ }
+
+ if (view.component instanceof AffineFormatBarWidget) {
+ setupFormatBarAIEntry(view.component);
+ }
+
+ if (view.component instanceof AffineSlashMenuWidget) {
+ setupSlashMenuAIEntry(view.component);
+ }
+ });
+ }
}
+ return AIEdgelessRootWatcher;
}
-export const AIEdgelessRootBlockSpec: ExtensionType[] = [
- ...EdgelessRootBlockSpec,
- AIEdgelessRootWatcher,
- {
- setup: di => {
- di.override(WidgetViewMapIdentifier('affine:page'), () => {
- return {
- ...edgelessRootWidgetViewMap,
- [AFFINE_EDGELESS_COPILOT_WIDGET]: literal`${unsafeStatic(
- AFFINE_EDGELESS_COPILOT_WIDGET
- )}`,
- [AFFINE_AI_PANEL_WIDGET]: literal`${unsafeStatic(
- AFFINE_AI_PANEL_WIDGET
- )}`,
- };
- });
+export function createAIEdgelessRootBlockSpec(
+ framework: FrameworkProvider
+): ExtensionType[] {
+ return [
+ ...EdgelessRootBlockSpec,
+ getAIEdgelessRootWatcher(framework),
+ {
+ setup: di => {
+ di.override(WidgetViewMapIdentifier('affine:page'), () => {
+ return {
+ ...edgelessRootWidgetViewMap,
+ [AFFINE_EDGELESS_COPILOT_WIDGET]: literal`${unsafeStatic(
+ AFFINE_EDGELESS_COPILOT_WIDGET
+ )}`,
+ [AFFINE_AI_PANEL_WIDGET]: literal`${unsafeStatic(
+ AFFINE_AI_PANEL_WIDGET
+ )}`,
+ };
+ });
+ },
},
- },
-];
+ ];
+}
class AIParagraphBlockWatcher extends BlockServiceWatcher {
static override readonly flavour = 'affine:paragraph';
diff --git a/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts b/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts
index 7a80e414bb762..e4d6ba2da1989 100644
--- a/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts
+++ b/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts
@@ -1,6 +1,14 @@
+import { stopPropagation } from '@affine/core/utils';
import type { EditorHost } from '@blocksuite/affine/block-std';
import { type AIError, openFileOrFiles } from '@blocksuite/affine/blocks';
-import { assertExists, WithDisposable } from '@blocksuite/affine/global/utils';
+import {
+ assertExists,
+ SignalWatcher,
+ WithDisposable,
+} from '@blocksuite/affine/global/utils';
+import { unsafeCSSVarV2 } from '@blocksuite/affine-shared/theme';
+import { ImageIcon, PublishIcon } from '@blocksuite/icons/lit';
+import type { Signal } from '@preact/signals-core';
import { css, html, LitElement, nothing } from 'lit';
import { property, query, state } from 'lit/decorators.js';
import { repeat } from 'lit/directives/repeat.js';
@@ -10,7 +18,6 @@ import {
ChatClearIcon,
ChatSendIcon,
CloseIcon,
- ImageIcon,
} from '../_common/icons';
import { AIProvider } from '../provider';
import { reportResponse } from '../utils/action-reporter';
@@ -24,7 +31,13 @@ function getFirstTwoLines(text: string) {
return lines.slice(0, 2);
}
-export class ChatPanelInput extends WithDisposable(LitElement) {
+export interface AINetworkSearchConfig {
+ visible: Signal;
+ enabled: Signal;
+ setEnabled: (state: boolean) => void;
+}
+
+export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) {
static override styles = css`
.chat-panel-input {
display: flex;
@@ -104,10 +117,28 @@ export class ChatPanelInput extends WithDisposable(LitElement) {
margin-left: auto;
}
- .image-upload {
+ .image-upload,
+ .chat-network-search {
display: flex;
justify-content: center;
align-items: center;
+ svg {
+ width: 20px;
+ height: 20px;
+ color: ${unsafeCSSVarV2('icon/primary')};
+ }
+ }
+ .chat-network-search[data-active='true'] svg {
+ color: ${unsafeCSSVarV2('icon/activated')};
+ }
+
+ .image-upload[aria-disabled='true'],
+ .chat-network-search[aria-disabled='true'] {
+ cursor: not-allowed;
+ }
+ .image-upload[aria-disabled='true'] svg,
+ .chat-network-search[aria-disabled='true'] svg {
+ color: var(--affine-text-disable-color) !important;
}
}
@@ -235,6 +266,9 @@ export class ChatPanelInput extends WithDisposable(LitElement) {
@property({ attribute: false })
accessor cleanupHistories!: () => Promise;
+ @property({ attribute: false })
+ accessor networkSearchConfig!: AINetworkSearchConfig;
+
private _addImages(images: File[]) {
const oldImages = this.chatContextValue.images;
this.updateContext({
@@ -296,6 +330,23 @@ export class ChatPanelInput extends WithDisposable(LitElement) {
`;
}
+ private readonly _toggleNetworkSearch = (e: MouseEvent) => {
+ e.preventDefault();
+ e.stopPropagation();
+
+ const enable = this.networkSearchConfig.enabled.value;
+ this.networkSearchConfig.setEnabled(!enable);
+ };
+
+ private readonly _uploadImageFiles = async (_e: MouseEvent) => {
+ const images = await openFileOrFiles({
+ acceptType: 'Images',
+ multiple: true,
+ });
+ if (!images) return;
+ this._addImages(images);
+ };
+
override connectedCallback() {
super.connectedCallback();
@@ -305,7 +356,7 @@ export class ChatPanelInput extends WithDisposable(LitElement) {
if (this.host === host) {
context && this.updateContext(context);
await this.updateComplete;
- await this.send(input);
+ input && (await this.send(input));
}
}
)
@@ -316,7 +367,9 @@ export class ChatPanelInput extends WithDisposable(LitElement) {
const { images, status } = this.chatContextValue;
const hasImages = images.length > 0;
const maxHeight = hasImages ? 272 + 2 : 200 + 2;
-
+ const networkDisabled = !!this.chatContextValue.images.length;
+ const networkActive = !!this.networkSearchConfig.enabled.value;
+ const uploadDisabled = networkActive && !networkDisabled;
return html`