diff --git a/.github/actions/copilot-test/action.yml b/.github/actions/copilot-test/action.yml index a9c95250621f4..cb75daafec050 100644 --- a/.github/actions/copilot-test/action.yml +++ b/.github/actions/copilot-test/action.yml @@ -26,6 +26,7 @@ runs: DEV_SERVER_URL: http://localhost:8080 COPILOT_OPENAI_API_KEY: ${{ inputs.openai-key }} COPILOT_FAL_API_KEY: ${{ inputs.fal-key }} + COPILOT_PERPLEXITY_API_KEY: ${{ inputs.perplexity-key }} - name: Upload test results if: ${{ failure() }} diff --git a/.github/actions/deploy/deploy.mjs b/.github/actions/deploy/deploy.mjs index 3059feabbf0fb..e5a35a18bf070 100644 --- a/.github/actions/deploy/deploy.mjs +++ b/.github/actions/deploy/deploy.mjs @@ -17,6 +17,7 @@ const { METRICS_CUSTOMER_IO_TOKEN, COPILOT_OPENAI_API_KEY, COPILOT_FAL_API_KEY, + COPILOT_PERPLEXITY_API_KEY, COPILOT_UNSPLASH_API_KEY, MAILER_SENDER, MAILER_USER, @@ -147,6 +148,7 @@ const createHelmCommand = ({ isDryRun }) => { `--set graphql.app.copilot.enabled=true`, `--set-string graphql.app.copilot.openai.key="${COPILOT_OPENAI_API_KEY}"`, `--set-string graphql.app.copilot.fal.key="${COPILOT_FAL_API_KEY}"`, + `--set-string graphql.app.copilot.perplexity.key="${COPILOT_PERPLEXITY_API_KEY}"`, `--set-string graphql.app.copilot.unsplash.key="${COPILOT_UNSPLASH_API_KEY}"`, `--set-string graphql.app.mailer.sender="${MAILER_SENDER}"`, `--set-string graphql.app.mailer.user="${MAILER_USER}"`, diff --git a/.github/helm/affine/charts/graphql/templates/deployment.yaml b/.github/helm/affine/charts/graphql/templates/deployment.yaml index 36a04270c53c8..aaa79e0dcfe52 100644 --- a/.github/helm/affine/charts/graphql/templates/deployment.yaml +++ b/.github/helm/affine/charts/graphql/templates/deployment.yaml @@ -157,6 +157,11 @@ spec: secretKeyRef: name: "{{ .Values.app.copilot.secretName }}" key: falSecret + - name: COPILOT_PERPLEXITY_API_KEY + valueFrom: + secretKeyRef: + name: "{{ .Values.app.copilot.secretName }}" + key: perplexitySecret - name: COPILOT_UNSPLASH_API_KEY valueFrom: secretKeyRef: diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 0dae039b8298e..4ee194c913007 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -531,6 +531,7 @@ jobs: CARGO_TARGET_DIR: '${{ github.workspace }}/target' COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }} COPILOT_FAL_API_KEY: ${{ secrets.COPILOT_FAL_API_KEY }} + COPILOT_PERPLEXITY_API_KEY: ${{ secrets.COPILOT_PERPLEXITY_API_KEY }} - name: Upload server test coverage results if: ${{ steps.check-blocksuite-update.outputs.skip != 'true' || steps.apifilter.outputs.changed == 'true' }} @@ -619,6 +620,7 @@ jobs: script: yarn affine @affine-test/affine-cloud-copilot e2e --forbid-only --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} openai-key: ${{ secrets.COPILOT_OPENAI_API_KEY }} fal-key: ${{ secrets.COPILOT_FAL_API_KEY }} + perplexity-key: ${{ secrets.COPILOT_PERPLEXITY_API_KEY }} server-e2e-test: name: ${{ matrix.tests.name }} @@ -703,6 +705,7 @@ jobs: DEV_SERVER_URL: http://localhost:8080 COPILOT_OPENAI_API_KEY: 1 COPILOT_FAL_API_KEY: 1 + COPILOT_PERPLEXITY_API_KEY: 1 - name: Upload test results if: ${{ failure() }} diff --git a/.github/workflows/copilot-test.yml b/.github/workflows/copilot-test.yml index e91942173be87..5ce2300f331d6 100644 --- a/.github/workflows/copilot-test.yml +++ b/.github/workflows/copilot-test.yml @@ -84,6 +84,7 @@ jobs: CARGO_TARGET_DIR: '${{ github.workspace }}/target' COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }} COPILOT_FAL_API_KEY: ${{ secrets.COPILOT_FAL_API_KEY }} + COPILOT_PERPLEXITY_API_KEY: ${{ secrets.COPILOT_PERPLEXITY_API_KEY }} - name: Upload server test coverage results uses: codecov/codecov-action@v5 @@ -147,6 +148,7 @@ jobs: script: yarn affine @affine-test/affine-cloud-copilot e2e --forbid-only --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} openai-key: ${{ secrets.COPILOT_OPENAI_API_KEY }} fal-key: ${{ secrets.COPILOT_FAL_API_KEY }} + perplexity-key: ${{ secrets.COPILOT_PERPLEXITY_API_KEY }} test-done: needs: diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 7198f08179878..95b3067851d9c 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -98,6 +98,7 @@ jobs: CAPTCHA_TURNSTILE_SECRET: ${{ secrets.CAPTCHA_TURNSTILE_SECRET }} COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }} COPILOT_FAL_API_KEY: ${{ secrets.COPILOT_FAL_API_KEY }} + COPILOT_PERPLEXITY_API_KEY: ${{ secrets.COPILOT_PERPLEXITY_API_KEY }} COPILOT_UNSPLASH_API_KEY: ${{ secrets.COPILOT_UNSPLASH_API_KEY }} METRICS_CUSTOMER_IO_TOKEN: ${{ secrets.METRICS_CUSTOMER_IO_TOKEN }} MAILER_SENDER: ${{ secrets.OAUTH_EMAIL_SENDER }} diff --git a/blocksuite/blocks/src/root-block/widgets/ai-panel/ai-panel.ts b/blocksuite/blocks/src/root-block/widgets/ai-panel/ai-panel.ts index d6ff0d1033db3..028fa6be31831 100644 --- a/blocksuite/blocks/src/root-block/widgets/ai-panel/ai-panel.ts +++ b/blocksuite/blocks/src/root-block/widgets/ai-panel/ai-panel.ts @@ -63,7 +63,7 @@ export class AffineAIPanelWidget extends WidgetComponent { box-sizing: border-box; width: 100%; height: fit-content; - padding: 8px 0; + padding: 10px 0; } .ai-panel-container:not(:has(ai-panel-generating)) { @@ -474,6 +474,7 @@ export class AffineAIPanelWidget extends WidgetComponent { .onBlur=${this.discard} .onFinish=${this._inputFinish} .onInput=${this.onInput} + .networkSearchConfig=${config.networkSearchConfig} >`, ], [ diff --git a/blocksuite/blocks/src/root-block/widgets/ai-panel/components/state/input.ts b/blocksuite/blocks/src/root-block/widgets/ai-panel/components/state/input.ts index e9c6dbcba6bb5..b72526a398b88 100644 --- a/blocksuite/blocks/src/root-block/widgets/ai-panel/components/state/input.ts +++ b/blocksuite/blocks/src/root-block/widgets/ai-panel/components/state/input.ts @@ -1,11 +1,14 @@ import { AIStarIcon } from '@blocksuite/affine-components/icons'; +import { unsafeCSSVarV2 } from '@blocksuite/affine-shared/theme'; import { stopPropagation } from '@blocksuite/affine-shared/utils'; -import { WithDisposable } from '@blocksuite/global/utils'; -import { SendIcon } from '@blocksuite/icons/lit'; +import { SignalWatcher, WithDisposable } from '@blocksuite/global/utils'; +import { PublishIcon, SendIcon } from '@blocksuite/icons/lit'; import { css, html, LitElement, nothing } from 'lit'; import { property, query, state } from 'lit/decorators.js'; -export class AIPanelInput extends WithDisposable(LitElement) { +import type { AINetworkSearchConfig } from '../../type'; + +export class AIPanelInput extends SignalWatcher(WithDisposable(LitElement)) { static override styles = css` :host { width: 100%; @@ -20,8 +23,9 @@ export class AIPanelInput extends WithDisposable(LitElement) { background: var(--affine-background-overlay-panel-color); } - .icon { + .star { display: flex; + padding: 2px; align-items: center; } @@ -66,22 +70,36 @@ export class AIPanelInput extends WithDisposable(LitElement) { display: flex; align-items: center; padding: 2px; - gap: 10px; + gap: 4px; border-radius: 4px; - background: var(--affine-black-10, rgba(0, 0, 0, 0.1)); - + background: ${unsafeCSSVarV2('icon/disable')}; svg { - width: 16px; - height: 16px; - color: var(--affine-pure-white, #fff); + width: 20px; + height: 20px; + color: ${unsafeCSSVarV2('button/pureWhiteText')}; } } .arrow[data-active] { - background: var(--affine-brand-color, #1e96eb); + background: ${unsafeCSSVarV2('icon/activated')}; } .arrow[data-active]:hover { cursor: pointer; } + .network { + display: flex; + align-items: center; + padding: 2px; + gap: 4px; + cursor: pointer; + svg { + width: 20px; + height: 20px; + color: ${unsafeCSSVarV2('icon/primary')}; + } + } + .network[data-active='true'] svg { + color: ${unsafeCSSVarV2('icon/activated')}; + } `; private readonly _onInput = () => { @@ -101,12 +119,14 @@ export class AIPanelInput extends WithDisposable(LitElement) { private readonly _onKeyDown = (e: KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey && !e.isComposing) { - e.preventDefault(); - this._sendToAI(); + this._sendToAI(e); } }; - private readonly _sendToAI = () => { + private readonly _sendToAI = (e: MouseEvent | KeyboardEvent) => { + e.preventDefault(); + e.stopPropagation(); + const value = this.textarea.value.trim(); if (value.length === 0) return; @@ -114,9 +134,17 @@ export class AIPanelInput extends WithDisposable(LitElement) { this.remove(); }; + private readonly _toggleNetworkSearch = (e: MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + + const enable = this.networkSearchConfig.enabled.value; + this.networkSearchConfig.setEnabled(!enable); + }; + override render() { return html`
-
${AIStarIcon}
+
${AIStarIcon}
+ ${this.networkSearchConfig.visible.value + ? html` +
+ ${PublishIcon()} + Toggle Network Search +
+ ` + : nothing}
void) | undefined = undefined; diff --git a/blocksuite/blocks/src/root-block/widgets/ai-panel/type.ts b/blocksuite/blocks/src/root-block/widgets/ai-panel/type.ts index a9961c27ba67b..7be30f3928702 100644 --- a/blocksuite/blocks/src/root-block/widgets/ai-panel/type.ts +++ b/blocksuite/blocks/src/root-block/widgets/ai-panel/type.ts @@ -2,6 +2,7 @@ import type { AIError, AIItemGroupConfig, } from '@blocksuite/affine-components/ai-item'; +import type { Signal } from '@preact/signals-core'; import type { nothing, TemplateResult } from 'lit'; export interface CopyConfig { @@ -28,6 +29,12 @@ export interface AIPanelGeneratingConfig { stages?: string[]; } +export interface AINetworkSearchConfig { + visible: Signal; + enabled: Signal; + setEnabled: (state: boolean) => void; +} + export interface AffineAIPanelWidgetConfig { answerRenderer: ( answer: string, @@ -44,10 +51,10 @@ export interface AffineAIPanelWidgetConfig { finishStateConfig: AIPanelAnswerConfig; generatingStateConfig: AIPanelGeneratingConfig; errorStateConfig: AIPanelErrorConfig; + networkSearchConfig: AINetworkSearchConfig; hideCallback?: () => void; discardCallback?: () => void; inputCallback?: (input: string) => void; - copy?: CopyConfig; } diff --git a/packages/backend/server/.env.example b/packages/backend/server/.env.example index 2a56e69bf6960..9c740f049243a 100644 --- a/packages/backend/server/.env.example +++ b/packages/backend/server/.env.example @@ -2,6 +2,7 @@ # REDIS_SERVER_HOST=localhost # COPILOT_FAL_API_KEY=YOUR_KEY # COPILOT_OPENAI_API_KEY=YOUR_KEY +# COPILOT_PERPLEXITY_API_KEY=YOUR_KEY # MAILER_HOST=127.0.0.1 # MAILER_PORT=1025 diff --git a/packages/backend/server/package.json b/packages/backend/server/package.json index 196277b3f1c1a..9533ebc4ecf03 100644 --- a/packages/backend/server/package.json +++ b/packages/backend/server/package.json @@ -58,6 +58,7 @@ "@socket.io/redis-adapter": "^8.3.0", "cookie-parser": "^1.4.7", "dotenv": "^16.4.7", + "eventsource-parser": "^3.0.0", "express": "^4.21.2", "fast-xml-parser": "^4.5.0", "get-stream": "^9.0.1", diff --git a/packages/backend/server/src/config/affine.env.ts b/packages/backend/server/src/config/affine.env.ts index 05ff3dc60bd55..c72f5e617f57b 100644 --- a/packages/backend/server/src/config/affine.env.ts +++ b/packages/backend/server/src/config/affine.env.ts @@ -28,6 +28,7 @@ AFFiNE.ENV_MAP = { CAPTCHA_TURNSTILE_SECRET: ['plugins.captcha.turnstile.secret', 'string'], COPILOT_OPENAI_API_KEY: 'plugins.copilot.openai.apiKey', COPILOT_FAL_API_KEY: 'plugins.copilot.fal.apiKey', + COPILOT_PERPLEXITY_API_KEY: 'plugins.copilot.perplexity.apiKey', COPILOT_UNSPLASH_API_KEY: 'plugins.copilot.unsplashKey', REDIS_SERVER_HOST: 'redis.host', REDIS_SERVER_PORT: ['redis.port', 'int'], diff --git a/packages/backend/server/src/plugins/copilot/config.ts b/packages/backend/server/src/plugins/copilot/config.ts index 516334d7ce64b..b7c54ca37ed8d 100644 --- a/packages/backend/server/src/plugins/copilot/config.ts +++ b/packages/backend/server/src/plugins/copilot/config.ts @@ -3,10 +3,12 @@ import type { ClientOptions as OpenAIClientOptions } from 'openai'; import { defineStartupConfig, ModuleConfig } from '../../base/config'; import { StorageConfig } from '../../base/storage/config'; import type { FalConfig } from './providers/fal'; +import { PerplexityConfig } from './providers/perplexity'; export interface CopilotStartupConfigurations { openai?: OpenAIClientOptions; fal?: FalConfig; + perplexity?: PerplexityConfig; test?: never; unsplashKey?: string; storage: StorageConfig; diff --git a/packages/backend/server/src/plugins/copilot/index.ts b/packages/backend/server/src/plugins/copilot/index.ts index 01bbeb6b46e43..2940044c4ad2d 100644 --- a/packages/backend/server/src/plugins/copilot/index.ts +++ b/packages/backend/server/src/plugins/copilot/index.ts @@ -13,6 +13,7 @@ import { CopilotProviderService, FalProvider, OpenAIProvider, + PerplexityProvider, registerCopilotProvider, } from './providers'; import { @@ -26,6 +27,7 @@ import { CopilotWorkflowExecutors, CopilotWorkflowService } from './workflow'; registerCopilotProvider(FalProvider); registerCopilotProvider(OpenAIProvider); +registerCopilotProvider(PerplexityProvider); @Plugin({ name: 'copilot', diff --git a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts index f2e613f160205..ebc1effbe03bc 100644 --- a/packages/backend/server/src/plugins/copilot/prompt/prompts.ts +++ b/packages/backend/server/src/plugins/copilot/prompt/prompts.ts @@ -952,6 +952,11 @@ const chat: Prompt[] = [ }, ], }, + { + name: 'Search With AFFiNE AI', + model: 'llama-3.1-sonar-small-128k-online', + messages: [], + }, // use for believer plan { name: 'Chat With AFFiNE AI - Believer', diff --git a/packages/backend/server/src/plugins/copilot/providers/index.ts b/packages/backend/server/src/plugins/copilot/providers/index.ts index a72cb307f57e0..1e89774d0c9ef 100644 --- a/packages/backend/server/src/plugins/copilot/providers/index.ts +++ b/packages/backend/server/src/plugins/copilot/providers/index.ts @@ -124,9 +124,7 @@ export class CopilotProviderService { if (!this.cachedProviders.has(provider)) { this.cachedProviders.set(provider, this.create(provider)); } - - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - return this.cachedProviders.get(provider)!; + return this.cachedProviders.get(provider) as CopilotProvider; } async getProviderByCapability( @@ -196,3 +194,4 @@ export class CopilotProviderService { export { FalProvider } from './fal'; export { OpenAIProvider } from './openai'; +export { PerplexityProvider } from './perplexity'; diff --git a/packages/backend/server/src/plugins/copilot/providers/perplexity.ts b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts new file mode 100644 index 0000000000000..f107a7501859d --- /dev/null +++ b/packages/backend/server/src/plugins/copilot/providers/perplexity.ts @@ -0,0 +1,209 @@ +import assert from 'node:assert'; + +import { EventSourceParserStream } from 'eventsource-parser/stream'; + +import { + CopilotPromptInvalid, + CopilotProviderSideError, + metrics, +} from '../../../base'; +import { + CopilotCapability, + CopilotChatOptions, + CopilotProviderType, + CopilotTextToTextProvider, + PromptMessage, +} from '../types'; + +export type PerplexityConfig = { + apiKey: string; +}; + +interface Message { + role: 'assistant'; + content: string; +} + +interface Choice { + message: Message; + delta: Message; + finish_reason: 'stop' | null; +} + +interface PerplexityData { + citations: string[]; + choices: Choice[]; +} + +export function injectCitations(content: string, citations: string[]) { + // Match [[n]] and [n] patterns + // Not match if they're already part of a formatted citation + const regex = /(?:\[\[(\d+)\]\]|\[(\d+)\])(?!\]?\([^)]*\))/g; + return content.replace(regex, (_, g1, g2) => { + const index = parseInt(g1 || g2) - 1; + if (index >= 0 && index < citations.length) { + return `[[${g1 || g2}](${citations[index]})]`; + } + return _; + }); +} + +export class PerplexityProvider implements CopilotTextToTextProvider { + static readonly type = CopilotProviderType.Perplexity; + + static readonly capabilities = [CopilotCapability.TextToText]; + + static assetsConfig(config: PerplexityConfig) { + return !!config.apiKey; + } + + constructor(private readonly config: PerplexityConfig) { + assert(PerplexityProvider.assetsConfig(config)); + } + + readonly availableModels = [ + 'llama-3.1-sonar-small-128k-online', + 'llama-3.1-sonar-large-128k-online', + 'llama-3.1-sonar-huge-128k-online', + ]; + + get type(): CopilotProviderType { + return PerplexityProvider.type; + } + + getCapabilities(): CopilotCapability[] { + return PerplexityProvider.capabilities; + } + + async isModelAvailable(model: string): Promise { + return this.availableModels.includes(model); + } + + async generateText( + messages: PromptMessage[], + model: string = 'llama-3.1-sonar-small-128k-online', + options: CopilotChatOptions = {} + ): Promise { + await this.checkParams({ messages, model, options }); + try { + metrics.ai.counter('chat_text_calls').add(1, { model }); + const sMessages = messages + .map(({ content, role }) => ({ content, role })) + .filter(({ content }) => typeof content === 'string'); + + const params = { + method: 'POST', + headers: { + Authorization: `Bearer ${this.config.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model, + messages: sMessages, + max_tokens: options.maxTokens || 4096, + }), + }; + const response = await fetch( + 'https://api.perplexity.ai/chat/completions', + params + ); + const json: PerplexityData = await response.json(); + return injectCitations(json.choices[0].message.content, json.citations); + } catch (e: any) { + metrics.ai.counter('chat_text_errors').add(1, { model }); + throw this.handleError(e); + } + } + + async *generateTextStream( + messages: PromptMessage[], + model: string = 'llama-3.1-sonar-small-128k-online', + options: CopilotChatOptions = {} + ): AsyncIterable { + await this.checkParams({ messages, model, options }); + try { + metrics.ai.counter('chat_text_stream_calls').add(1, { model }); + const sMessages = messages + .map(({ content, role }) => ({ content, role })) + .filter(({ content }) => typeof content === 'string'); + + const params = { + method: 'POST', + headers: { + Authorization: `Bearer ${this.config.apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model, + messages: sMessages, + max_tokens: options.maxTokens || 4096, + stream: true, + }), + }; + const response = await fetch( + 'https://api.perplexity.ai/chat/completions', + params + ); + if (response.body) { + const eventStream = response.body + .pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()) + .pipeThrough( + new TransformStream({ + transform(chunk, controller) { + if (options.signal?.aborted) { + controller.enqueue(null); + return; + } + const data = JSON.parse(chunk.data) as PerplexityData | null; + if (data?.choices?.[0]?.delta?.content) { + const content = injectCitations( + data.choices[0].delta.content, + data.citations || [] + ); + controller.enqueue(content); + } + }, + flush(controller) { + controller.enqueue(null); + }, + }) + ); + + const reader = eventStream.getReader(); + while (true) { + const { done, value } = await reader.read(); + if (done) break; + yield value; + } + } else { + const result = await this.generateText(messages, model, options); + yield result; + } + } catch (e) { + metrics.ai.counter('chat_text_stream_errors').add(1, { model }); + throw e; + } + } + + protected async checkParams({ + model, + }: { + messages?: PromptMessage[]; + embeddings?: string[]; + model: string; + options: CopilotChatOptions; + }) { + if (!(await this.isModelAvailable(model))) { + throw new CopilotPromptInvalid(`Invalid model: ${model}`); + } + } + + private handleError(e: any) { + return new CopilotProviderSideError({ + provider: this.type, + kind: 'unexpected_response', + message: e?.message || 'Unexpected perplexity response', + }); + } +} diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index 9d0429b678656..a4dad5040b2bf 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -22,6 +22,7 @@ import GraphQLUpload from 'graphql-upload/GraphQLUpload.mjs'; import { CallMetric, CopilotFailedToCreateMessage, + CopilotSessionNotFound, FileUpload, RequestMutex, Throttle, @@ -62,6 +63,17 @@ class CreateChatSessionInput { promptName!: string; } +@InputType() +class UpdateChatSessionInput { + @Field(() => String) + sessionId!: string; + + @Field(() => String, { + description: 'The prompt name to use for the session', + }) + promptName!: string; +} + @InputType() class ForkChatSessionInput { @Field(() => String) @@ -372,6 +384,41 @@ export class CopilotResolver { }); } + @Mutation(() => String, { + description: 'Update a chat session', + }) + @CallMetric('ai', 'chat_session_update') + async updateCopilotSession( + @CurrentUser() user: CurrentUser, + @Args({ name: 'options', type: () => UpdateChatSessionInput }) + options: UpdateChatSessionInput + ) { + const session = await this.chatSession.get(options.sessionId); + if (!session) { + throw new CopilotSessionNotFound(); + } + const { workspaceId, docId, parentSessionId } = session.config; + await this.permissions.checkCloudPagePermission( + workspaceId, + docId, + user.id + ); + const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${workspaceId}`; + await using lock = await this.mutex.acquire(lockFlag); + if (!lock) { + return new TooManyRequest('Server is busy'); + } + + await this.chatSession.checkQuota(user.id); + return await this.chatSession.update({ + ...options, + parentSessionId, + workspaceId, + docId, + userId: user.id, + }); + } + @Mutation(() => String, { description: 'Create a chat session', }) diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 3e6414828f7f9..c6175bdb43fbd 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -23,6 +23,7 @@ import { ChatSessionForkOptions, ChatSessionOptions, ChatSessionState, + ChatSessionUpdateOptions, getTokenEncoder, ListHistoriesOptions, PromptMessage, @@ -46,13 +47,22 @@ export class ChatSession implements AsyncDisposable { get config() { const { sessionId, + parentSessionId, userId, workspaceId, docId, prompt: { name: promptName, config: promptConfig }, } = this.state; - return { sessionId, userId, workspaceId, docId, promptName, promptConfig }; + return { + sessionId, + parentSessionId, + userId, + workspaceId, + docId, + promptName, + promptConfig, + }; } get stashMessages() { @@ -198,6 +208,19 @@ export class ChatSessionService { private readonly prompt: PromptService ) {} + private async haveSession(sessionId: string, userId: string) { + return await this.db.$transaction(async tx => { + return await tx.aiSession + .count({ + where: { + id: sessionId, + userId, + }, + }) + .then(c => c > 0); + }); + } + private async setSession(state: ChatSessionState): Promise { return await this.db.$transaction(async tx => { let sessionId = state.sessionId; @@ -226,15 +249,7 @@ export class ChatSessionService { if (id) sessionId = id; } - const haveSession = await tx.aiSession - .count({ - where: { - id: sessionId, - userId: state.userId, - }, - }) - .then(c => c > 0); - + const haveSession = await this.haveSession(sessionId, state.userId); if (haveSession) { // message will only exists when setSession call by session.save if (state.messages.length) { @@ -280,6 +295,22 @@ export class ChatSessionService { }); } + private async updateSession(state: ChatSessionState): Promise { + return await this.db.$transaction(async tx => { + let sessionId = state.sessionId; + const haveSession = await this.haveSession(sessionId, state.userId); + if (haveSession) { + await tx.aiSession.update({ + where: { id: sessionId }, + data: { + promptName: state.prompt.name, + }, + }); + } + return sessionId; + }); + } + private async getSession( sessionId: string ): Promise { @@ -570,6 +601,19 @@ export class ChatSessionService { }); } + async update(options: ChatSessionUpdateOptions): Promise { + const prompt = await this.prompt.get(options.promptName); + if (!prompt) { + this.logger.error(`Prompt not found: ${options.promptName}`); + throw new CopilotPromptNotFound({ name: options.promptName }); + } + return await this.updateSession({ + ...options, + prompt, + messages: [], + }); + } + async fork(options: ChatSessionForkOptions): Promise { const state = await this.getSession(options.sessionId); if (!state) { diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index 679377f16a3a2..25aa8a77af601 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -123,6 +123,15 @@ export interface ChatSessionOptions { promptName: string; } +export interface ChatSessionUpdateOptions { + sessionId: string; + parentSessionId: string | null; + userId: string; + workspaceId: string; + docId: string; + promptName: string; +} + export interface ChatSessionForkOptions extends Omit { sessionId: string; @@ -154,6 +163,7 @@ export type ListHistoriesOptions = { export enum CopilotProviderType { FAL = 'fal', OpenAI = 'openai', + Perplexity = 'perplexity', // only for test Test = 'test', } diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 3ec9dee256c99..a5c168609aa98 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -551,6 +551,9 @@ type Mutation { """Update a copilot prompt""" updateCopilotPrompt(messages: [CopilotPromptMessageInput!]!, name: String!): CopilotPromptType! + + """Update a chat session""" + updateCopilotSession(options: UpdateChatSessionInput!): String! updateProfile(input: UpdateUserInput!): UserType! """update server runtime configurable setting""" @@ -865,6 +868,12 @@ type UnsupportedSubscriptionPlanDataType { plan: String! } +input UpdateChatSessionInput { + """The prompt name to use for the session""" + promptName: String! + sessionId: String! +} + input UpdateUserInput { """User name""" name: String diff --git a/packages/backend/server/tests/copilot-provider.spec.ts b/packages/backend/server/tests/copilot-provider.spec.ts index c22493f987da1..e17d8b4b86fff 100644 --- a/packages/backend/server/tests/copilot-provider.spec.ts +++ b/packages/backend/server/tests/copilot-provider.spec.ts @@ -13,6 +13,7 @@ import { CopilotProviderService, FalProvider, OpenAIProvider, + PerplexityProvider, registerCopilotProvider, unregisterCopilotProvider, } from '../src/plugins/copilot/providers'; @@ -47,8 +48,10 @@ const test = ava as TestFn; const isCopilotConfigured = !!process.env.COPILOT_OPENAI_API_KEY && !!process.env.COPILOT_FAL_API_KEY && + !!process.env.COPILOT_PERPLEXITY_API_KEY && process.env.COPILOT_OPENAI_API_KEY !== '1' && - process.env.COPILOT_FAL_API_KEY !== '1'; + process.env.COPILOT_FAL_API_KEY !== '1' && + process.env.COPILOT_PERPLEXITY_API_KEY !== '1'; const runIfCopilotConfigured = test.macro( async ( t, @@ -75,6 +78,9 @@ test.serial.before(async t => { fal: { apiKey: process.env.COPILOT_FAL_API_KEY, }, + perplexity: { + apiKey: process.env.COPILOT_PERPLEXITY_API_KEY, + }, }, }, }), @@ -111,6 +117,7 @@ test.serial.before(async t => { registerCopilotProvider(OpenAIProvider); registerCopilotProvider(FalProvider); + registerCopilotProvider(PerplexityProvider); for (const name of await prompt.listNames()) { await prompt.delete(name); @@ -124,6 +131,7 @@ test.serial.before(async t => { test.after(async _ => { unregisterCopilotProvider(OpenAIProvider.type); unregisterCopilotProvider(FalProvider.type); + unregisterCopilotProvider(PerplexityProvider.type); }); test.after(async t => { @@ -152,7 +160,6 @@ const checkMDList = (text: string) => { return false; } - // eslint-disable-next-line @typescript-eslint/no-non-null-asserted-optional-chain const currentIndent = line.match(/^( *)/)?.[0].length!; if (Number.isNaN(currentIndent) || currentIndent % 2 !== 0) { return false; @@ -282,6 +289,8 @@ const actions = [ 'Make it longer', 'Make it shorter', 'Continue writing', + 'Chat With AFFiNE AI', + 'Search With AFFiNE AI', ], messages: [{ role: 'user' as const, content: TestAssets.SSOT }], verifier: (t: ExecutionContext, result: string) => { diff --git a/packages/backend/server/tests/copilot.e2e.ts b/packages/backend/server/tests/copilot.e2e.ts index 49262127d9a98..71a523b31460a 100644 --- a/packages/backend/server/tests/copilot.e2e.ts +++ b/packages/backend/server/tests/copilot.e2e.ts @@ -16,6 +16,7 @@ import { CopilotProviderService, FalProvider, OpenAIProvider, + PerplexityProvider, registerCopilotProvider, unregisterCopilotProvider, } from '../src/plugins/copilot/providers'; @@ -41,6 +42,7 @@ import { sse2array, textToEventStream, unsplashSearch, + updateCopilotSession, } from './utils/copilot'; const test = ava as TestFn<{ @@ -63,6 +65,9 @@ test.beforeEach(async t => { fal: { apiKey: '1', }, + perplexity: { + apiKey: '1', + }, unsplashKey: process.env.UNSPLASH_ACCESS_KEY || '1', }, }, @@ -91,6 +96,7 @@ test.beforeEach(async t => { unregisterCopilotProvider(OpenAIProvider.type); unregisterCopilotProvider(FalProvider.type); + unregisterCopilotProvider(PerplexityProvider.type); registerCopilotProvider(MockCopilotTestProvider); await prompt.set(promptName, 'test', [ @@ -156,6 +162,85 @@ test('should create session correctly', async t => { } }); +test('should update session correctly', async t => { + const { app } = t.context; + + const assertUpdateSession = async ( + sessionId: string, + error: string, + asserter = async (x: any) => { + t.truthy(await x, error); + } + ) => { + await asserter(updateCopilotSession(app, token, sessionId, promptName)); + }; + + { + const { id: workspaceId } = await createWorkspace(app, token); + const docId = randomUUID(); + const sessionId = await createCopilotSession( + app, + token, + workspaceId, + docId, + promptName + ); + await assertUpdateSession( + sessionId, + 'should be able to update session with cloud workspace that user can access' + ); + } + + { + const sessionId = await createCopilotSession( + app, + token, + randomUUID(), + randomUUID(), + promptName + ); + await assertUpdateSession( + sessionId, + 'should be able to update session with local workspace' + ); + } + + { + const aToken = (await signUp(app, 'test', 'test@affine.pro', '123456')) + .token.token; + const { id: workspaceId } = await createWorkspace(app, aToken); + const inviteId = await inviteUser( + app, + aToken, + workspaceId, + 'darksky@affine.pro' + ); + await acceptInviteById(app, workspaceId, inviteId, false); + const sessionId = await createCopilotSession( + app, + token, + workspaceId, + randomUUID(), + promptName + ); + await assertUpdateSession( + sessionId, + 'should able to update session after user have permission' + ); + } + + { + const sessionId = '123456'; + await assertUpdateSession(sessionId, '', async x => { + await t.throwsAsync( + x, + { instanceOf: Error }, + 'should not able to update invalid session id' + ); + }); + } +}); + test('should fork session correctly', async t => { const { app } = t.context; diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 9372fc893a368..f4f0774f5ba22 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -68,7 +68,10 @@ test.beforeEach(async t => { apiKey: process.env.COPILOT_OPENAI_API_KEY ?? '1', }, fal: { - apiKey: '1', + apiKey: process.env.COPILOT_FAL_API_KEY ?? '1', + }, + perplexity: { + apiKey: process.env.COPILOT_PERPLEXITY_API_KEY ?? '1', }, }, }, @@ -274,6 +277,46 @@ test('should be able to manage chat session', async t => { } }); +test('should be able to update chat session', async t => { + const { prompt, session } = t.context; + + // Set up a prompt to be used in the session + await prompt.set('prompt', 'model', [ + { role: 'system', content: 'hello {{word}}' }, + ]); + + const commonParams = { + docId: 'test', + workspaceId: 'test', + parentSessionId: null, + userId, + }; + + // Create a session + const sessionId = await session.create({ + promptName: 'prompt', + ...commonParams, + }); + t.truthy(sessionId, 'should create session'); + + // Update the session + const updatedSessionId = await session.update({ + sessionId, + promptName: 'Search With AFFiNE AI', + ...commonParams, + }); + t.is(updatedSessionId, sessionId, 'should update session with same id'); + + // Verify the session was updated + const updatedSession = await session.get(sessionId); + t.truthy(updatedSession, 'should retrieve updated session'); + t.is( + updatedSession?.config.promptName, + 'Search With AFFiNE AI', + 'should have updated prompt name' + ); +}); + test('should be able to fork chat session', async t => { const { auth, prompt, session } = t.context; diff --git a/packages/backend/server/tests/perplexity.spec.ts b/packages/backend/server/tests/perplexity.spec.ts new file mode 100644 index 0000000000000..7c534399942eb --- /dev/null +++ b/packages/backend/server/tests/perplexity.spec.ts @@ -0,0 +1,70 @@ +/// + +import { ReadableStreamDefaultReader } from 'node:stream/web'; + +import ava, { TestFn } from 'ava'; +import Sinon from 'sinon'; + +import { + injectCitations, + PerplexityProvider, +} from '../src/plugins/copilot/providers/perplexity'; + +const test = ava as TestFn<{ + provider: PerplexityProvider; + mockReader: Sinon.SinonStubbedInstance< + ReadableStreamDefaultReader + >; + loggerSpy: Sinon.SinonSpy; +}>; + +test.beforeEach(t => { + const provider = new PerplexityProvider({ apiKey: 'test-api-key' }); + const mockReader = { + read: Sinon.stub(), + releaseLock: Sinon.stub(), + cancel: Sinon.stub(), + } as unknown as Sinon.SinonStubbedInstance< + ReadableStreamDefaultReader + >; + + // Spy on the logger + const loggerSpy = Sinon.spy(provider['logger'], 'warn'); + + t.context = { provider, mockReader, loggerSpy }; +}); + +test.afterEach.always(t => { + // Restore the original logger method + t.context.loggerSpy.restore(); +}); + +test('injectCitations should replace citation placeholders with URLs', t => { + const content = + 'This is [a] test sentence with citations [1] and [[2]] and [3].'; + const citations = [ + 'https://example.com/citation1', + 'https://example.com/citation2', + ]; + + const expected = + 'This is [a] test sentence with citations [[1](https://example.com/citation1)] and [[2](https://example.com/citation2)] and [3].'; + const result = injectCitations(content, citations); + + t.is(result, expected); +}); + +test('injectCitations should not replace citation already with URLs', t => { + const content = + 'Test sentence with citations [1](https://example.com/citationx) and [[2]](https://example.com/citationy) and [[3](https://example.com/citationz)].'; + const citations = [ + 'https://example.com/citation1', + 'https://example.com/citation2', + 'https://example.com/citation3', + ]; + + const expected = content; + const result = injectCitations(content, citations); + + t.is(result, expected); +}); diff --git a/packages/backend/server/tests/utils/copilot.ts b/packages/backend/server/tests/utils/copilot.ts index 07bce549a2535..e88971ce30e70 100644 --- a/packages/backend/server/tests/utils/copilot.ts +++ b/packages/backend/server/tests/utils/copilot.ts @@ -184,6 +184,31 @@ export async function createCopilotSession( return res.body.data.createCopilotSession; } +export async function updateCopilotSession( + app: INestApplication, + userToken: string, + sessionId: string, + promptName: string +): Promise { + const res = await request(app.getHttpServer()) + .post(gql) + .auth(userToken, { type: 'bearer' }) + .set({ 'x-request-id': 'test', 'x-operation-name': 'test' }) + .send({ + query: ` + mutation updateCopilotSession($options: UpdateChatSessionInput!) { + updateCopilotSession(options: $options) + } + `, + variables: { options: { sessionId, promptName } }, + }) + .expect(200); + + handleGraphQLError(res); + + return res.body.data.updateCopilotSession; +} + export async function forkCopilotSession( app: INestApplication, userToken: string, diff --git a/packages/frontend/core/src/blocksuite/presets/ai/ai-panel.ts b/packages/frontend/core/src/blocksuite/presets/ai/ai-panel.ts index 9e77c66c1339f..b232b8fe99164 100644 --- a/packages/frontend/core/src/blocksuite/presets/ai/ai-panel.ts +++ b/packages/frontend/core/src/blocksuite/presets/ai/ai-panel.ts @@ -1,3 +1,4 @@ +import { AINetworkSearchService } from '@affine/core/modules/ai-button/services/network-search'; import type { EditorHost } from '@blocksuite/affine/block-std'; import { type AffineAIPanelWidget, @@ -9,6 +10,7 @@ import { NoteDisplayMode, } from '@blocksuite/affine/blocks'; import { assertExists, Bound } from '@blocksuite/affine/global/utils'; +import type { FrameworkProvider } from '@toeverything/infra'; import type { TemplateResult } from 'lit'; import { createTextRenderer, insertFromMarkdown } from '../_common'; @@ -287,14 +289,21 @@ export function buildCopyConfig(panel: AffineAIPanelWidget) { } export function buildAIPanelConfig( - panel: AffineAIPanelWidget + panel: AffineAIPanelWidget, + framework: FrameworkProvider ): AffineAIPanelWidgetConfig { const ctx = new AIContext(); + const searchService = framework.get(AINetworkSearchService); return { answerRenderer: createTextRenderer(panel.host, { maxHeight: 320 }), finishStateConfig: buildFinishConfig(panel, 'chat', ctx), generatingStateConfig: buildGeneratingConfig(), errorStateConfig: buildErrorConfig(panel), copy: buildCopyConfig(panel), + networkSearchConfig: { + visible: searchService.visible, + enabled: searchService.enabled, + setEnabled: searchService.setEnabled, + }, }; } diff --git a/packages/frontend/core/src/blocksuite/presets/ai/ai-spec.ts b/packages/frontend/core/src/blocksuite/presets/ai/ai-spec.ts index b672cf1e3ab30..6ad115b1ebfa1 100644 --- a/packages/frontend/core/src/blocksuite/presets/ai/ai-spec.ts +++ b/packages/frontend/core/src/blocksuite/presets/ai/ai-spec.ts @@ -23,6 +23,7 @@ import { } from '@blocksuite/affine/blocks'; import { assertInstanceOf } from '@blocksuite/affine/global/utils'; import type { ExtensionType } from '@blocksuite/affine/store'; +import type { FrameworkProvider } from '@toeverything/infra'; import { literal, unsafeStatic } from 'lit/static-html.js'; import { buildAIPanelConfig } from './ai-panel'; @@ -36,96 +37,110 @@ import { setupImageToolbarAIEntry } from './entries/image-toolbar/setup-image-to import { setupSlashMenuAIEntry } from './entries/slash-menu/setup-slash-menu'; import { setupSpaceAIEntry } from './entries/space/setup-space'; -class AIPageRootWatcher extends BlockServiceWatcher { - static override readonly flavour = 'affine:page'; - - override mounted() { - super.mounted(); - this.blockService.specSlots.widgetConnected.on(view => { - if (view.component instanceof AffineAIPanelWidget) { - view.component.style.width = '630px'; - view.component.config = buildAIPanelConfig(view.component); - setupSpaceAIEntry(view.component); - } - - if (view.component instanceof AffineFormatBarWidget) { - setupFormatBarAIEntry(view.component); - } - - if (view.component instanceof AffineSlashMenuWidget) { - setupSlashMenuAIEntry(view.component); - } - }); +function getAIPageRootWatcher(framework: FrameworkProvider) { + class AIPageRootWatcher extends BlockServiceWatcher { + static override readonly flavour = 'affine:page'; + + override mounted() { + super.mounted(); + this.blockService.specSlots.widgetConnected.on(view => { + if (view.component instanceof AffineAIPanelWidget) { + view.component.style.width = '630px'; + view.component.config = buildAIPanelConfig(view.component, framework); + setupSpaceAIEntry(view.component); + } + + if (view.component instanceof AffineFormatBarWidget) { + setupFormatBarAIEntry(view.component); + } + + if (view.component instanceof AffineSlashMenuWidget) { + setupSlashMenuAIEntry(view.component); + } + }); + } } + return AIPageRootWatcher; } -export const AIPageRootBlockSpec: ExtensionType[] = [ - ...PageRootBlockSpec, - AIPageRootWatcher, - { - setup: di => { - di.override(WidgetViewMapIdentifier('affine:page'), () => { - return { - ...pageRootWidgetViewMap, - [AFFINE_AI_PANEL_WIDGET]: literal`${unsafeStatic( - AFFINE_AI_PANEL_WIDGET - )}`, - }; - }); +export function createAIPageRootBlockSpec( + framework: FrameworkProvider +): ExtensionType[] { + return [ + ...PageRootBlockSpec, + getAIPageRootWatcher(framework), + { + setup: di => { + di.override(WidgetViewMapIdentifier('affine:page'), () => { + return { + ...pageRootWidgetViewMap, + [AFFINE_AI_PANEL_WIDGET]: literal`${unsafeStatic( + AFFINE_AI_PANEL_WIDGET + )}`, + }; + }); + }, }, - }, -]; - -class AIEdgelessRootWatcher extends BlockServiceWatcher { - static override readonly flavour = 'affine:page'; - - override mounted() { - super.mounted(); - this.blockService.specSlots.widgetConnected.on(view => { - if (view.component instanceof AffineAIPanelWidget) { - view.component.style.width = '430px'; - view.component.config = buildAIPanelConfig(view.component); - setupSpaceAIEntry(view.component); - } - - if (view.component instanceof EdgelessCopilotWidget) { - setupEdgelessCopilot(view.component); - } - - if (view.component instanceof EdgelessElementToolbarWidget) { - setupEdgelessElementToolbarAIEntry(view.component); - } - - if (view.component instanceof AffineFormatBarWidget) { - setupFormatBarAIEntry(view.component); - } + ]; +} - if (view.component instanceof AffineSlashMenuWidget) { - setupSlashMenuAIEntry(view.component); - } - }); +function getAIEdgelessRootWatcher(framework: FrameworkProvider) { + class AIEdgelessRootWatcher extends BlockServiceWatcher { + static override readonly flavour = 'affine:page'; + + override mounted() { + super.mounted(); + this.blockService.specSlots.widgetConnected.on(view => { + if (view.component instanceof AffineAIPanelWidget) { + view.component.style.width = '430px'; + view.component.config = buildAIPanelConfig(view.component, framework); + setupSpaceAIEntry(view.component); + } + + if (view.component instanceof EdgelessCopilotWidget) { + setupEdgelessCopilot(view.component); + } + + if (view.component instanceof EdgelessElementToolbarWidget) { + setupEdgelessElementToolbarAIEntry(view.component); + } + + if (view.component instanceof AffineFormatBarWidget) { + setupFormatBarAIEntry(view.component); + } + + if (view.component instanceof AffineSlashMenuWidget) { + setupSlashMenuAIEntry(view.component); + } + }); + } } + return AIEdgelessRootWatcher; } -export const AIEdgelessRootBlockSpec: ExtensionType[] = [ - ...EdgelessRootBlockSpec, - AIEdgelessRootWatcher, - { - setup: di => { - di.override(WidgetViewMapIdentifier('affine:page'), () => { - return { - ...edgelessRootWidgetViewMap, - [AFFINE_EDGELESS_COPILOT_WIDGET]: literal`${unsafeStatic( - AFFINE_EDGELESS_COPILOT_WIDGET - )}`, - [AFFINE_AI_PANEL_WIDGET]: literal`${unsafeStatic( - AFFINE_AI_PANEL_WIDGET - )}`, - }; - }); +export function createAIEdgelessRootBlockSpec( + framework: FrameworkProvider +): ExtensionType[] { + return [ + ...EdgelessRootBlockSpec, + getAIEdgelessRootWatcher(framework), + { + setup: di => { + di.override(WidgetViewMapIdentifier('affine:page'), () => { + return { + ...edgelessRootWidgetViewMap, + [AFFINE_EDGELESS_COPILOT_WIDGET]: literal`${unsafeStatic( + AFFINE_EDGELESS_COPILOT_WIDGET + )}`, + [AFFINE_AI_PANEL_WIDGET]: literal`${unsafeStatic( + AFFINE_AI_PANEL_WIDGET + )}`, + }; + }); + }, }, - }, -]; + ]; +} class AIParagraphBlockWatcher extends BlockServiceWatcher { static override readonly flavour = 'affine:paragraph'; diff --git a/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts b/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts index 7a80e414bb762..e4d6ba2da1989 100644 --- a/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts +++ b/packages/frontend/core/src/blocksuite/presets/ai/chat-panel/chat-panel-input.ts @@ -1,6 +1,14 @@ +import { stopPropagation } from '@affine/core/utils'; import type { EditorHost } from '@blocksuite/affine/block-std'; import { type AIError, openFileOrFiles } from '@blocksuite/affine/blocks'; -import { assertExists, WithDisposable } from '@blocksuite/affine/global/utils'; +import { + assertExists, + SignalWatcher, + WithDisposable, +} from '@blocksuite/affine/global/utils'; +import { unsafeCSSVarV2 } from '@blocksuite/affine-shared/theme'; +import { ImageIcon, PublishIcon } from '@blocksuite/icons/lit'; +import type { Signal } from '@preact/signals-core'; import { css, html, LitElement, nothing } from 'lit'; import { property, query, state } from 'lit/decorators.js'; import { repeat } from 'lit/directives/repeat.js'; @@ -10,7 +18,6 @@ import { ChatClearIcon, ChatSendIcon, CloseIcon, - ImageIcon, } from '../_common/icons'; import { AIProvider } from '../provider'; import { reportResponse } from '../utils/action-reporter'; @@ -24,7 +31,13 @@ function getFirstTwoLines(text: string) { return lines.slice(0, 2); } -export class ChatPanelInput extends WithDisposable(LitElement) { +export interface AINetworkSearchConfig { + visible: Signal; + enabled: Signal; + setEnabled: (state: boolean) => void; +} + +export class ChatPanelInput extends SignalWatcher(WithDisposable(LitElement)) { static override styles = css` .chat-panel-input { display: flex; @@ -104,10 +117,28 @@ export class ChatPanelInput extends WithDisposable(LitElement) { margin-left: auto; } - .image-upload { + .image-upload, + .chat-network-search { display: flex; justify-content: center; align-items: center; + svg { + width: 20px; + height: 20px; + color: ${unsafeCSSVarV2('icon/primary')}; + } + } + .chat-network-search[data-active='true'] svg { + color: ${unsafeCSSVarV2('icon/activated')}; + } + + .image-upload[aria-disabled='true'], + .chat-network-search[aria-disabled='true'] { + cursor: not-allowed; + } + .image-upload[aria-disabled='true'] svg, + .chat-network-search[aria-disabled='true'] svg { + color: var(--affine-text-disable-color) !important; } } @@ -235,6 +266,9 @@ export class ChatPanelInput extends WithDisposable(LitElement) { @property({ attribute: false }) accessor cleanupHistories!: () => Promise; + @property({ attribute: false }) + accessor networkSearchConfig!: AINetworkSearchConfig; + private _addImages(images: File[]) { const oldImages = this.chatContextValue.images; this.updateContext({ @@ -296,6 +330,23 @@ export class ChatPanelInput extends WithDisposable(LitElement) { `; } + private readonly _toggleNetworkSearch = (e: MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + + const enable = this.networkSearchConfig.enabled.value; + this.networkSearchConfig.setEnabled(!enable); + }; + + private readonly _uploadImageFiles = async (_e: MouseEvent) => { + const images = await openFileOrFiles({ + acceptType: 'Images', + multiple: true, + }); + if (!images) return; + this._addImages(images); + }; + override connectedCallback() { super.connectedCallback(); @@ -305,7 +356,7 @@ export class ChatPanelInput extends WithDisposable(LitElement) { if (this.host === host) { context && this.updateContext(context); await this.updateComplete; - await this.send(input); + input && (await this.send(input)); } } ) @@ -316,7 +367,9 @@ export class ChatPanelInput extends WithDisposable(LitElement) { const { images, status } = this.chatContextValue; const hasImages = images.length > 0; const maxHeight = hasImages ? 272 + 2 : 200 + 2; - + const networkDisabled = !!this.chatContextValue.images.length; + const networkActive = !!this.networkSearchConfig.enabled.value; + const uploadDisabled = networkActive && !networkDisabled; return html`