-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PosgreSQL hybrid search #958
base: main
Are you sure you want to change the base?
Changes from all commits
df47b1b
c5c3099
f0a0ddd
a51cb41
100b1b8
a81ece7
664edd5
733683e
a32eb20
5450227
d9ff235
e8d0398
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,6 +59,14 @@ | |
|
||
this._columnsListNoEmbeddings = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload}"; | ||
this._columnsListWithEmbeddings = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload},{this._colEmbedding}"; | ||
this._columnsListHybrid = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload},{this._colEmbedding}"; | ||
this._columnsListHybridCoalesce = $@" | ||
COALESCE(semantic_search.{this._colId}, keyword_search.{this._colId}) AS {this._colId}, | ||
COALESCE(semantic_search.{this._colTags}, keyword_search.{this._colTags}) AS {this._colTags}, | ||
COALESCE(semantic_search.{this._colContent}, keyword_search.{this._colContent}) AS {this._colContent}, | ||
COALESCE(semantic_search.{this._colPayload}, keyword_search.{this._colPayload}) AS {this._colPayload}, | ||
COALESCE(semantic_search.{this._colEmbedding}, keyword_search.{this._colEmbedding}) AS {this._colEmbedding}, | ||
"; | ||
|
||
this._createTableSql = string.Empty; | ||
if (config.CreateTableSql?.Count > 0) | ||
|
@@ -138,6 +146,8 @@ | |
CancellationToken cancellationToken = default) | ||
{ | ||
var origInputTableName = tableName; | ||
var indexTags = this.WithTableNamePrefix(tableName) + "_idx_tags"; | ||
var indexContent = this.WithTableNamePrefix(tableName) + "_idx_content"; | ||
tableName = this.WithSchemaAndTableNamePrefix(tableName); | ||
this._log.LogTrace("Creating table: {0}", tableName); | ||
|
||
|
@@ -175,7 +185,8 @@ | |
{this._colContent} TEXT DEFAULT '' NOT NULL, | ||
{this._colPayload} JSONB DEFAULT '{{}}'::JSONB NOT NULL | ||
); | ||
CREATE INDEX IF NOT EXISTS idx_tags ON {tableName} USING GIN({this._colTags}); | ||
CREATE INDEX IF NOT EXISTS ""{indexTags}"" ON {tableName} USING GIN({this._colTags}); | ||
CREATE INDEX IF NOT EXISTS ""{indexContent}"" ON {tableName} USING GIN(to_tsvector('english',{this._colContent})); | ||
COMMIT; | ||
"; | ||
#pragma warning restore CA2100 | ||
|
@@ -388,23 +399,27 @@ | |
/// Get a list of records | ||
/// </summary> | ||
/// <param name="tableName">Table containing the records to fetch</param> | ||
/// <param name="query">Prompt query. Only used in the case of hybrid search</param> | ||
/// <param name="target">Source vector to compare for similarity</param> | ||
/// <param name="minSimilarity">Minimum similarity threshold</param> | ||
/// <param name="filterSql">SQL filter to apply</param> | ||
/// <param name="sqlUserValues">List of user values passed with placeholders to avoid SQL injection</param> | ||
/// <param name="limit">Max number of records to retrieve</param> | ||
/// <param name="offset">Records to skip from the top</param> | ||
/// <param name="withEmbeddings">Whether to include embedding vectors</param> | ||
/// <param name="useHybridSearch">Whether to use hybrid search or vector search</param> | ||
/// <param name="cancellationToken">Async task cancellation token</param> | ||
public async IAsyncEnumerable<(PostgresMemoryRecord record, double similarity)> GetSimilarAsync( | ||
string tableName, | ||
string query, | ||
Vector target, | ||
double minSimilarity, | ||
string? filterSql = null, | ||
Dictionary<string, object>? sqlUserValues = null, | ||
int limit = 1, | ||
int offset = 0, | ||
bool withEmbeddings = false, | ||
bool useHybridSearch = false, | ||
[EnumeratorCancellation] CancellationToken cancellationToken = default) | ||
{ | ||
tableName = this.WithSchemaAndTableNamePrefix(tableName); | ||
|
@@ -413,14 +428,19 @@ | |
|
||
// Column names | ||
string columns = withEmbeddings ? this._columnsListWithEmbeddings : this._columnsListNoEmbeddings; | ||
string columnsHibrid = this._columnsListHybrid; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: var name typo |
||
string columnsListHybridCoalesce = this._columnsListHybridCoalesce; | ||
Comment on lines
+431
to
+432
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these 2 vars could be removed, is there a reason to copy the values? |
||
|
||
// Filtering logic, including filter by similarity | ||
// | ||
filterSql = filterSql?.Trim().Replace(PostgresSchema.PlaceholdersTags, this._colTags, StringComparison.Ordinal); | ||
if (string.IsNullOrWhiteSpace(filterSql)) | ||
{ | ||
filterSql = "TRUE"; | ||
} | ||
|
||
string filterSqlHybridText = filterSql; | ||
Check warning on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Unit Tests (9.0.x, ubuntu-latest)
Check warning on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Unit Tests (9.0.x, ubuntu-latest)
Check failure on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Build (9.0.x, ubuntu-latest, Debug)
Check failure on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Build (9.0.x, ubuntu-latest, Debug)
Check warning on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Unit Tests (9.0.x, windows-latest)
Check warning on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Unit Tests (9.0.x, windows-latest)
Check failure on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Build (9.0.x, ubuntu-latest, Release)
|
||
|
||
Check warning on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Unit Tests (9.0.x, ubuntu-latest)
Check warning on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Unit Tests (9.0.x, ubuntu-latest)
Check failure on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Build (9.0.x, ubuntu-latest, Debug)
Check failure on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Build (9.0.x, ubuntu-latest, Debug)
Check warning on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Unit Tests (9.0.x, windows-latest)
Check warning on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Unit Tests (9.0.x, windows-latest)
Check failure on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs GitHub Actions / Build (9.0.x, ubuntu-latest, Release)
|
||
var maxDistance = 1 - minSimilarity; | ||
filterSql += $" AND {this._colEmbedding} <=> @embedding < @maxDistance"; | ||
|
||
|
@@ -440,16 +460,51 @@ | |
#pragma warning disable CA2100 // SQL reviewed | ||
string colDistance = "__distance"; | ||
|
||
// When using 1 - (embedding <=> target) the index is not being used, therefore we calculate | ||
// the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause. | ||
cmd.CommandText = @$" | ||
SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance} | ||
FROM {tableName} | ||
WHERE {filterSql} | ||
ORDER BY {colDistance} ASC | ||
if (useHybridSearch) | ||
{ | ||
// When using 1 - (embedding <=> target) the index is not being used, therefore we calculate | ||
// the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause. | ||
cmd.CommandText = @$" | ||
WITH semantic_search AS ( | ||
SELECT {columnsHibrid}, RANK () OVER (ORDER BY {this._colEmbedding} <=> @embedding) AS rank | ||
FROM {tableName} | ||
WHERE {filterSql} | ||
ORDER BY {this._colEmbedding} <=> @embedding | ||
LIMIT @limit | ||
), | ||
keyword_search AS ( | ||
SELECT {columnsHibrid}, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', {this._colContent}), query) DESC) | ||
FROM {tableName}, plainto_tsquery('english', @query) query | ||
WHERE {filterSqlHybridText} AND to_tsvector('english', {this._colContent}) @@ query | ||
ORDER BY ts_rank_cd(to_tsvector('english', {this._colContent}), query) DESC | ||
Comment on lines
+477
to
+479
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: language should be configurable |
||
LIMIT @limit | ||
) | ||
SELECT | ||
{columnsListHybridCoalesce} | ||
COALESCE(1.0 / (60 + semantic_search.rank), 0.0) + | ||
COALESCE(1.0 / (60 + keyword_search.rank), 0.0) AS {colDistance} | ||
Comment on lines
+484
to
+485
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you document or point to documentation explaining these formulas? e.g. what is |
||
FROM semantic_search | ||
FULL OUTER JOIN keyword_search ON semantic_search.{this._colId} = keyword_search.{this._colId} | ||
ORDER BY {colDistance} DESC | ||
LIMIT @limit | ||
OFFSET @offset | ||
"; | ||
cmd.Parameters.AddWithValue("@query", query); | ||
cmd.Parameters.AddWithValue("@minSimilarity", minSimilarity); | ||
} | ||
else | ||
{ | ||
// When using 1 - (embedding <=> target) the index is not being used, therefore we calculate | ||
// the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause. | ||
cmd.CommandText = @$" | ||
SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance} | ||
FROM {tableName} | ||
WHERE {filterSql} | ||
ORDER BY {colDistance} ASC | ||
LIMIT @limit | ||
OFFSET @offset | ||
"; | ||
} | ||
|
||
cmd.Parameters.AddWithValue("@embedding", target); | ||
cmd.Parameters.AddWithValue("@maxDistance", maxDistance); | ||
|
@@ -692,6 +747,8 @@ | |
private readonly string _colPayload; | ||
private readonly string _columnsListNoEmbeddings; | ||
private readonly string _columnsListWithEmbeddings; | ||
private readonly string _columnsListHybrid; | ||
private readonly string _columnsListHybridCoalesce; | ||
private readonly bool _dbNamePresent; | ||
|
||
/// <summary> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove the trailing comma (consistency with the way these values are used)