From 7eb615a8ec4517868fb3692503b6ad0b41a058f9 Mon Sep 17 00:00:00 2001 From: ahaapple Date: Mon, 6 Jan 2025 11:11:15 +0800 Subject: [PATCH] AI-Powered Search History version 1 --- frontend/app/api/history-search/route.ts | 34 ++++ frontend/components/modal/search-model.tsx | 116 ++++++++++++++ .../components/sidebar/search-history.tsx | 21 +-- frontend/components/sidebar/sidebar-close.tsx | 3 +- .../components/sidebar/sidebar-header.tsx | 35 ++++ vector/db.ts | 106 +++++++++---- vector/ingest.ts | 15 +- vector/memfree_index.ts | 149 ++++++++++++++++++ vector/schema.ts | 16 +- vector/test/compact.test.ts | 7 +- vector/test/search.test.ts | 110 ++++++------- vector/test/vector.test.ts | 23 +-- 12 files changed, 506 insertions(+), 129 deletions(-) create mode 100644 frontend/app/api/history-search/route.ts create mode 100644 frontend/components/modal/search-model.tsx create mode 100644 frontend/components/sidebar/sidebar-header.tsx create mode 100644 vector/memfree_index.ts diff --git a/frontend/app/api/history-search/route.ts b/frontend/app/api/history-search/route.ts new file mode 100644 index 00000000..a8ba3196 --- /dev/null +++ b/frontend/app/api/history-search/route.ts @@ -0,0 +1,34 @@ +import { auth } from '@/auth'; +import { API_TOKEN, VECTOR_HOST } from '@/lib/env'; +import { NextResponse } from 'next/server'; + +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + const query = searchParams.get('q'); + + const session = await auth(); + if (!session?.user) { + return NextResponse.json({ message: 'Unauthorized' }, { status: 401 }); + } + const searchUrl = `${VECTOR_HOST}/api/vector/search`; + const response = await fetch(searchUrl, { + method: 'POST', + headers: { + Accept: 'application/json', + Authorization: API_TOKEN!, + }, + body: JSON.stringify({ + userId: session?.user.id, + query, + }), + }); + + if (!response.ok) { + throw new Error(`Error! status: ${response.status}`); + } + + const result = await response.json(); + console.log(result); + + return NextResponse.json(result); +} diff --git a/frontend/components/modal/search-model.tsx b/frontend/components/modal/search-model.tsx new file mode 100644 index 00000000..8dbed5bc --- /dev/null +++ b/frontend/components/modal/search-model.tsx @@ -0,0 +1,116 @@ +// components/SearchDialog.tsx +'use client'; + +import { useState } from 'react'; +import { useRouter } from 'next/navigation'; +import { Input } from '@/components/ui/input'; +import { Button } from '@/components/ui/button'; +import { Dialog, DialogContent, DialogHeader, DialogTitle } from '@/components/ui/dialog'; +import { Loader2, MessageCircle } from 'lucide-react'; +import { ScrollArea } from '@/components/ui/scroll-area'; + +interface SearchResult { + id: string; + title: string; + url: string; +} + +interface SearchDialogProps { + openSearch: boolean; + onOpenModelChange: (open: boolean) => void; +} + +interface SearchResult { + id: string; + title: string; + url: string; + text: string; +} + +export function SearchDialog({ openSearch: open, onOpenModelChange: onOpenChange }: SearchDialogProps) { + const router = useRouter(); + const [query, setQuery] = useState(''); + const [results, setResults] = useState([]); + const [isLoading, setIsLoading] = useState(false); + + const handleSearch = async (searchQuery: string) => { + if (!searchQuery.trim()) { + setResults([]); + return; + } + + setIsLoading(true); + try { + const response = await fetch(`/api/history-search?q=${encodeURIComponent(searchQuery)}`); + const data = await response.json(); + console.log(data); + setResults(data); + } catch (error) { + console.error('search error:', error); + setResults([]); + } finally { + setIsLoading(false); + } + }; + + const handleResultClick = (url: string) => { + router.push('/search/' + url); + onOpenChange(false); + }; + + return ( + + + + AI-Powered Search History + + +
+
+ setQuery(e.target.value)} + className="flex-1" + autoFocus + /> + +
+ + + {isLoading ? ( +
+
Searching ...
+
+ ) : results.length === 0 ? ( +
+
No Result
+
+ ) : ( +
+ {results.map((result) => ( +
handleResultClick(result.url)} + > +
+ +
+
+

{result.title}

+

{result.text}

+
+
+ ))} +
+ )} +
+
+
+
+ ); +} diff --git a/frontend/components/sidebar/search-history.tsx b/frontend/components/sidebar/search-history.tsx index 1c22dc7d..be46cd8e 100644 --- a/frontend/components/sidebar/search-history.tsx +++ b/frontend/components/sidebar/search-history.tsx @@ -1,16 +1,17 @@ import * as React from 'react'; - -import Link from 'next/link'; -import Image from 'next/image'; +import dynamic from 'next/dynamic'; import { SidebarList } from './sidebar-list'; -import { siteConfig } from '@/config'; -import { SidebarClose } from '@/components/sidebar/sidebar-close'; import { SignInButton } from '@/components/layout/sign-in-button'; import { User } from '@/lib/types'; import { NewSearchButton } from '@/components/shared/new-search-button'; import { buttonVariants } from '@/components/ui/button'; +const SidebarHeader = dynamic(() => import('@/components/sidebar/sidebar-header').then((mod) => mod.SidebarHeader), { + ssr: false, + loading: () =>
, +}); + interface SearchHistoryProps { user: User; } @@ -18,15 +19,7 @@ interface SearchHistoryProps { export async function SearchHistory({ user }: SearchHistoryProps) { return (
-
- - MemFree Logo - {siteConfig.name} - -
- -
-
+ {!user && }
diff --git a/frontend/components/sidebar/sidebar-close.tsx b/frontend/components/sidebar/sidebar-close.tsx index f79ddfc3..a67c88e4 100644 --- a/frontend/components/sidebar/sidebar-close.tsx +++ b/frontend/components/sidebar/sidebar-close.tsx @@ -19,8 +19,7 @@ export function SidebarClose() { toggleSidebar(); }} > - - Toggle Sidebar + )} diff --git a/frontend/components/sidebar/sidebar-header.tsx b/frontend/components/sidebar/sidebar-header.tsx new file mode 100644 index 00000000..51111996 --- /dev/null +++ b/frontend/components/sidebar/sidebar-header.tsx @@ -0,0 +1,35 @@ +'use client'; + +import * as React from 'react'; + +import Link from 'next/link'; +import Image from 'next/image'; + +import { siteConfig } from '@/config'; +import { SidebarClose } from '@/components/sidebar/sidebar-close'; +import { Button } from '@/components/ui/button'; +import { SearchDialog } from '@/components/modal/search-model'; +import { Search } from 'lucide-react'; + +export async function SidebarHeader() { + const [open, setOpen] = React.useState(false); + return ( +
+ + MemFree Logo + {siteConfig.name} + +
+ + +
+ +
+ ); +} diff --git a/vector/db.ts b/vector/db.ts index c9aacd7d..32f5023c 100644 --- a/vector/db.ts +++ b/vector/db.ts @@ -7,52 +7,96 @@ export class LanceDB { private config: DatabaseConfig; private db: any; private dbSchema: DBSchema; + private tableCreationLocks = new Map>(); constructor(config: DatabaseConfig, schema: DBSchema) { this.config = config; this.dbSchema = schema; } + private async withLock( + key: string, + fn: () => Promise + ): Promise { + if (this.tableCreationLocks.has(key)) { + return this.tableCreationLocks.get(key)! as Promise; + } + + const promise = fn().finally(() => { + this.tableCreationLocks.delete(key); + }); + this.tableCreationLocks.set(key, promise); + return promise; + } + + async getTable(tableName: string): Promise { + try { + if (!this.db) { + await this.connect(); + } + + if (!this.db) { + throw new Error("Database connection not established."); + } + + if ((await this.db.tableNames()).includes(tableName)) { + return this.db.openTable(tableName); + } else { + // to avoid Conflicting Append and Overwrite Transactions + return this.withLock(tableName, async () => { + // double check if table is created by another thread + const currentTableNames = await this.db.tableNames(); + if (currentTableNames.includes(tableName)) { + return this.db.openTable(tableName); + } + + console.log("Creating table", tableName); + return this.db.createEmptyTable(tableName, this.dbSchema.schema, { + mode: "create", + existOk: false, + }); + }); + } + } catch (error) { + console.error("Error getting table", tableName, error); + throw error; + } + } + private isS3Config(options: any): options is S3Config { return "bucket" in options; } async connect(): Promise { - if (this.config.type === "s3") { - if (!this.isS3Config(this.config.options)) { - throw new Error("Invalid S3 configuration"); - } - - const { bucket, awsAccessKeyId, awsSecretAccessKey, region, s3Express } = - this.config.options || {}; - this.db = await lancedb.connect(bucket, { - storageOptions: { + try { + if (this.config.type === "s3") { + if (!this.isS3Config(this.config.options)) { + throw new Error("Invalid S3 configuration"); + } + + const { + bucket, awsAccessKeyId, awsSecretAccessKey, region, s3Express, - }, - }); - } else { - const { localDirectory } = - (this.config.options as LocalConfig) || process.cwd(); - this.db = await lancedb.connect(localDirectory); - } - return this.db; - } - - async getTable(tableName: string): Promise { - if (!this.db) { - await this.connect(); - } - - if ((await this.db.tableNames()).includes(tableName)) { - return this.db.openTable(tableName); - } else { - return this.db.createEmptyTable(tableName, this.dbSchema.schema, { - mode: "create", - existOk: true, - }); + } = this.config.options || {}; + this.db = await lancedb.connect(bucket, { + storageOptions: { + awsAccessKeyId, + awsSecretAccessKey, + region, + s3Express, + }, + }); + } else { + const { localDirectory } = + (this.config.options as LocalConfig) || process.cwd(); + this.db = await lancedb.connect(localDirectory); + } + } catch (error) { + console.error("Error connecting to database", error); + throw error; } } diff --git a/vector/ingest.ts b/vector/ingest.ts index d52c2e98..3f8dc42b 100644 --- a/vector/ingest.ts +++ b/vector/ingest.ts @@ -13,7 +13,7 @@ const mdSplitter = RecursiveCharacterTextSplitter.fromLanguage("markdown", { chunkOverlap: 40, }); -const textSplitter = new RecursiveCharacterTextSplitter({ +export const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 400, chunkOverlap: 40, }); @@ -56,6 +56,17 @@ export async function processIngestion( const table = await db.append(userId, data); } +export async function appendData( + userId: string, + data: Array> +) { + const table = await db.append(userId, data); +} + +export async function compact(userId: string) { + await db.compact(userId); +} + export async function ingest_md( url: string, userId: string, @@ -97,7 +108,7 @@ export async function ingest_url(url: string, userId: string) { await processIngestion(url, userId, markdown, title, image ?? ""); } -async function addVectors( +export async function addVectors( image: string, title: string, url: string, diff --git a/vector/memfree_index.ts b/vector/memfree_index.ts new file mode 100644 index 00000000..5b1fb137 --- /dev/null +++ b/vector/memfree_index.ts @@ -0,0 +1,149 @@ +import { addVectors, appendData, compact, textSplitter } from "./ingest"; +import { redis } from "./redis"; +import type { Search } from "./type"; + +export const SEARCH_KEY = "search:"; +export const USER_SEARCH_KEY = "user:search:"; +export const LAST_INDEXED_TIME_KEY = "user:last_indexed_time:"; + +interface BatchSearchResult { + searches: Search[]; + lastTimestamp: number; + hasMore: boolean; +} + +async function getBatchSearchesByTimestamp( + userId: string, + fromTimestamp: number, + batchSize: number +): Promise { + try { + const searchIds: string[] = await redis.zrange( + USER_SEARCH_KEY + userId, + fromTimestamp, + "+inf", + { + byScore: true, + offset: 0, + count: batchSize, + } + ); + + if (!searchIds || searchIds.length === 0) { + return { + searches: [], + lastTimestamp: fromTimestamp, + hasMore: false, + }; + } + + const pipeline = redis.pipeline(); + searchIds.forEach((searchId) => { + pipeline.hgetall(searchId); + }); + + const results = (await pipeline.exec()) as Search[]; + + const lastSearchScore = await redis.zscore( + USER_SEARCH_KEY + userId, + searchIds[searchIds.length - 1] + ); + + return { + searches: results, + lastTimestamp: lastSearchScore || fromTimestamp, + hasMore: searchIds.length === batchSize, + }; + } catch (error) { + console.error("Failed to get batch searches:", error); + return { + searches: [], + lastTimestamp: fromTimestamp, + hasMore: false, + }; + } +} + +export async function processAllUserSearchMessages( + userId: string, + limit: number = 20 +) { + try { + let lastIndexedTime = + Number(await redis.get(LAST_INDEXED_TIME_KEY + userId)) || 0; + + console.time("processSearchMessages"); + while (true) { + console.time(`Processing batch from ${lastIndexedTime}`); + + const { searches, lastTimestamp, hasMore } = + await getBatchSearchesByTimestamp(userId, lastIndexedTime, limit); + + if (!searches || searches.length === 0) { + break; + } + + await Promise.all( + searches.map(async (search) => { + try { + if (!search?.messages || !Array.isArray(search.messages)) { + return; + } + + const messageDocumentsPromises = search.messages + .filter((message) => message.content) + .map((message) => + textSplitter.createDocuments([message.content]) + ); + + const titleDocumentsPromise = textSplitter.createDocuments([ + search.title, + ]); + + const [titleDocuments, ...messageDocumentsArrays] = + await Promise.all([ + titleDocumentsPromise, + ...messageDocumentsPromises, + ]); + + const documents = [ + ...messageDocumentsArrays.flat(), + ...titleDocuments, + ]; + + const data = await addVectors( + "", + search.title, + search.id, + documents + ); + console.log("data length", data.length); + await appendData(userId, data); + + console.log("search title ", search.title, "search id ", search.id); + } catch (error) { + console.error(`Failed to process search ${search.id}:`, error); + } + }) + ); + + lastIndexedTime = lastTimestamp; + await redis.set(LAST_INDEXED_TIME_KEY + userId, lastIndexedTime); + + if (!hasMore) { + break; + } + } + + console.timeEnd("processSearchMessages"); + + await compact(userId); + + return true; + } catch (error) { + console.error("Failed to process search messages:", error); + return false; + } +} + +await processAllUserSearchMessages(process.env.TEST_USER!, 20); diff --git a/vector/schema.ts b/vector/schema.ts index 76e51cd1..bb84bf5a 100644 --- a/vector/schema.ts +++ b/vector/schema.ts @@ -42,4 +42,18 @@ export const documentSchema: DBSchema = { ]), }; -SchemaFactory.registerSchema("document", documentSchema); +SchemaFactory.registerSchema(documentSchema.name, documentSchema); + +export const testSchema: DBSchema = { + name: "test", + schema: new Schema([ + new Field("create_time", new Float64(), true), + new Field("text", new Utf8(), true), + new Field( + "vector", + new FixedSizeList(DIMENSIONS, new Field("item", new Float32())), + true + ), + ]), +}; +SchemaFactory.registerSchema(testSchema.name, testSchema); diff --git a/vector/test/compact.test.ts b/vector/test/compact.test.ts index 22abd885..ef91e03e 100644 --- a/vector/test/compact.test.ts +++ b/vector/test/compact.test.ts @@ -1,12 +1,15 @@ import { describe, it } from "bun:test"; -import { compact } from "../db"; +import { DatabaseFactory } from "../db"; +import { testConfig } from "../config"; +import { testSchema } from "../schema"; const testUser = process.env.TEST_USER || "localTest"; +const db = DatabaseFactory.createDatabase(testConfig, testSchema); describe("compact", () => { it("should compact succesully", async () => { console.time("compact"); - await compact(testUser); + await db.compact(testUser); console.timeEnd("compact"); }, 5000000); }); diff --git a/vector/test/search.test.ts b/vector/test/search.test.ts index b565d65a..237594f0 100644 --- a/vector/test/search.test.ts +++ b/vector/test/search.test.ts @@ -26,67 +26,67 @@ describe("/api/vector/search endpoint", () => { expect(json).toEqual(expect.any(Object)); }, 100000); - it("seaech with url", async () => { - const query = "memfree"; + // it("seaech with url", async () => { + // const query = "memfree"; - const response = await fetch(`${host}/api/vector/search`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: API_TOKEN, - }, - body: JSON.stringify({ - query: query, - userId: testUser, - url: "https://www.memfree.me/", - }), - }); + // const response = await fetch(`${host}/api/vector/search`, { + // method: "POST", + // headers: { + // "Content-Type": "application/json", + // Authorization: API_TOKEN, + // }, + // body: JSON.stringify({ + // query: query, + // userId: testUser, + // url: "https://www.memfree.me/", + // }), + // }); - const json = await response.json(); - console.log(json); - expect(response.status).toBe(200); - expect(json).toEqual(expect.any(Object)); - }, 100000); + // const json = await response.json(); + // console.log(json); + // expect(response.status).toBe(200); + // expect(json).toEqual(expect.any(Object)); + // }, 100000); - it("should return empty results for not found user", async () => { - const query = "memfree"; - const notFoundUser = "notFoundUser"; + // it("should return empty results for not found user", async () => { + // const query = "memfree"; + // const notFoundUser = "notFoundUser"; - const response = await fetch(`${host}/api/vector/search`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: API_TOKEN, - }, - body: JSON.stringify({ - query: query, - userId: notFoundUser, - }), - }); + // const response = await fetch(`${host}/api/vector/search`, { + // method: "POST", + // headers: { + // "Content-Type": "application/json", + // Authorization: API_TOKEN, + // }, + // body: JSON.stringify({ + // query: query, + // userId: notFoundUser, + // }), + // }); - const json = await response.json(); - console.log(json); - expect(response.status).toBe(200); - expect(json).toEqual([]); - }, 10000); + // const json = await response.json(); + // console.log(json); + // expect(response.status).toBe(200); + // expect(json).toEqual([]); + // }, 10000); - it("should return 401", async () => { - const query = "memfree"; - const notFoundUser = "notFoundUser"; + // it("should return 401", async () => { + // const query = "memfree"; + // const notFoundUser = "notFoundUser"; - const response = await fetch(`${host}/api/vector/search`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - query: query, - userId: notFoundUser, - }), - }); + // const response = await fetch(`${host}/api/vector/search`, { + // method: "POST", + // headers: { + // "Content-Type": "application/json", + // }, + // body: JSON.stringify({ + // query: query, + // userId: notFoundUser, + // }), + // }); - const json = await response.json(); - console.log(json); - expect(response.status).toBe(401); - }, 10000); + // const json = await response.json(); + // console.log(json); + // expect(response.status).toBe(401); + // }, 10000); }); diff --git a/vector/test/vector.test.ts b/vector/test/vector.test.ts index a96c4d62..b43f7503 100644 --- a/vector/test/vector.test.ts +++ b/vector/test/vector.test.ts @@ -1,31 +1,10 @@ import { describe, it, expect } from "bun:test"; import { getEmbedding } from "../embedding/embedding"; -import { - Schema, - Field, - Float32, - FixedSizeList, - Utf8, - Float64, -} from "apache-arrow"; import { DIMENSIONS, testConfig } from "../config"; -import { SchemaFactory } from "../schema"; +import { SchemaFactory, testSchema } from "../schema"; import type { DBSchema } from "../type"; import { DatabaseFactory } from "../db"; -const testSchema: DBSchema = { - name: "test", - schema: new Schema([ - new Field("create_time", new Float64(), true), - new Field("text", new Utf8(), true), - new Field( - "vector", - new FixedSizeList(DIMENSIONS, new Field("item", new Float32())), - true - ), - ]), -}; -SchemaFactory.registerSchema(testSchema.name, testSchema); const db = DatabaseFactory.createDatabase(testConfig, testSchema); describe("vector test", () => {