Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sqs committed Jan 7, 2024
1 parent 4275f70 commit 78b1e51
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 32 deletions.
9 changes: 5 additions & 4 deletions provider/docs/bin/search.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { readFile } from 'fs/promises'
import path from 'path'
import { createClient } from '../src/client/client'
import { type CorpusIndex } from '../src/corpus/index/corpusIndex'
import { fromJSON } from '../src/corpus/index/corpusIndex'

const args = process.argv.slice(2)

Expand All @@ -25,7 +25,7 @@ if (args.length !== 2) {
process.exit(1)
}

const index = JSON.parse(await readFile(indexFile, 'utf8')) as CorpusIndex
const index = fromJSON(JSON.parse(await readFile(indexFile, 'utf8')))

const client = createClient(index, { logger: message => console.error('# ' + message) })

Expand All @@ -37,8 +37,9 @@ for (const [i, result] of results.slice(0, MAX_RESULTS).entries()) {
if (i !== 0) {
console.log()
}
console.log(`#${i + 1} [${result.score.toFixed(3)}] ${doc.doc.url ?? `doc${doc.doc.id}`}#chunk${result.chunk}`)
console.log(`${indent(truncate(result.excerpt.replaceAll('\n\n', '\n'), 500), '\t')}`)
console.log(`#${i + 1} [${result.score.toFixed(3)}] ${doc.doc.url ?? ''} doc${doc.doc.id}#chunk${result.chunk}`)
const chunk = doc.chunks[result.chunk]
console.log(`${indent(truncate(chunk.text.replaceAll('\n\n', '\n'), 500), '\t')}`)
}

