Skip to content

Commit

Permalink
feat(frontend): Wait for events before rendering messages (#4994)
Browse files Browse the repository at this point in the history
Co-authored-by: mamoodi <[email protected]>
  • Loading branch information
amanape and mamoodi authored Nov 14, 2024
1 parent fac5237 commit 01cacf7
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 25 deletions.
93 changes: 93 additions & 0 deletions frontend/__tests__/hooks/use-rate.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import { act, renderHook } from "@testing-library/react";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { useRate } from "#/utils/use-rate";

describe("useRate", () => {
beforeEach(() => {
vi.useFakeTimers();
});

afterEach(() => {
vi.useRealTimers();
});

it("should initialize", () => {
const { result } = renderHook(() => useRate());

expect(result.current.items).toHaveLength(0);
expect(result.current.rate).toBeNull();
expect(result.current.lastUpdated).toBeNull();
expect(result.current.isUnderThreshold).toBe(true);
});

it("should handle the case of a single element", () => {
const { result } = renderHook(() => useRate());

act(() => {
result.current.record(123);
});

expect(result.current.items).toHaveLength(1);
expect(result.current.lastUpdated).not.toBeNull();
});

it("should return the difference between the last two elements", () => {
const { result } = renderHook(() => useRate());

vi.setSystemTime(500);
act(() => {
result.current.record(4);
});

vi.advanceTimersByTime(500);
act(() => {
result.current.record(9);
});

expect(result.current.items).toHaveLength(2);
expect(result.current.rate).toBe(5);
expect(result.current.lastUpdated).toBe(1000);
});

it("should update isUnderThreshold after [threshold]ms of no activity", () => {
const { result } = renderHook(() => useRate({ threshold: 500 }));

expect(result.current.isUnderThreshold).toBe(true);

act(() => {
// not sure if fake timers is buggy with intervals,
// but I need to call it twice to register
vi.advanceTimersToNextTimer();
vi.advanceTimersToNextTimer();
});

expect(result.current.isUnderThreshold).toBe(false);
});

it("should return an isUnderThreshold boolean", () => {
const { result } = renderHook(() => useRate({ threshold: 500 }));

vi.setSystemTime(500);
act(() => {
result.current.record(400);
});
act(() => {
result.current.record(1000);
});

expect(result.current.isUnderThreshold).toBe(false);

act(() => {
result.current.record(1500);
});

expect(result.current.isUnderThreshold).toBe(true);

act(() => {
vi.advanceTimersToNextTimer();
vi.advanceTimersToNextTimer();
});

expect(result.current.isUnderThreshold).toBe(false);
});
});
55 changes: 31 additions & 24 deletions frontend/src/components/chat-interface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ const isErrorMessage = (
): message is ErrorMessage => "error" in message;

export function ChatInterface() {
const { send } = useWsClient();
const { send, isLoadingMessages } = useWsClient();

const dispatch = useDispatch();
const scrollRef = React.useRef<HTMLDivElement>(null);
const { scrollDomToBottom, onChatBodyScroll, hitBottom } =
Expand Down Expand Up @@ -101,30 +102,36 @@ export function ChatInterface() {
onScroll={(e) => onChatBodyScroll(e.currentTarget)}
className="flex flex-col grow overflow-y-auto overflow-x-hidden px-4 pt-4 gap-2"
>
{messages.map((message, index) =>
isErrorMessage(message) ? (
<ErrorMessage
key={index}
id={message.id}
message={message.message}
/>
) : (
<ChatMessage
key={index}
type={message.sender}
message={message.content}
>
{message.imageUrls.length > 0 && (
<ImageCarousel size="small" images={message.imageUrls} />
)}
{messages.length - 1 === index &&
message.sender === "assistant" &&
curAgentState === AgentState.AWAITING_USER_CONFIRMATION && (
<ConfirmationButtons />
)}
</ChatMessage>
),
{isLoadingMessages && (
<div className="flex justify-center">
<div className="w-6 h-6 border-2 border-t-[4px] border-primary-500 rounded-full animate-spin" />
</div>
)}
{!isLoadingMessages &&
messages.map((message, index) =>
isErrorMessage(message) ? (
<ErrorMessage
key={index}
id={message.id}
message={message.message}
/>
) : (
<ChatMessage
key={index}
type={message.sender}
message={message.content}
>
{message.imageUrls.length > 0 && (
<ImageCarousel size="small" images={message.imageUrls} />
)}
{messages.length - 1 === index &&
message.sender === "assistant" &&
curAgentState === AgentState.AWAITING_USER_CONFIRMATION && (
<ConfirmationButtons />
)}
</ChatMessage>
),
)}
</div>

<div className="flex flex-col gap-[6px] px-4 pb-4">
Expand Down
14 changes: 13 additions & 1 deletion frontend/src/context/ws-client-provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import ActionType from "#/types/ActionType";
import EventLogger from "#/utils/event-logger";
import AgentState from "#/types/AgentState";
import { handleAssistantMessage } from "#/services/actions";
import { useRate } from "#/utils/use-rate";

const isOpenHandsMessage = (event: Record<string, unknown>) =>
event.action === "message";

const RECONNECT_RETRIES = 5;

Expand All @@ -17,12 +21,14 @@ export enum WsClientProviderStatus {

interface UseWsClient {
status: WsClientProviderStatus;
isLoadingMessages: boolean;
events: Record<string, unknown>[];
send: (event: Record<string, unknown>) => void;
}

const WsClientContext = React.createContext<UseWsClient>({
status: WsClientProviderStatus.STOPPED,
isLoadingMessages: true,
events: [],
send: () => {
throw new Error("not connected");
Expand Down Expand Up @@ -51,6 +57,8 @@ export function WsClientProvider({
const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
const [retryCount, setRetryCount] = React.useState(RECONNECT_RETRIES);

const messageRateHandler = useRate({ threshold: 500 });

function send(event: Record<string, unknown>) {
if (!wsRef.current) {
EventLogger.error("WebSocket is not connected.");
Expand All @@ -71,6 +79,9 @@ export function WsClientProvider({

function handleMessage(messageEvent: MessageEvent) {
const event = JSON.parse(messageEvent.data);
if (isOpenHandsMessage(event)) {
messageRateHandler.record(new Date().getTime());
}
setEvents((prevEvents) => [...prevEvents, event]);
if (event.extras?.agent_state === AgentState.INIT) {
setStatus(WsClientProviderStatus.ACTIVE);
Expand Down Expand Up @@ -177,10 +188,11 @@ export function WsClientProvider({
const value = React.useMemo<UseWsClient>(
() => ({
status,
isLoadingMessages: messageRateHandler.isUnderThreshold,
events,
send,
}),
[status, events],
[status, messageRateHandler.isUnderThreshold, events],
);

return (
Expand Down
67 changes: 67 additions & 0 deletions frontend/src/utils/use-rate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import React from "react";

interface UseRateProps {
threshold: number;
}

const DEFAULT_CONFIG: UseRateProps = { threshold: 1000 };

export const useRate = (config = DEFAULT_CONFIG) => {
const [items, setItems] = React.useState<number[]>([]);
const [rate, setRate] = React.useState<number | null>(null);
const [lastUpdated, setLastUpdated] = React.useState<number | null>(null);
const [isUnderThreshold, setIsUnderThreshold] = React.useState(true);

/**
* Record an entry in order to calculate the rate
* @param entry Entry to record
*
* @example
* record(new Date().getTime());
*/
const record = (entry: number) => {
setItems((prev) => [...prev, entry]);
setLastUpdated(new Date().getTime());
};

/**
* Update the rate based on the last two entries (if available)
*/
const updateRate = () => {
if (items.length > 1) {
const newRate = items[items.length - 1] - items[items.length - 2];
setRate(newRate);

if (newRate <= config.threshold) setIsUnderThreshold(true);
else setIsUnderThreshold(false);
}
};

React.useEffect(() => {
updateRate();
}, [items]);

React.useEffect(() => {
// Set up an interval to check if the time since the last update exceeds the threshold
// If it does, set isUnderThreshold to false, otherwise set it to true
// This ensures that the component can react to periods of inactivity
const intervalId = setInterval(() => {
if (lastUpdated !== null) {
const timeSinceLastUpdate = new Date().getTime() - lastUpdated;
setIsUnderThreshold(timeSinceLastUpdate <= config.threshold);
} else {
setIsUnderThreshold(false);
}
}, config.threshold);

return () => clearInterval(intervalId);
}, [lastUpdated, config.threshold]);

return {
items,
rate,
lastUpdated,
isUnderThreshold,
record,
};
};

0 comments on commit 01cacf7

Please sign in to comment.