diff --git a/components/Chat/Chat.tsx b/components/Chat/Chat.tsx index 2e3aa069..0826097f 100644 --- a/components/Chat/Chat.tsx +++ b/components/Chat/Chat.tsx @@ -1,22 +1,21 @@ -import { AI_DISCLAIMER, AI_SEARCH_UNSUBMITTED } from "@/lib/constants/common"; -import { MessageTypes, StreamingMessage } from "@/types/components/chat"; -import React, { useEffect, useState } from "react"; import { StyledResponseActions, StyledResponseDisclaimer, StyledUnsubmitted, } from "@/components/Chat/Response/Response.styled"; import { defaultState, useSearchState } from "@/context/search-context"; +import { AI_DISCLAIMER, AI_SEARCH_UNSUBMITTED } from "@/lib/constants/common"; +import React, { useEffect, useState } from "react"; -import Announcement from "@/components/Shared/Announcement"; -import { Button } from "@nulib/design-system"; import ChatFeedback from "@/components/Chat/Feedback/Feedback"; import ChatResponse from "@/components/Chat/Response/Response"; +import Announcement from "@/components/Shared/Announcement"; import Container from "@/components/Shared/Container"; -import { Work } from "@nulib/dcapi-types"; -import { prepareQuestion } from "@/lib/chat-helpers"; import useChatSocket from "@/hooks/useChatSocket"; import useQueryParams from "@/hooks/useQueryParams"; +import { prepareQuestion } from "@/lib/chat-helpers"; +import { Work } from "@nulib/dcapi-types"; +import { Button } from "@nulib/design-system"; const Chat = ({ totalResults, diff --git a/components/Chat/Response/Response.tsx b/components/Chat/Response/Response.tsx index 1d64e6d6..ef41e3bf 100644 --- a/components/Chat/Response/Response.tsx +++ b/components/Chat/Response/Response.tsx @@ -6,12 +6,11 @@ import { StyledResponseWrapper, } from "./Response.styled"; -import BouncingLoader from "@/components/Shared/BouncingLoader"; -import Container from "@/components/Shared/Container"; import ResponseImages from "@/components/Chat/Response/Images"; import ResponseMarkdown from "@/components/Chat/Response/Markdown"; +import BouncingLoader from "@/components/Shared/BouncingLoader"; +import Container from "@/components/Shared/Container"; import { StreamingMessage } from "@/types/components/chat"; -import { Work } from "@nulib/dcapi-types"; interface ChatResponseProps { conversationRef?: string; @@ -35,7 +34,7 @@ const ChatResponse: React.FC = ({ const { type } = message; if (type === "token") { - setStreamedMessage((prev) => prev + message?.message); + setStreamedMessage((prev) => prev + message.message); } if (type === "answer") { @@ -51,19 +50,18 @@ const ChatResponse: React.FC = ({ } if (type === "tool_start") { - // @ts-ignore - const { tool, input } = message?.message; + const { tool, input } = message.message; let interstitialMessage = ""; switch (tool) { case "discover_fields": interstitialMessage = "Discovering fields"; break; case "search": - interstitialMessage = `Searching for: ${input?.query}`; + interstitialMessage = `Searching for: ${input.query}`; break; case "aggregate": console.log(`aggregate input`, input); - interstitialMessage = `Aggregating ${input?.agg_field} by ${input?.term_field} ${input?.term}`; + interstitialMessage = `Aggregating ${input.agg_field} by ${input.term_field} ${input.term}`; break; default: console.warn("Unknown tool_start message", message); @@ -87,7 +85,7 @@ const ChatResponse: React.FC = ({ <> {prev} diff --git a/types/components/chat.ts b/types/components/chat.ts index ca22a2fb..a1d77799 100644 --- a/types/components/chat.ts +++ b/types/components/chat.ts @@ -1,62 +1,90 @@ import { Work } from "@nulib/dcapi-types"; -export type MessageTypes = - | "answer" - | "aggregation_result" - | "final" - | "final_message" - | "search_result" - | "start" - | "stop" - | "token" - | "tool_start"; - -type MessageAggregationResult = { - buckets: [ - { - key: string; - doc_count: number; - }, - ]; - doc_count_error_upper_bound: number; - sum_other_doc_count: number; -}; - -type MessageSearchResult = Array; - -type MessageModel = { - model: string; -}; - -type MessageTool = - | { - tool: "discover_fields"; - input: {}; - } - | { - tool: "search"; - input: { - query: string; - }; - } - | { - tool: "aggregate"; - input: { agg_field: string; term_field: string; term: string }; - }; - -type MessageShape = - | string - | MessageAggregationResult - | MessageSearchResult - | MessageModel - | MessageTool; - -export type StreamingMessage = { +export type Ref = { ref: string; - message?: MessageShape; - type: MessageTypes; }; +export type AggregationResultMessage = { + type: "aggregation_result"; + message: { + buckets: [ + { + key: string; + doc_count: number; + }, + ]; + doc_count_error_upper_bound: number; + sum_other_doc_count: number; + }; +}; + +export type AgentFinalMessage = { + type: "final"; + message: string; +}; + +export type LLMAnswerMessage = { + type: "answer"; + message: string; +}; + +export type LLMFinalMessage = { + type: "final_message"; +}; + +export type LLMTokenMessage = { + type: "token"; + message: string; +}; + +export type LLMStopMessage = { + type: "stop"; +}; + +export type SearchResultMessage = { + type: "search_result"; + message: Array; +}; + +export type StartMessage = { + type: "start"; + message: { + model: string; + }; +}; + +export type ToolStartMessage = { + type: "tool_start"; + message: + | { + tool: "discover_fields"; + input: {}; + } + | { + tool: "search"; + input: { + query: string; + }; + } + | { + tool: "aggregate"; + input: { agg_field: string; term_field: string; term: string }; + }; +}; + +export type StreamingMessage = Ref & + ( + | AggregationResultMessage + | AgentFinalMessage + | LLMAnswerMessage + | LLMFinalMessage + | LLMTokenMessage + | LLMStopMessage + | SearchResultMessage + | StartMessage + | ToolStartMessage + ); + export type ChatConfig = { auth: string; endpoint: string;