function truncate(text: string, maxLength: number): string {
Expand Down
51 changes: 28 additions & 23 deletions provider/docs/src/client/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,42 @@ export interface SearchOptions {
* Search using multiple search methods.
*/
export async function search(index: CorpusIndex, query: Query, { logger }: SearchOptions): Promise<SearchResult[]> {
const allResults = (
await Promise.all(
Object.entries(SEARCH_METHODS).map(async ([name, searchFn]) => {
const t0 = performance.now()
const results = await searchFn(index, query)
logger?.(`search[${name}] took ${Math.round(performance.now() - t0)}ms`)
return results
})
)
).flat()
const allResults = await Promise.all(
Object.entries(SEARCH_METHODS).map(async ([name, searchFn]) => {
const t0 = performance.now()
const results = await searchFn(index, query)
logger?.(`search[${name}] took ${Math.round(performance.now() - t0)}ms`)
return [name, results] as [string, SearchResult[]]
})
)

// Sum scores for each chunk.
const combinedResults = new Map<DocID, Map<ChunkIndex, SearchResult>>()
for (const result of allResults) {
let docResults = combinedResults.get(result.doc)
if (!docResults) {
docResults = new Map<ChunkIndex, SearchResult>()
combinedResults.set(result.doc, docResults)
}
for (const [searchMethod, results] of allResults) {
for (const result of results) {
let docResults = combinedResults.get(result.doc)
if (!docResults) {
docResults = new Map<ChunkIndex, SearchResult>()
combinedResults.set(result.doc, docResults)
}

const chunkResult = docResults.get(result.chunk) ?? {
doc: result.doc,
chunk: result.chunk,
score: 0,
excerpt: result.excerpt,
const chunkResult: SearchResult = docResults.get(result.chunk) ?? {
doc: result.doc,
chunk: result.chunk,
score: 0,
scores: {},
excerpt: result.excerpt,
}
docResults.set(result.chunk, {
...chunkResult,
score: chunkResult.score + result.score,
scores: { ...chunkResult.scores, [searchMethod]: result.score },
})
}
docResults.set(result.chunk, { ...chunkResult, score: chunkResult.score + result.score })
}

const results = Array.from(combinedResults.values()).flatMap(docResults => Array.from(docResults.values()))
const MIN_SCORE = 0.3
const MIN_SCORE = 0.001
return results.filter(s => s.score >= MIN_SCORE).toSorted((a, b) => b.score - a.score)
}

Expand Down
18 changes: 14 additions & 4 deletions provider/docs/src/corpus/index/corpusIndex.test.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import { describe, expect, test } from 'vitest'
import { createCorpusArchive } from '../archive/corpusArchive'
import { type Doc, type DocID } from '../doc/doc'
import { createCorpusIndex } from './corpusIndex'
import { createCorpusIndex, fromJSON } from './corpusIndex'

export function doc(id: DocID, text: string): Doc {
return { id, text }
}

describe('indexCorpus', () => {
test('#docs', async () => {
expect((await createCorpusIndex(await createCorpusArchive([doc(1, 'a'), doc(2, 'b')]))).docs.length).toBe(2)
describe('indexCorpus', async () => {
const INDEX = await createCorpusIndex(await createCorpusArchive([doc(1, 'a'), doc(2, 'b')]))

test('docs', () => {
expect(INDEX.docs.length).toBe(2)
})

test('JSON-serializable', async () => {
const serialized = fromJSON(JSON.parse(JSON.stringify(INDEX)))
const indexWithoutToJSON = { ...INDEX }
delete (indexWithoutToJSON as any).toJSON
expect(serialized.docs[0].chunks[0].embeddings).toBeInstanceOf(Float32Array)
expect(serialized.docs).toEqual(INDEX.docs)
})
})
40 changes: 39 additions & 1 deletion provider/docs/src/corpus/index/corpusIndex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,16 @@ export async function createCorpusIndex(
): Promise<CorpusIndex> {
const docs = await indexCorpusDocs(archive, { contentExtractor })
const tfidf = createTFIDFIndex(docs)
return {
const index: CorpusIndex = {
docs,
tfidf,
}
const serializable = {
...index,
/** Handles serializing the Float32Array values. */
toJSON: () => toJSON(index),
}
return serializable
}

async function indexCorpusDocs(
Expand Down Expand Up @@ -68,3 +74,35 @@ async function indexCorpusDocs(
})
)
}

function toJSON(index: CorpusIndex): any {
return {
...index,
docs: index.docs.map(doc => ({
...doc,
chunks: doc.chunks.map(chunk => ({
...chunk,
embeddings: Array.from(chunk.embeddings),
})),
})),
}
}

/**
* Must be called on any {@link CorpusIndex} value that was deserialized using `JSON.parse`.
*/
export function fromJSON(indexData: any): CorpusIndex {
return {
...indexData,
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-member-access
docs: indexData.docs.map((doc: any) => ({
...doc,
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-member-access
chunks: doc.chunks.map((chunk: any) => ({
...chunk,
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
embeddings: new Float32Array(chunk.embeddings),
})),
})),
}
}
7 changes: 7 additions & 0 deletions provider/docs/src/e2e.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ describe('e2e', () => {
chunk: 3,
excerpt: 'Audio URL parsing\n\nTo parse an audio URL, use the `parseAudioURL` function.',
score: 0.764,
scores: {
embeddingsSearch: 0.662,
keywordSearch: 0.102,
},
},
])
})
Expand All @@ -29,5 +33,8 @@ describe('e2e', () => {
function roundScores(results: SearchResult[]) {
for (const result of results) {
result.score = Math.round(result.score * 1000) / 1000
for (const [searchMethod, score] of Object.entries(result.scores)) {
result.scores[searchMethod] = Math.round(score * 1000) / 1000
}
}
}
6 changes: 6 additions & 0 deletions provider/docs/src/search/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ export interface Query {
export interface SearchResult {
doc: DocID
chunk: ChunkIndex

/** The final score after combining the individual scores from different search methods. */
score: number

/** Scores from all search methods that returned this result. */
scores: { [searchMethod: string]: number }

excerpt: string
}

0 comments on commit 78b1e51

Please sign in to comment.