From 11f9edd59baee65f271fdf8623baa93e0958469b Mon Sep 17 00:00:00 2001 From: Bradley Axen Date: Wed, 26 Feb 2025 17:35:17 +0100 Subject: [PATCH] draft: use rust messages in typescript --- crates/goose-server/src/routes/reply.rs | 385 +++----------- ui/desktop/src/components/ChatView.tsx | 118 ++--- ui/desktop/src/components/GooseMessage.tsx | 79 ++- .../src/components/GooseResponseForm.tsx | 21 +- ui/desktop/src/components/Splash.tsx | 1 - ui/desktop/src/components/SplashPills.tsx | 7 +- ui/desktop/src/components/ToolInvocations.tsx | 60 +-- ui/desktop/src/components/UserMessage.tsx | 16 +- ui/desktop/src/hooks/useMessageStream.ts | 483 ++++++++++++++++++ ui/desktop/src/types/message.ts | 199 ++++++++ ui/desktop/src/utils/generateId.ts | 7 + 11 files changed, 935 insertions(+), 441 deletions(-) create mode 100644 ui/desktop/src/hooks/useMessageStream.ts create mode 100644 ui/desktop/src/types/message.ts create mode 100644 ui/desktop/src/utils/generateId.ts diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 03166c173..131108131 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -11,7 +11,7 @@ use futures::{stream::StreamExt, Stream}; use goose::message::{Message, MessageContent}; use mcp_core::{content::Content, role::Role}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::{ convert::Infallible, @@ -23,33 +23,13 @@ use tokio::sync::mpsc; use tokio::time::timeout; use tokio_stream::wrappers::ReceiverStream; -// Types matching the incoming JSON structure +// Direct message serialization for the chat request #[derive(Debug, Deserialize)] struct ChatRequest { - messages: Vec, + messages: Vec, } -#[derive(Debug, Deserialize)] -struct IncomingMessage { - role: String, - content: String, - #[serde(default)] - #[serde(rename = "toolInvocations")] - tool_invocations: Vec, -} - -#[derive(Debug, Deserialize)] -struct ToolInvocation { - state: String, - #[serde(rename = "toolCallId")] - tool_call_id: String, - #[serde(rename = "toolName")] - tool_name: String, - args: Value, - result: Option>, -} - -// Custom SSE response type that implements the Vercel AI SDK protocol +// Custom SSE response type for streaming messages pub struct SseResponse { rx: ReceiverStream, } @@ -79,188 +59,32 @@ impl IntoResponse for SseResponse { .header("Content-Type", "text/event-stream") .header("Cache-Control", "no-cache") .header("Connection", "keep-alive") - .header("x-vercel-ai-data-stream", "v1") .body(body) .unwrap() } } -// Convert incoming messages to our internal Message type -fn convert_messages(incoming: Vec) -> Vec { - let mut messages = Vec::new(); - - for msg in incoming { - match msg.role.as_str() { - "user" => { - messages.push(Message::user().with_text(msg.content)); - } - "assistant" => { - // First handle any tool invocations - each represents a complete request/response cycle - for tool in msg.tool_invocations { - if tool.state == "result" { - // Add the original tool request from assistant - let tool_call = mcp_core::tool::ToolCall { - name: tool.tool_name, - arguments: tool.args, - }; - messages.push( - Message::assistant() - .with_tool_request(tool.tool_call_id.clone(), Ok(tool_call)), - ); - - // Add the tool response from user - if let Some(result) = &tool.result { - messages.push( - Message::user() - .with_tool_response(tool.tool_call_id, Ok(result.clone())), - ); - } - } - } - - // Then add the assistant's text response after tool interactions - if !msg.content.is_empty() { - messages.push(Message::assistant().with_text(msg.content)); - } - } - _ => { - tracing::warn!("Unknown role: {}", msg.role); - } - } - } - - messages -} - -// Protocol-specific message formatting -struct ProtocolFormatter; - -impl ProtocolFormatter { - fn format_text(text: &str) -> String { - let encoded_text = serde_json::to_string(text).unwrap_or_else(|_| String::new()); - format!("0:{}\n", encoded_text) - } - - fn format_tool_call(id: &str, name: &str, args: &Value) -> String { - // Tool calls start with "9:" - let tool_call = json!({ - "toolCallId": id, - "toolName": name, - "args": args - }); - format!("9:{}\n", tool_call) - } - - fn format_tool_response(id: &str, result: &Vec) -> String { - // Tool responses start with "a:" - let response = json!({ - "toolCallId": id, - "result": result, - }); - format!("a:{}\n", response) - } - - fn format_error(error: &str) -> String { - // Error messages start with "3:" in the new protocol. - let encoded_error = serde_json::to_string(error).unwrap_or_else(|_| String::new()); - format!("3:{}\n", encoded_error) - } - - fn format_finish(reason: &str) -> String { - // Finish messages start with "d:" - let finish = json!({ - "finishReason": reason, - "usage": { - "promptTokens": 0, - "completionTokens": 0 - } - }); - format!("d:{}\n", finish) - } +// Message event types for SSE streaming +#[derive(Debug, Serialize)] +#[serde(tag = "type")] +enum MessageEvent { + Message { message: Message }, + Error { error: String }, + Finish { reason: String }, } -async fn stream_message( - message: Message, +// Stream a message as an SSE event +async fn stream_event( + event: MessageEvent, tx: &mpsc::Sender, ) -> Result<(), mpsc::error::SendError> { - match message.role { - Role::User => { - // Handle tool responses - for content in message.content { - // I believe with the protocol we aren't intended to pass back user messages, so we only deal with - // the tool responses here - if let MessageContent::ToolResponse(response) = content { - // We should return a result for either an error or a success - match response.tool_result { - Ok(result) => { - tx.send(ProtocolFormatter::format_tool_response( - &response.id, - &result, - )) - .await?; - } - Err(err) => { - // Send an error message first - tx.send(ProtocolFormatter::format_error(&err.to_string())) - .await?; - // Then send an empty tool response to maintain the protocol - let result = - vec![Content::text(format!("Error: {}", err)).with_priority(0.0)]; - tx.send(ProtocolFormatter::format_tool_response( - &response.id, - &result, - )) - .await?; - } - } - } - } - } - Role::Assistant => { - for content in message.content { - match content { - MessageContent::ToolRequest(request) => { - match request.tool_call { - Ok(tool_call) => { - tx.send(ProtocolFormatter::format_tool_call( - &request.id, - &tool_call.name, - &tool_call.arguments, - )) - .await?; - } - Err(err) => { - // Send a placeholder tool call to maintain protocol - tx.send(ProtocolFormatter::format_tool_call( - &request.id, - "invalid_tool", - &json!({"error": err.to_string()}), - )) - .await?; - } - } - } - MessageContent::Text(text) => { - for line in text.text.lines() { - let modified_line = format!("{}\n", line); - tx.send(ProtocolFormatter::format_text(&modified_line)) - .await?; - } - } - MessageContent::ToolConfirmationRequest(_) => { - // skip tool confirmation requests - } - MessageContent::Image(_) => { - // skip images - } - MessageContent::ToolResponse(_) => { - // skip tool responses - } - } - } - } - } - Ok(()) + let json = serde_json::to_string(&event).unwrap_or_else(|e| { + format!( + r#"{{"type":"Error","error":"Failed to serialize event: {}"}}"#, + e + ) + }); + tx.send(format!("data: {}\n\n", json)).await } async fn handler( @@ -278,19 +102,12 @@ async fn handler( return Err(StatusCode::UNAUTHORIZED); } - // Check protocol header (optional in our case) - if let Some(protocol) = headers.get("x-protocol") { - if protocol.to_str().map(|p| p != "data").unwrap_or(true) { - return Err(StatusCode::BAD_REQUEST); - } - } - // Create channel for streaming let (tx, rx) = mpsc::channel(100); let stream = ReceiverStream::new(rx); - // Convert incoming messages - let messages = convert_messages(request.messages); + // Get messages directly from the request + let messages = request.messages; // Get a lock on the shared agent let agent = state.agent.clone(); @@ -301,10 +118,20 @@ async fn handler( let agent = match agent.as_ref() { Some(agent) => agent, None => { - let _ = tx - .send(ProtocolFormatter::format_error("No agent configured")) - .await; - let _ = tx.send(ProtocolFormatter::format_finish("error")).await; + let _ = stream_event( + MessageEvent::Error { + error: "No agent configured".to_string(), + }, + &tx, + ) + .await; + let _ = stream_event( + MessageEvent::Finish { + reason: "error".to_string(), + }, + &tx, + ) + .await; return; } }; @@ -313,10 +140,20 @@ async fn handler( Ok(stream) => stream, Err(e) => { tracing::error!("Failed to start reply stream: {:?}", e); - let _ = tx - .send(ProtocolFormatter::format_error(&e.to_string())) - .await; - let _ = tx.send(ProtocolFormatter::format_finish("error")).await; + let _ = stream_event( + MessageEvent::Error { + error: e.to_string(), + }, + &tx, + ) + .await; + let _ = stream_event( + MessageEvent::Finish { + reason: "error".to_string(), + }, + &tx, + ) + .await; return; } }; @@ -326,25 +163,32 @@ async fn handler( response = timeout(Duration::from_millis(500), stream.next()) => { match response { Ok(Some(Ok(message))) => { - if let Err(e) = stream_message(message, &tx).await { + if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await { tracing::error!("Error sending message through channel: {}", e); - let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await; + let _ = stream_event( + MessageEvent::Error { + error: e.to_string(), + }, + &tx, + ).await; break; } } Ok(Some(Err(e))) => { tracing::error!("Error processing message: {}", e); - let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await; + let _ = stream_event( + MessageEvent::Error { + error: e.to_string(), + }, + &tx, + ).await; break; } Ok(None) => { break; } - Err(_) => { // Heartbeat, used to detect disconnected clients and then end running tools. + Err(_) => { // Heartbeat, used to detect disconnected clients if tx.is_closed() { - // Kill any running processes when the client disconnects - // TODO is this used? I suspect post MCP this is on the server instead - // goose::process_store::kill_processes(); break; } continue; @@ -354,24 +198,30 @@ async fn handler( } } - // Send finish message - let _ = tx.send(ProtocolFormatter::format_finish("stop")).await; + // Send finish event + let _ = stream_event( + MessageEvent::Finish { + reason: "stop".to_string(), + }, + &tx, + ) + .await; }); Ok(SseResponse::new(stream)) } -#[derive(Debug, Deserialize, serde::Serialize)] +#[derive(Debug, Deserialize, Serialize)] struct AskRequest { prompt: String, } -#[derive(Debug, serde::Serialize)] +#[derive(Debug, Serialize)] struct AskResponse { response: String, } -// simple ask an AI for a response, non streaming +// Simple ask an AI for a response, non streaming async fn ask_handler( State(state): State, headers: HeaderMap, @@ -478,85 +328,6 @@ mod tests { } } - #[test] - fn test_convert_messages_user_only() { - let incoming = vec![IncomingMessage { - role: "user".to_string(), - content: "Hello".to_string(), - tool_invocations: vec![], - }]; - - let messages = convert_messages(incoming); - assert_eq!(messages.len(), 1); - assert_eq!(messages[0].role, Role::User); - assert!( - matches!(&messages[0].content[0], MessageContent::Text(text) if text.text == "Hello") - ); - } - - #[test] - fn test_convert_messages_with_tool_invocation() { - let tool_result = vec![Content::text("tool response").with_priority(0.0)]; - let incoming = vec![IncomingMessage { - role: "assistant".to_string(), - content: "".to_string(), - tool_invocations: vec![ToolInvocation { - state: "result".to_string(), - tool_call_id: "123".to_string(), - tool_name: "test_tool".to_string(), - args: json!({"key": "value"}), - result: Some(tool_result.clone()), - }], - }]; - - let messages = convert_messages(incoming); - assert_eq!(messages.len(), 2); // Tool request and response - - // Check tool request - assert_eq!(messages[0].role, Role::Assistant); - assert!( - matches!(&messages[0].content[0], MessageContent::ToolRequest(req) if req.id == "123") - ); - - // Check tool response - assert_eq!(messages[1].role, Role::User); - assert!( - matches!(&messages[1].content[0], MessageContent::ToolResponse(resp) if resp.id == "123") - ); - } - - #[test] - fn test_protocol_formatter() { - // Test text formatting - let text = "Hello world"; - let formatted = ProtocolFormatter::format_text(text); - assert_eq!(formatted, "0:\"Hello world\"\n"); - - // Test tool call formatting - let formatted = - ProtocolFormatter::format_tool_call("123", "test_tool", &json!({"key": "value"})); - assert!(formatted.starts_with("9:")); - assert!(formatted.contains("\"toolCallId\":\"123\"")); - assert!(formatted.contains("\"toolName\":\"test_tool\"")); - - // Test tool response formatting - let result = vec![Content::text("response").with_priority(0.0)]; - let formatted = ProtocolFormatter::format_tool_response("123", &result); - assert!(formatted.starts_with("a:")); - assert!(formatted.contains("\"toolCallId\":\"123\"")); - - // Test error formatting - let formatted = ProtocolFormatter::format_error("Test error"); - println!("Formatted error: {}", formatted); - assert!(formatted.starts_with("3:")); - assert!(formatted.contains("Test error")); - - // Test finish formatting - let formatted = ProtocolFormatter::format_finish("stop"); - assert!(formatted.starts_with("d:")); - assert!(formatted.contains("\"finishReason\":\"stop\"")); - } - mod integration_tests { use super::*; use axum::{body::Body, http::Request}; @@ -575,7 +346,7 @@ mod tests { }); let agent = AgentFactory::create("reference", mock_provider).unwrap(); let state = AppState { - config: Arc::new(Mutex::new(HashMap::new())), // Add this line + config: Arc::new(Mutex::new(HashMap::new())), agent: Arc::new(Mutex::new(Some(agent))), secret_key: "test-secret".to_string(), }; @@ -604,4 +375,4 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); } } -} +} \ No newline at end of file diff --git a/ui/desktop/src/components/ChatView.tsx b/ui/desktop/src/components/ChatView.tsx index 12669ca33..834fc5ef1 100644 --- a/ui/desktop/src/components/ChatView.tsx +++ b/ui/desktop/src/components/ChatView.tsx @@ -1,5 +1,4 @@ import React, { useEffect, useRef, useState } from 'react'; -import { Message, useChat } from '../ai-sdk-fork/useChat'; import { getApiUrl } from '../config'; import BottomMenu from './BottomMenu'; import FlappyGoose from './FlappyGoose'; @@ -14,15 +13,13 @@ import UserMessage from './UserMessage'; import { askAi } from '../utils/askAI'; import Splash from './Splash'; import 'react-toastify/dist/ReactToastify.css'; +import { useMessageStream } from '../hooks/useMessageStream'; +import { Message, createUserMessage, getTextContent } from '../types/message'; export interface ChatType { id: number; title: string; - messages: Array<{ - id: string; - role: 'function' | 'system' | 'user' | 'assistant' | 'data' | 'tool'; - content: string; - }>; + messages: Message[]; } export default function ChatView({ setView }: { setView: (view: View) => void }) { @@ -39,14 +36,27 @@ export default function ChatView({ setView }: { setView: (view: View) => void }) const [showGame, setShowGame] = useState(false); const scrollRef = useRef(null); - const { messages, append, stop, isLoading, error, setMessages } = useChat({ + const { + messages, + append, + stop, + isLoading, + error, + setMessages, + input, + setInput, + handleInputChange, + handleSubmit: submitMessage + } = useMessageStream({ api: getApiUrl('/reply'), initialMessages: chat?.messages || [], - onFinish: async (message, _) => { + onFinish: async (message, reason) => { window.electron.stopPowerSaveBlocker(); - const fetchResponses = await askAi(message.content); - setMessageMetadata((prev) => ({ ...prev, [message.id]: fetchResponses })); + // Extract text content from the message to pass to askAi + const messageText = getTextContent(message); + const fetchResponses = await askAi(messageText); + setMessageMetadata((prev) => ({ ...prev, [message.id || ''] : fetchResponses })); const timeSinceLastInteraction = Date.now() - lastInteractionTime; window.electron.logInfo('last interaction:' + lastInteractionTime); @@ -58,6 +68,11 @@ export default function ChatView({ setView }: { setView: (view: View) => void }) }); } }, + onToolCall: (toolCall) => { + // Handle tool calls if needed + console.log('Tool call received:', toolCall); + // Implement tool call handling logic here + } }); // Update chat messages when they change @@ -78,10 +93,7 @@ export default function ChatView({ setView }: { setView: (view: View) => void }) const content = customEvent.detail?.value || ''; if (content.trim()) { setLastInteractionTime(Date.now()); - append({ - role: 'user', - content, - }); + append(createUserMessage(content)); if (scrollRef.current?.scrollToBottom) { scrollRef.current.scrollToBottom(); } @@ -97,47 +109,38 @@ export default function ChatView({ setView }: { setView: (view: View) => void }) setLastInteractionTime(Date.now()); window.electron.stopPowerSaveBlocker(); - const lastMessage: Message = messages[messages.length - 1]; - if (lastMessage.role === 'user' && lastMessage.toolInvocations === undefined) { - // Remove the last user message. + // Handle stopping the message stream + const lastMessage = messages[messages.length - 1]; + if (lastMessage && lastMessage.role === 'user') { + // Remove the last user message if it's the most recent one if (messages.length > 1) { setMessages(messages.slice(0, -1)); } else { setMessages([]); } - } else if (lastMessage.role === 'assistant' && lastMessage.toolInvocations !== undefined) { - // Add messaging about interrupted ongoing tool invocations - const newLastMessage: Message = { - ...lastMessage, - toolInvocations: lastMessage.toolInvocations.map((invocation) => { - if (invocation.state !== 'result') { - return { - ...invocation, - result: [ - { - audience: ['user'], - text: 'Interrupted.\n', - type: 'text', - }, - { - audience: ['assistant'], - text: 'Interrupted by the user to make a correction.\n', - type: 'text', - }, - ], - state: 'result', - }; - } else { - return invocation; - } - }), - }; - - const updatedMessages = [...messages.slice(0, -1), newLastMessage]; - setMessages(updatedMessages); } + // Note: Tool call interruption handling would need to be implemented + // differently with the new message format }; + // Filter out standalone tool response messages for rendering + // They will be shown as part of the tool invocation in the assistant message + const filteredMessages = messages.filter(message => { + // Keep all assistant messages and user messages that aren't just tool responses + if (message.role === 'assistant') return true; + + // For user messages, check if they're only tool responses + if (message.role === 'user') { + const hasOnlyToolResponses = message.content.every(c => 'ToolResponse' in c); + const hasTextContent = message.content.some(c => 'Text' in c); + + // Keep the message if it has text content or is not just tool responses + return hasTextContent || !hasOnlyToolResponses; + } + + return true; + }); + return (
@@ -145,19 +148,19 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
{messages.length === 0 ? ( - + append(createUserMessage(text))} /> ) : ( - {messages.map((message) => ( -
+ {filteredMessages.map((message, index) => ( +
{message.role === 'user' ? ( ) : ( append(createUserMessage(text))} /> )}
@@ -166,20 +169,17 @@ export default function ChatView({ setView }: { setView: (view: View) => void })
{error.message || 'Honk! Goose experienced an error while responding'} - {error.status && (Status: {error.status})}
{ + // Find the last user message const lastUserMessage = messages.reduceRight( (found, m) => found || (m.role === 'user' ? m : null), - null + null as Message | null ); if (lastUserMessage) { - append({ - role: 'user', - content: lastUserMessage.content, - }); + append(lastUserMessage); } }} > @@ -206,4 +206,4 @@ export default function ChatView({ setView }: { setView: (view: View) => void }) {showGame && setShowGame(false)} />}
); -} +} \ No newline at end of file diff --git a/ui/desktop/src/components/GooseMessage.tsx b/ui/desktop/src/components/GooseMessage.tsx index a19af957c..70d44bbba 100644 --- a/ui/desktop/src/components/GooseMessage.tsx +++ b/ui/desktop/src/components/GooseMessage.tsx @@ -1,41 +1,96 @@ -import React from 'react'; +import React, { useMemo } from 'react'; import ToolInvocations from './ToolInvocations'; import LinkPreview from './LinkPreview'; import GooseResponseForm from './GooseResponseForm'; import { extractUrls } from '../utils/urlUtils'; import MarkdownContent from './MarkdownContent'; +import { Message, getTextContent, getToolRequests, getToolResponses } from '../types/message'; interface GooseMessageProps { - message: any; - messages: any[]; + message: Message; + messages: Message[]; metadata?: any; - append: (value: any) => void; + append: (value: string) => void; } export default function GooseMessage({ message, metadata, messages, append }: GooseMessageProps) { + // Extract text content from the message + const textContent = getTextContent(message); + + // Get tool requests from the message + const toolRequests = getToolRequests(message); + // Extract URLs under a few conditions // 1. The message is purely text // 2. The link wasn't also present in the previous message // 3. The message contains the explicit http:// or https:// protocol at the beginning const messageIndex = messages?.findIndex((msg) => msg.id === message.id); const previousMessage = messageIndex > 0 ? messages[messageIndex - 1] : null; - const previousUrls = previousMessage ? extractUrls(previousMessage.content) : []; - const urls = !message.toolInvocations ? extractUrls(message.content, previousUrls) : []; + const previousUrls = previousMessage ? extractUrls(getTextContent(previousMessage)) : []; + const urls = toolRequests.length === 0 ? extractUrls(textContent, previousUrls) : []; + + // Find tool responses that correspond to the tool requests in this message + const toolResponsesMap = useMemo(() => { + const responseMap = new Map(); + + // Look for tool responses in subsequent messages + if (messageIndex !== undefined && messageIndex >= 0) { + for (let i = messageIndex + 1; i < messages.length; i++) { + const responses = getToolResponses(messages[i]); + + for (const response of responses) { + // Check if this response matches any of our tool requests + const matchingRequest = toolRequests.find((req) => req.id === response.id); + if (matchingRequest) { + responseMap.set(response.id, response); + } + } + } + } + + return responseMap; + }, [messages, messageIndex, toolRequests]); + + // Convert tool requests to the format expected by ToolInvocations + const toolInvocations = useMemo(() => { + const invocations = toolRequests + .map((toolRequest) => { + const toolCall = toolRequest.tool_call.Ok; + + if (!toolCall) { + return null; + } + + const toolResponse = toolResponsesMap.get(toolRequest.id); + + return { + toolCallId: toolRequest.id, + toolName: toolCall.name, + args: toolCall.arguments, + state: toolResponse ? 'result' : 'running', + result: toolResponse?.tool_result?.Ok || undefined, + }; + }) + .filter(Boolean); + + return invocations; + }, [toolRequests, toolResponsesMap]); return (
- {message.content && ( + {/* Always show the top content area if there are tool calls, even if textContent is empty */} + {(textContent || toolInvocations.length > 0) && (
0 ? 'rounded-b-none' : ''}`} > - + {textContent ? : null}
)} - {message.toolInvocations && ( + {toolInvocations.length > 0 && (
- +
)}
@@ -53,7 +108,7 @@ export default function GooseMessage({ message, metadata, messages, append }: Go {/* NOTE from alexhancock on 1/14/2025 - disabling again temporarily due to non-determinism in when the forms show up */} {false && metadata && (
- +
)}
diff --git a/ui/desktop/src/components/GooseResponseForm.tsx b/ui/desktop/src/components/GooseResponseForm.tsx index a5f811c28..b48d6f473 100644 --- a/ui/desktop/src/components/GooseResponseForm.tsx +++ b/ui/desktop/src/components/GooseResponseForm.tsx @@ -3,6 +3,7 @@ import MarkdownContent from './MarkdownContent'; import { Button } from './ui/button'; import { cn } from '../utils'; import { Send } from './icons'; +import { createUserMessage } from '../types/message'; interface FormField { label: string; @@ -21,7 +22,7 @@ interface DynamicForm { interface GooseResponseFormProps { message: string; metadata: any; - append: (value: any) => void; + append: (value: string) => void; } export default function GooseResponseForm({ @@ -103,31 +104,19 @@ export default function GooseResponseForm({ }; const handleAccept = () => { - const message = { - content: 'Yes - go ahead.', - role: 'user', - }; - append(message); + append('Yes - go ahead.'); }; const handleSubmit = () => { if (selectedOption !== null && options[selectedOption]) { - const message = { - content: `Yes - continue with: ${options[selectedOption].optionTitle}`, - role: 'user', - }; - append(message); + append(`Yes - continue with: ${options[selectedOption].optionTitle}`); } }; const handleFormSubmit = (e: React.FormEvent) => { e.preventDefault(); if (dynamicForm) { - const message = { - content: JSON.stringify(formValues), - role: 'user', - }; - append(message); + append(JSON.stringify(formValues)); } }; diff --git a/ui/desktop/src/components/Splash.tsx b/ui/desktop/src/components/Splash.tsx index 7f90135ac..4058fae2d 100644 --- a/ui/desktop/src/components/Splash.tsx +++ b/ui/desktop/src/components/Splash.tsx @@ -1,6 +1,5 @@ import React from 'react'; import SplashPills from './SplashPills'; -import { Goose, Rain } from './icons/Goose'; import GooseLogo from './GooseLogo'; export default function Splash({ append }) { diff --git a/ui/desktop/src/components/SplashPills.tsx b/ui/desktop/src/components/SplashPills.tsx index a98593dc0..d738e4112 100644 --- a/ui/desktop/src/components/SplashPills.tsx +++ b/ui/desktop/src/components/SplashPills.tsx @@ -5,11 +5,8 @@ function SplashPill({ content, append, className = '', longForm = '' }) {
{ - const message = { - content: longForm || content, - role: 'user', - }; - await append(message); + // Use the longForm text if provided, otherwise use the content + await append(longForm || content); }} >
{content}
diff --git a/ui/desktop/src/components/ToolInvocations.tsx b/ui/desktop/src/components/ToolInvocations.tsx index 7de42dd4b..fb71a97d1 100644 --- a/ui/desktop/src/components/ToolInvocations.tsx +++ b/ui/desktop/src/components/ToolInvocations.tsx @@ -6,8 +6,21 @@ import MarkdownContent from './MarkdownContent'; import { snakeToTitleCase } from '../utils'; import { LoadingPlaceholder } from './LoadingPlaceholder'; import { ChevronUp } from 'lucide-react'; +import { Content } from '../types/message'; + +interface ToolInvocation { + toolCallId: string; + toolName: string; + args: any; + state: 'running' | 'result'; + result?: Content[]; +} + +interface ToolInvocationsProps { + toolInvocations: ToolInvocation[]; +} -export default function ToolInvocations({ toolInvocations }) { +export default function ToolInvocations({ toolInvocations }: ToolInvocationsProps) { return ( <> {toolInvocations.map((toolInvocation) => ( @@ -17,7 +30,7 @@ export default function ToolInvocations({ toolInvocations }) { ); } -function ToolInvocation({ toolInvocation }) { +function ToolInvocation({ toolInvocation }: { toolInvocation: ToolInvocation }) { return (
@@ -34,7 +47,7 @@ function ToolInvocation({ toolInvocation }) { interface ToolCallProps { call: { - state: 'call' | 'result'; + state: 'running' | 'result'; toolCallId: string; toolName: string; args: Record; @@ -58,28 +71,13 @@ function ToolCall({ call }: ToolCallProps) { ); } -interface Annotations { - audience?: string[]; // Array of audience types - priority?: number; // Priority value between 0 and 1 -} - -interface ResultItem { - text?: string; - type: 'text' | 'image'; - mimeType?: string; - data?: string; // Base64 encoded image data - annotations?: Annotations; -} - interface ToolResultProps { result: { - message?: string; - result?: ResultItem[]; + result?: Content[]; state?: string; toolCallId?: string; toolName?: string; args?: any; - input_todo?: any; }; } @@ -95,8 +93,7 @@ function ToolResult({ result }: ToolResultProps) { // Find results where either audience is not set, or it's set to a list that contains user const filteredResults = results.filter( - (item: ResultItem) => - !item.annotations?.audience || item.annotations?.audience?.includes('user') + (item) => !item.audience || item.audience?.includes('user') ); if (filteredResults.length === 0) return null; @@ -107,21 +104,21 @@ function ToolResult({ result }: ToolResultProps) { ); }; - const shouldShowExpanded = (item: ResultItem, index: number) => { + const shouldShowExpanded = (item: Content, index: number) => { // (priority is defined and > 0.5) OR already in the expandedItems return ( - (item.annotations?.priority !== undefined && item.annotations?.priority >= 0.5) || + (item.priority !== undefined && item.priority >= 0.5) || expandedItems.includes(index) ); }; return (
- {filteredResults.map((item: ResultItem, index: number) => { + {filteredResults.map((item, index) => { const isExpanded = shouldShowExpanded(item, index); // minimize if priority is not set or < 0.5 const shouldMinimize = - item.annotations?.priority === undefined || item.annotations?.priority < 0.5; + item.priority === undefined || item.priority < 0.5; return (
{shouldMinimize && ( @@ -137,23 +134,12 @@ function ToolResult({ result }: ToolResultProps) { )} {(isExpanded || !shouldMinimize) && ( <> - {item.type === 'text' && item.text && ( + {item.text && ( )} - {item.type === 'image' && item.data && item.mimeType && ( - Tool result { - console.error('Failed to load image: Invalid MIME-type encoded image data'); - e.currentTarget.style.display = 'none'; - }} - /> - )} )}
diff --git a/ui/desktop/src/components/UserMessage.tsx b/ui/desktop/src/components/UserMessage.tsx index a37a8e92d..42bb072cb 100644 --- a/ui/desktop/src/components/UserMessage.tsx +++ b/ui/desktop/src/components/UserMessage.tsx @@ -2,16 +2,24 @@ import React from 'react'; import LinkPreview from './LinkPreview'; import { extractUrls } from '../utils/urlUtils'; import MarkdownContent from './MarkdownContent'; +import { Message, getTextContent } from '../types/message'; -export default function UserMessage({ message }) { +interface UserMessageProps { + message: Message; +} + +export default function UserMessage({ message }: UserMessageProps) { + // Extract text content from the message + const textContent = getTextContent(message); + // Extract URLs which explicitly contain the http:// or https:// protocol - const urls = extractUrls(message.content, []); + const urls = extractUrls(textContent, []); return (
- +
{/* TODO(alexhancock): Re-enable link previews once styled well again */} @@ -25,4 +33,4 @@ export default function UserMessage({ message }) {
); -} +} \ No newline at end of file diff --git a/ui/desktop/src/hooks/useMessageStream.ts b/ui/desktop/src/hooks/useMessageStream.ts new file mode 100644 index 000000000..a91d45878 --- /dev/null +++ b/ui/desktop/src/hooks/useMessageStream.ts @@ -0,0 +1,483 @@ +import { useState, useCallback, useEffect, useRef, useId } from 'react'; +import useSWR, { KeyedMutator } from 'swr'; +import { getSecretKey } from '../config'; +import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message'; + +// Event types for SSE stream +type MessageEvent = + | { type: 'Message', message: Message } + | { type: 'Error', error: string } + | { type: 'Finish', reason: string }; + +export interface UseMessageStreamOptions { + /** + * The API endpoint that accepts a `{ messages: Message[] }` object and returns + * a stream of messages. Defaults to `/api/chat/reply`. + */ + api?: string; + + /** + * A unique identifier for the chat. If not provided, a random one will be + * generated. When provided, the hook with the same `id` will + * have shared states across components. + */ + id?: string; + + /** + * Initial messages of the chat. Useful to load an existing chat history. + */ + initialMessages?: Message[]; + + /** + * Initial input of the chat. + */ + initialInput?: string; + + /** + * Callback function to be called when a tool call is received. + * You can optionally return a result for the tool call. + */ + onToolCall?: (toolCall: any) => void | Promise | any; + + /** + * Callback function to be called when the API response is received. + */ + onResponse?: (response: Response) => void | Promise; + + /** + * Callback function to be called when the assistant message is finished streaming. + */ + onFinish?: (message: Message, reason: string) => void; + + /** + * Callback function to be called when an error is encountered. + */ + onError?: (error: Error) => void; + + /** + * HTTP headers to be sent with the API request. + */ + headers?: Record | Headers; + + /** + * Extra body object to be sent with the API request. + */ + body?: object; + + /** + * Maximum number of sequential LLM calls (steps), e.g. when you use tool calls. + * Default is 1. + */ + maxSteps?: number; +} + +export interface UseMessageStreamHelpers { + /** Current messages in the chat */ + messages: Message[]; + + /** The error object of the API request */ + error: undefined | Error; + + /** + * Append a user message to the chat list. This triggers the API call to fetch + * the assistant's response. + */ + append: (message: Message | string) => Promise; + + /** + * Reload the last AI chat response for the given chat history. + */ + reload: () => Promise; + + /** + * Abort the current request immediately. + */ + stop: () => void; + + /** + * Update the `messages` state locally. + */ + setMessages: (messages: Message[] | ((messages: Message[]) => Message[])) => void; + + /** The current value of the input */ + input: string; + + /** setState-powered method to update the input value */ + setInput: React.Dispatch>; + + /** An input/textarea-ready onChange handler to control the value of the input */ + handleInputChange: ( + e: React.ChangeEvent | React.ChangeEvent + ) => void; + + /** Form submission handler to automatically reset input and append a user message */ + handleSubmit: (event?: { preventDefault?: () => void }) => void; + + /** Whether the API request is in progress */ + isLoading: boolean; + + /** Add a tool result to a tool call */ + addToolResult: ({ toolCallId, result }: { toolCallId: string; result: any }) => void; +} + +/** + * Hook for streaming messages directly from the server using the native Goose message format + */ +export function useMessageStream({ + api = '/api/chat/reply', + id, + initialMessages = [], + initialInput = '', + onToolCall, + onResponse, + onFinish, + onError, + headers, + body, + maxSteps = 1, +}: UseMessageStreamOptions = {}): UseMessageStreamHelpers { + // Generate a unique id for the chat if not provided + const hookId = useId(); + const idKey = id ?? hookId; + const chatKey = typeof api === 'string' ? [api, idKey] : idKey; + + // Store the chat state in SWR, using the chatId as the key to share states + const { data: messages, mutate } = useSWR([chatKey, 'messages'], null, { + fallbackData: initialMessages, + }); + + // Keep the latest messages in a ref + const messagesRef = useRef(messages || []); + useEffect(() => { + messagesRef.current = messages || []; + }, [messages]); + + // We store loading state in another hook to sync loading states across hook invocations + const { data: isLoading = false, mutate: mutateLoading } = useSWR( + [chatKey, 'loading'], + null + ); + + const { data: error = undefined, mutate: setError } = useSWR( + [chatKey, 'error'], + null + ); + + // Abort controller to cancel the current API call + const abortControllerRef = useRef(null); + + // Extra metadata for requests + const extraMetadataRef = useRef({ + headers, + body, + }); + + useEffect(() => { + extraMetadataRef.current = { + headers, + body, + }; + }, [headers, body]); + + // Process the SSE stream from the server + const processMessageStream = useCallback( + async (response: Response, currentMessages: Message[]) => { + if (!response.body) { + throw new Error('Response body is empty'); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + // Decode the chunk and add it to our buffer + buffer += decoder.decode(value, { stream: true }); + + // Process complete SSE events + const events = buffer.split('\n\n'); + buffer = events.pop() || ''; // Keep the last incomplete event in the buffer + + for (const event of events) { + if (event.startsWith('data: ')) { + try { + const data = event.slice(6); // Remove 'data: ' prefix + const parsedEvent = JSON.parse(data) as MessageEvent; + + switch (parsedEvent.type) { + case 'Message': + // Update messages with the new message + currentMessages = [...currentMessages, parsedEvent.message]; + mutate(currentMessages, false); + break; + + case 'Error': + throw new Error(parsedEvent.error); + + case 'Finish': + // Call onFinish with the last message if available + if (onFinish && currentMessages.length > 0) { + const lastMessage = currentMessages[currentMessages.length - 1]; + onFinish(lastMessage, parsedEvent.reason); + } + break; + } + } catch (e) { + console.error('Error parsing SSE event:', e); + if (onError && e instanceof Error) { + onError(e); + } + } + } + } + } + } catch (e) { + if (e instanceof Error && e.name !== 'AbortError') { + console.error('Error reading SSE stream:', e); + if (onError) { + onError(e); + } + } + } finally { + reader.releaseLock(); + } + + return currentMessages; + }, + [mutate, onFinish, onError] + ); + + // Send a request to the server + const sendRequest = useCallback( + async (requestMessages: Message[]) => { + try { + mutateLoading(true); + setError(undefined); + + // Create abort controller + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + // Log the request messages for debugging + console.log('Sending messages to server:', JSON.stringify(requestMessages, null, 2)); + + // Send request to the server + const response = await fetch(api, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-Secret-Key': getSecretKey(), + ...extraMetadataRef.current.headers, + }, + body: JSON.stringify({ + messages: requestMessages, + ...extraMetadataRef.current.body, + }), + signal: abortController.signal, + }); + + if (onResponse) { + await onResponse(response); + } + + if (!response.ok) { + const text = await response.text(); + throw new Error(text || `Error ${response.status}: ${response.statusText}`); + } + + // Process the SSE stream + const updatedMessages = await processMessageStream(response, requestMessages); + + // Auto-submit when all tool calls in the last assistant message have results + if (maxSteps > 1 && updatedMessages.length > requestMessages.length) { + const lastMessage = updatedMessages[updatedMessages.length - 1]; + if (lastMessage.role === 'assistant' && hasCompletedToolCalls(lastMessage)) { + // Count trailing assistant messages to prevent infinite loops + let assistantCount = 0; + for (let i = updatedMessages.length - 1; i >= 0; i--) { + if (updatedMessages[i].role === 'assistant') { + assistantCount++; + } else { + break; + } + } + + if (assistantCount < maxSteps) { + await sendRequest(updatedMessages); + } + } + } + + abortControllerRef.current = null; + } catch (err) { + // Ignore abort errors as they are expected + if ((err as any).name === 'AbortError') { + abortControllerRef.current = null; + return; + } + + if (onError && err instanceof Error) { + onError(err); + } + + setError(err as Error); + } finally { + mutateLoading(false); + } + }, + [api, processMessageStream, mutateLoading, setError, onResponse, onError, maxSteps] + ); + + // Append a new message and send request + const append = useCallback( + async (message: Message | string) => { + // If a string is passed, convert it to a Message object + const messageToAppend = typeof message === 'string' + ? createUserMessage(message) + : message; + + console.log('Appending message:', JSON.stringify(messageToAppend, null, 2)); + + const currentMessages = [...messagesRef.current, messageToAppend]; + mutate(currentMessages, false); + await sendRequest(currentMessages); + }, + [mutate, sendRequest] + ); + + // Reload the last message + const reload = useCallback(async () => { + const currentMessages = messagesRef.current; + if (currentMessages.length === 0) { + return; + } + + // Remove last assistant message if present + const lastMessage = currentMessages[currentMessages.length - 1]; + const messagesToSend = lastMessage.role === 'assistant' + ? currentMessages.slice(0, -1) + : currentMessages; + + await sendRequest(messagesToSend); + }, [sendRequest]); + + // Stop the current request + const stop = useCallback(() => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + } + }, []); + + // Set messages directly + const setMessages = useCallback( + (messagesOrFn: Message[] | ((messages: Message[]) => Message[])) => { + if (typeof messagesOrFn === 'function') { + const newMessages = messagesOrFn(messagesRef.current); + mutate(newMessages, false); + messagesRef.current = newMessages; + } else { + mutate(messagesOrFn, false); + messagesRef.current = messagesOrFn; + } + }, + [mutate] + ); + + // Input state and handlers + const [input, setInput] = useState(initialInput); + + const handleInputChange = useCallback( + (e: React.ChangeEvent | React.ChangeEvent) => { + setInput(e.target.value); + }, + [] + ); + + const handleSubmit = useCallback( + async (event?: { preventDefault?: () => void }) => { + event?.preventDefault?.(); + if (!input.trim()) return; + + console.log('handleSubmit called with input:', input); + await append(input); + setInput(''); + }, + [input, append] + ); + + // Add tool result to a message + const addToolResult = useCallback( + ({ toolCallId, result }: { toolCallId: string; result: any }) => { + const currentMessages = messagesRef.current; + + // Find the last assistant message with the tool call + let lastAssistantIndex = -1; + for (let i = currentMessages.length - 1; i >= 0; i--) { + if (currentMessages[i].role === 'assistant') { + const toolRequests = currentMessages[i].content.filter( + content => 'ToolRequest' in content && + content.ToolRequest.id === toolCallId + ); + if (toolRequests.length > 0) { + lastAssistantIndex = i; + break; + } + } + } + + if (lastAssistantIndex === -1) return; + + // Create a tool response message + const toolResponseMessage = { + role: 'user' as const, + created: Math.floor(Date.now() / 1000), + content: [ + { + ToolResponse: { + id: toolCallId, + tool_result: { + ok: Array.isArray(result) ? result : [{ text: String(result), priority: 0 }], + }, + }, + }, + ], + }; + + // Insert the tool response after the assistant message + const updatedMessages = [ + ...currentMessages.slice(0, lastAssistantIndex + 1), + toolResponseMessage, + ...currentMessages.slice(lastAssistantIndex + 1), + ]; + + mutate(updatedMessages, false); + messagesRef.current = updatedMessages; + + // Auto-submit if we have tool results + if (maxSteps > 1) { + sendRequest(updatedMessages); + } + }, + [mutate, maxSteps, sendRequest] + ); + + return { + messages: messages || [], + error, + append, + reload, + stop, + setMessages, + input, + setInput, + handleInputChange, + handleSubmit, + isLoading: isLoading || false, + addToolResult, + }; +} \ No newline at end of file diff --git a/ui/desktop/src/types/message.ts b/ui/desktop/src/types/message.ts new file mode 100644 index 000000000..eb92fe820 --- /dev/null +++ b/ui/desktop/src/types/message.ts @@ -0,0 +1,199 @@ +/** + * Message types that match the Rust message structures + * for direct serialization between client and server + */ + +export type Role = 'user' | 'assistant'; + +export interface TextContent { + text: string; + annotations?: any; +} + +export interface ImageContent { + data: string; + mime_type: string; + annotations?: any; +} + +export interface ToolCall { + name: string; + arguments: any; +} + +export interface Content { + text?: string; + priority?: number; + // Add other content types as needed +} + +export interface ToolRequest { + id: string; + tool_call: { + Ok?: ToolCall; + Err?: string; + }; +} + +export interface ToolResponse { + id: string; + tool_result: { + Ok?: Content[]; + Err?: string; + }; +} + +export interface ToolConfirmationRequest { + id: string; + tool_name: string; + arguments: any; + prompt?: string; +} + +export type MessageContent = + | { Text: TextContent } + | { Image: ImageContent } + | { ToolRequest: ToolRequest } + | { ToolResponse: ToolResponse } + | { ToolConfirmationRequest: ToolConfirmationRequest }; + +export interface Message { + id?: string; + role: Role; + created: number; + content: MessageContent[]; +} + +// Helper functions to create messages +export function createUserMessage(text: string): Message { + return { + id: generateId(), + role: 'user', + created: Math.floor(Date.now() / 1000), + content: [{ Text: { text } }], + }; +} + +export function createAssistantMessage(text: string): Message { + return { + id: generateId(), + role: 'assistant', + created: Math.floor(Date.now() / 1000), + content: [{ Text: { text } }], + }; +} + +export function createToolRequestMessage(id: string, toolName: string, args: any): Message { + return { + id: generateId(), + role: 'assistant', + created: Math.floor(Date.now() / 1000), + content: [ + { + ToolRequest: { + id, + tool_call: { + Ok: { + // Using Ok to match the server format + name: toolName, + arguments: args, + }, + }, + }, + }, + ], + }; +} + +export function createToolResponseMessage(id: string, result: Content[]): Message { + return { + id: generateId(), + role: 'user', + created: Math.floor(Date.now() / 1000), + content: [ + { + ToolResponse: { + id, + tool_result: { + Ok: result, // Using Ok to match the server format + }, + }, + }, + ], + }; +} + +export function createToolErrorResponseMessage(id: string, error: string): Message { + return { + id: generateId(), + role: 'user', + created: Math.floor(Date.now() / 1000), + content: [ + { + ToolResponse: { + id, + tool_result: { + Err: error, // Using Err to match the server format + }, + }, + }, + ], + }; +} + +// Generate a unique ID for messages +function generateId(): string { + return Math.random().toString(36).substring(2, 10); +} + +// Helper functions to extract content from messages +export function getTextContent(message: Message): string { + return message.content + .filter((content): content is { Text: TextContent } => 'Text' in content) + .map((content) => content.Text.text) + .join('\n'); +} + +export function getToolRequests(message: Message): ToolRequest[] { + // Try both casing variations + return message.content + .filter( + (content): content is { ToolRequest: ToolRequest } | { toolRequest: ToolRequest } => + 'ToolRequest' in content || 'toolRequest' in content + ) + .map((content) => { + if ('ToolRequest' in content) { + return content.ToolRequest; + } else { + // Handle potential lowercase property name + return (content as any).toolRequest; + } + }); +} + +export function getToolResponses(message: Message): ToolResponse[] { + // Try both casing variations + return message.content + .filter( + (content): content is { ToolResponse: ToolResponse } | { toolResponse: ToolResponse } => + 'ToolResponse' in content || 'toolResponse' in content + ) + .map((content) => { + if ('ToolResponse' in content) { + return content.ToolResponse; + } else { + // Handle potential lowercase property name + return (content as any).toolResponse; + } + }); +} + +export function hasCompletedToolCalls(message: Message): boolean { + const toolRequests = getToolRequests(message); + if (toolRequests.length === 0) return false; + + // For now, we'll assume all tool calls are completed when this is checked + // In a real implementation, you'd need to check if all tool requests have responses + // by looking through subsequent messages + return true; +} diff --git a/ui/desktop/src/utils/generateId.ts b/ui/desktop/src/utils/generateId.ts new file mode 100644 index 000000000..d4f664e75 --- /dev/null +++ b/ui/desktop/src/utils/generateId.ts @@ -0,0 +1,7 @@ +/** + * Generate a random ID string + * @returns A random string ID + */ +export function generateId(): string { + return Math.random().toString(36).substring(2, 10); +} \ No newline at end of file