Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(frontend): Wait for events before rendering messages #4994

Merged
merged 11 commits into from
Nov 14, 2024
93 changes: 93 additions & 0 deletions frontend/__tests__/hooks/use-rate.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import { act, renderHook } from "@testing-library/react";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { useRate } from "#/utils/use-rate";

describe("useRate", () => {
beforeEach(() => {
vi.useFakeTimers();
});

afterEach(() => {
vi.useRealTimers();
});

it("should initialize", () => {
const { result } = renderHook(() => useRate());

expect(result.current.items).toHaveLength(0);
expect(result.current.rate).toBeNull();
expect(result.current.lastUpdated).toBeNull();
expect(result.current.isUnderThreshold).toBe(true);
});

it("should handle the case of a single element", () => {
const { result } = renderHook(() => useRate());

act(() => {
result.current.record(123);
});

expect(result.current.items).toHaveLength(1);
expect(result.current.lastUpdated).not.toBeNull();
});

it("should return the difference between the last two elements", () => {
const { result } = renderHook(() => useRate());

vi.setSystemTime(500);
act(() => {
result.current.record(4);
});

vi.advanceTimersByTime(500);
act(() => {
result.current.record(9);
});

expect(result.current.items).toHaveLength(2);
expect(result.current.rate).toBe(5);
expect(result.current.lastUpdated).toBe(1000);
});

it("should update isUnderThreshold after [threshold]ms of no activity", () => {
const { result } = renderHook(() => useRate({ threshold: 500 }));

expect(result.current.isUnderThreshold).toBe(true);

act(() => {
// not sure if fake timers is buggy with intervals,
// but I need to call it twice to register
vi.advanceTimersToNextTimer();
vi.advanceTimersToNextTimer();
});

expect(result.current.isUnderThreshold).toBe(false);
});

it("should return an isUnderThreshold boolean", () => {
const { result } = renderHook(() => useRate({ threshold: 500 }));

vi.setSystemTime(500);
act(() => {
result.current.record(400);
});
act(() => {
result.current.record(1000);
});

expect(result.current.isUnderThreshold).toBe(false);

act(() => {
result.current.record(1500);
});

expect(result.current.isUnderThreshold).toBe(true);

act(() => {
vi.advanceTimersToNextTimer();
vi.advanceTimersToNextTimer();
});

expect(result.current.isUnderThreshold).toBe(false);
});
});
55 changes: 31 additions & 24 deletions frontend/src/components/chat-interface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ const isErrorMessage = (
): message is ErrorMessage => "error" in message;

