Skip to content

Commit

Permalink
Merge pull request #2670 from stakwork/ask-question-for-selected
Browse files Browse the repository at this point in the history
feat: context ai search
  • Loading branch information
Rassl authored Feb 7, 2025
2 parents cb7a91c + bb8589d commit f714834
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 17 deletions.
57 changes: 52 additions & 5 deletions src/components/App/SideBar/AiSearch/index.tsx
Original file line number Diff line number Diff line change
@@ -1,37 +1,78 @@
import { FormProvider, useForm } from 'react-hook-form'
import { ClipLoader } from 'react-spinners'
import styled from 'styled-components'
import { Flex } from '~/components/common/Flex'
import SearchIcon from '~/components/Icons/SearchIcon'
import { SearchBar } from '~/components/SearchBar'
import { Flex } from '~/components/common/Flex'
import { useAiSummaryStore, useHasAiChatsResponseLoading } from '~/stores/useAiSummaryStore'
import { useDataStore } from '~/stores/useDataStore'
import { useGraphStore } from '~/stores/useGraphStore'
import { useSchemaStore } from '~/stores/useSchemaStore'
import { useUserStore } from '~/stores/useUserStore'
import { NodeExtended } from '~/types'
import { colors } from '~/utils'

export const AiSearch = () => {
export const AiSearch = ({ contextSearch }: { contextSearch?: boolean }) => {
const form = useForm<{ search: string }>({ mode: 'onChange' })
const { setAbortRequests } = useDataStore((s) => s)
const { setBudget } = useUserStore((s) => s)
const { reset } = form
const fetchAIData = useAiSummaryStore((s) => s.fetchAIData)
const { selectedNode } = useGraphStore((s) => s)
const normalizedSchemasByType = useSchemaStore((s) => s.normalizedSchemasByType)

const isLoading = useHasAiChatsResponseLoading()

let context = ''

function getNodeKeyDetails(nodeKey: string | undefined, currentSelectedNode: NodeExtended) {
let nodeKeyContextString = ''

if (!nodeKey) {
return nodeKeyContextString
}

const nodeKeyArr = nodeKey.split('-')

for (let i = 0; i < nodeKeyArr.length; i += 1) {
const key = nodeKeyArr[i]
const propertyValue = currentSelectedNode.properties ? currentSelectedNode.properties[key] : ''
const comma = i === nodeKeyArr.length - 1 ? '' : ','

nodeKeyContextString = `${nodeKeyContextString} ${key} - ${propertyValue}${comma}`
}

return nodeKeyContextString
}

const handleSubmit = form.handleSubmit(({ search }) => {
if (search.trim() === '') {
return
}

fetchAIData(setBudget, setAbortRequests, search)
if (contextSearch && selectedNode) {
const nodeType = selectedNode.node_type

const nodeKey = normalizedSchemasByType[nodeType].node_key

const nodeKeyContextString = getNodeKeyDetails(nodeKey, selectedNode)

context = `**${nodeType}: ${nodeKeyContextString}**`
}

fetchAIData(setBudget, setAbortRequests, search, undefined, context)
reset({ search: '' })
})

return (
<AiSearchWrapper>
<FormProvider {...form}>
<Search>
<SearchBar loading={isLoading} onSubmit={handleSubmit} placeholder="Ask follow-up" />
<SearchBar
loading={isLoading}
onSubmit={handleSubmit}
placeholder={contextSearch ? 'Ask a question' : 'Ask follow-up'}
/>
<InputButton
data-testid="search-ai_action_icon"
onClick={() => {
Expand All @@ -42,7 +83,13 @@ export const AiSearch = () => {
handleSubmit()
}}
>
{!isLoading ? <SearchIcon /> : <StyledClipLoader color={colors.lightGray} data-testid="loader" size="20" />}
{contextSearch && <SearchIcon />}
{!contextSearch &&
(!isLoading ? (
<SearchIcon />
) : (
<StyledClipLoader color={colors.lightGray} data-testid="loader" size="20" />
))}
</InputButton>
</Search>
</FormProvider>
Expand Down
15 changes: 14 additions & 1 deletion src/components/App/SideBar/SidebarSubView/__tests__/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,22 @@ Object.defineProperty(window, 'matchMedia', {
})

jest.mock('react-hook-form', () => ({
...jest.requireActual('react-hook-form'),
useFormContext: jest.fn(() => ({
setValue: jest.fn(),
register: jest.fn(),
watch: jest.fn(() => ''),
})),
useForm: jest.fn(() => ({
register: jest.fn(),
handleSubmit: jest.fn((fn) => (event) => fn(event)),
reset: jest.fn((fn) => () => fn()),
})),
}))

jest.mock('react-router-dom', () => ({
...jest.requireActual('react-router-dom'),
useNavigate: jest.fn(),
}))

jest.mock('~/stores/useDataStore', () => ({
Expand Down Expand Up @@ -63,7 +76,7 @@ const mockSelectedNode = {
describe('Test SideBarSubView', () => {
beforeEach(() => {
jest.clearAllMocks()
useDataStoreMock.mockReturnValue({ setTeachMe: jest.fn(), showTeachMe: false })
useDataStoreMock.mockReturnValue({ setTeachMe: jest.fn(), showTeachMe: false, setAbortRequests: jest.fn() })
useGraphStoreMock.mockReturnValue({ setSelectedNode: jest.fn() })
useSelectedNodeMock.mockReturnValue(mockSelectedNode)
useAppStoreMock.mockReturnValue({ setSidebarOpen: jest.fn() })
Expand Down
17 changes: 13 additions & 4 deletions src/components/App/SideBar/SidebarSubView/index.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import { Slide } from '@mui/material'
import styled from 'styled-components'
import { Flex } from '~/components/common/Flex'
import ChevronLeftIcon from '~/components/Icons/ChevronLeftIcon'
import CloseIcon from '~/components/Icons/CloseIcon'
import { Flex } from '~/components/common/Flex'
import { useAppStore } from '~/stores/useAppStore'
import { useGraphStore, useSelectedNode } from '~/stores/useGraphStore'
import { usePlayerStore } from '~/stores/usePlayerStore'
import { colors } from '~/utils/colors'
import { AiSearch } from '../AiSearch'
import { SelectedNodeView } from '../SelectedNodeView'
import { MediaPlayer } from './MediaPlayer'

Expand All @@ -29,9 +30,12 @@ export const SideBarSubView = ({ open }: Props) => {
>
<Wrapper>
<MediaPlayer key={playingNode?.ref_id} hidden={selectedNode?.ref_id !== playingNode?.ref_id} />
<ScrollWrapper>
<SelectedNodeView />
</ScrollWrapper>
<AiSearchScrollWrapper>
<ScrollWrapper>
<SelectedNodeView />
</ScrollWrapper>
<AiSearch contextSearch />
</AiSearchScrollWrapper>
<CloseButton
data-testid="close-sidebar-sub-view"
onClick={() => {
Expand Down Expand Up @@ -80,6 +84,11 @@ const CloseButton = styled(Flex)`
}
`

const AiSearchScrollWrapper = styled(Flex)`
flex: 1 1 100%;
overflow: hidden;
`

const ScrollWrapper = styled(Flex)`
flex: 1 1 100%;
border-radius: 16px;
Expand Down
11 changes: 7 additions & 4 deletions src/components/App/SideBar/__tests__/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Object.defineProperty(window, 'matchMedia', {
})

jest.mock('react-hook-form', () => ({
...jest.requireActual('react-hook-form'),
useFormContext: jest.fn(() => ({
setValue: jest.fn(),
register: jest.fn(),
Expand Down Expand Up @@ -145,9 +146,10 @@ describe('Test SideBar', () => {

fireEvent.change(searchInput, { target: { value: 'Lightning Network' } })

const searchIcon = screen.getByTestId('search-icon')
const searchIcons = screen.getAllByTestId('search-icon')

expect(searchIcon).toBeInTheDocument()
expect(searchIcons.length).toBeGreaterThan(0) // Ensure at least one exists
expect(searchIcons[0]).toBeInTheDocument()
;(async () => {
await waitFor(() => {
expect(onSubmitMock).toHaveBeenCalled()
Expand Down Expand Up @@ -201,9 +203,10 @@ describe('Test SideBar', () => {
</MemoryRouter>,
)

const searchIcon = screen.getByTestId('search-icon')
const searchIcons = screen.getAllByTestId('search-icon')

expect(searchIcon).toBeInTheDocument()
expect(searchIcons.length).toBeGreaterThan(0) // Ensure at least one exists
expect(searchIcons[0]).toBeInTheDocument()
})

it('Ensure that the Trending component is displayed when there is no search term.', () => {
Expand Down
8 changes: 5 additions & 3 deletions src/stores/useAiSummaryStore/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export type AiSummaryStore = {
setAbortRequests: (status: boolean) => void,
AISearchQuery?: string,
params?: FetchNodeParams,
context?: string,
) => void
newLoading: AIEntity | null
}
Expand Down Expand Up @@ -76,16 +77,17 @@ export const useAiSummaryStore = create<AiSummaryStore>()(

setAiRefId: (aiRefId) => set({ aiRefId }),

fetchAIData: async (setBudget, setAbortRequests, AISearchQuery = '', params) => {
fetchAIData: async (setBudget, setAbortRequests, AISearchQuery = '', params, context) => {
const { filters, addNewNode } = useDataStore.getState()
const currentPage = filters.skip
const itemsPerPage = filters.limit
const { currentSearch } = useAppStore.getState()
const { setAiSummaryAnswer, setNewLoading, aiRefId } = get()
const fullAiSearchQuery = context ? `${context} ${AISearchQuery}` : AISearchQuery

const ai = { ai_summary: String(!!AISearchQuery) }

addNewNode({ nodes: [{ ...QuestionNode, name: AISearchQuery, ref_id: AISearchQuery }], edges: [] })
addNewNode({ nodes: [{ ...QuestionNode, name: fullAiSearchQuery, ref_id: fullAiSearchQuery }], edges: [] })

if (AISearchQuery) {
setNewLoading({ question: AISearchQuery, answerLoading: true })
Expand All @@ -101,7 +103,7 @@ export const useAiSummaryStore = create<AiSummaryStore>()(
abortController = controller

const { node_type: filterNodeTypes, ...withoutNodeType } = filters
const word = AISearchQuery || currentSearch
const word = fullAiSearchQuery || currentSearch

const isLatest = isEqual(filters, defaultFilters) && !word

Expand Down

0 comments on commit f714834

Please sign in to comment.