export function ChatInterface() {
const { send } = useWsClient();
const { send, isLoadingMessages } = useWsClient();

const dispatch = useDispatch();
const scrollRef = React.useRef<HTMLDivElement>(null);
const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
Expand Down Expand Up @@ -101,30 +102,36 @@ export function ChatInterface() {
onScroll={(e) => onChatBodyScroll(e.currentTarget)}
className="flex flex-col grow overflow-y-auto overflow-x-hidden px-4 pt-4 gap-2"
>
{messages.map((message, index) =>
isErrorMessage(message) ? (
<ErrorMessage
key={index}
id={message.id}
message={message.message}
/>
) : (
<ChatMessage
key={index}
type={message.sender}
message={message.content}
>
{message.imageUrls.length > 0 && (
<ImageCarousel size="small" images={message.imageUrls} />
)}
{messages.length - 1 === index &&
message.sender === "assistant" &&
curAgentState === AgentState.AWAITING_USER_CONFIRMATION && (
<ConfirmationButtons />
)}
</ChatMessage>
),
{isLoadingMessages && (
<div className="flex justify-center">
<div className="w-6 h-6 border-2 border-t-[4px] border-primary-500 rounded-full animate-spin" />
</div>
)}
{!isLoadingMessages &&
messages.map((message, index) =>
isErrorMessage(message) ? (
<ErrorMessage
key={index}
id={message.id}
message={message.message}
/>
) : (
<ChatMessage
key={index}
type={message.sender}
message={message.content}
>
{message.imageUrls.length > 0 && (
<ImageCarousel size="small" images={message.imageUrls} />
)}
{messages.length - 1 === index &&
message.sender === "assistant" &&
curAgentState === AgentState.AWAITING_USER_CONFIRMATION && (
<ConfirmationButtons />
)}
</ChatMessage>
),
)}
</div>

<div className="flex flex-col gap-[6px] px-4 pb-4">
Expand Down
14 changes: 13 additions & 1 deletion frontend/src/context/ws-client-provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import ActionType from "#/types/ActionType";
import EventLogger from "#/utils/event-logger";
import AgentState from "#/types/AgentState";
import { handleAssistantMessage } from "#/services/actions";
import { useRate } from "#/utils/use-rate";

const isOpenHandsMessage = (event: Record<string, unknown>) =>
event.action === "message";

const RECONNECT_RETRIES = 5;

Expand All @@ -17,12 +21,14 @@ export enum WsClientProviderStatus {

interface UseWsClient {
status: WsClientProviderStatus;
isLoadingMessages: boolean;
events: Record<string, unknown>[];
send: (event: Record<string, unknown>) => void;
}

const WsClientContext = React.createContext<UseWsClient>({
status: WsClientProviderStatus.STOPPED,
isLoadingMessages: true,
events: [],
send: () => {
throw new Error("not connected");
Expand Down Expand Up @@ -51,6 +57,8 @@ export function WsClientProvider({
const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
const [retryCount, setRetryCount] = React.useState(RECONNECT_RETRIES);

const messageRateHandler = useRate({ threshold: 500 });

function send(event: Record<string, unknown>) {
if (!wsRef.current) {
EventLogger.error("WebSocket is not connected.");
Expand All @@ -71,6 +79,9 @@ export function WsClientProvider({

function handleMessage(messageEvent: MessageEvent) {
const event = JSON.parse(messageEvent.data);
if (isOpenHandsMessage(event)) {
messageRateHandler.record(new Date().getTime());
}
setEvents((prevEvents) => [...prevEvents, event]);
if (event.extras?.agent_state === AgentState.INIT) {
setStatus(WsClientProviderStatus.ACTIVE);
Expand Down Expand Up @@ -177,10 +188,11 @@ export function WsClientProvider({
const value = React.useMemo<UseWsClient>(
() => ({
status,
isLoadingMessages: messageRateHandler.isUnderThreshold,
events,
send,
}),
[status, events],
[status, messageRateHandler.isUnderThreshold, events],
);

return (
Expand Down
67 changes: 67 additions & 0 deletions frontend/src/utils/use-rate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import React from "react";

interface UseRateProps {
threshold: number;
}

const DEFAULT_CONFIG: UseRateProps = { threshold: 1000 };

export const useRate = (config = DEFAULT_CONFIG) => {
const [items, setItems] = React.useState<number[]>([]);
const [rate, setRate] = React.useState<number | null>(null);
const [lastUpdated, setLastUpdated] = React.useState<number | null>(null);
const [isUnderThreshold, setIsUnderThreshold] = React.useState(true);

/**
* Record an entry in order to calculate the rate
* @param entry Entry to record
*
* @example
* record(new Date().getTime());
*/
const record = (entry: number) => {
setItems((prev) => [...prev, entry]);
setLastUpdated(new Date().getTime());
};

/**
* Update the rate based on the last two entries (if available)
*/
const updateRate = () => {
if (items.length > 1) {
const newRate = items[items.length - 1] - items[items.length - 2];
setRate(newRate);

if (newRate <= config.threshold) setIsUnderThreshold(true);
else setIsUnderThreshold(false);
}
};

React.useEffect(() => {
updateRate();
}, [items]);

React.useEffect(() => {
// Set up an interval to check if the time since the last update exceeds the threshold
// If it does, set isUnderThreshold to false, otherwise set it to true
// This ensures that the component can react to periods of inactivity
const intervalId = setInterval(() => {
if (lastUpdated !== null) {
const timeSinceLastUpdate = new Date().getTime() - lastUpdated;
setIsUnderThreshold(timeSinceLastUpdate <= config.threshold);
} else {
setIsUnderThreshold(false);
}
}, config.threshold);

return () => clearInterval(intervalId);
}, [lastUpdated, config.threshold]);

return {
items,
rate,
lastUpdated,
isUnderThreshold,
record,
};
};
Loading