Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PosgreSQL hybrid search #958

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
40 changes: 17 additions & 23 deletions extensions/Postgres/Postgres.TestApplication/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI.Ollama;
using Microsoft.KernelMemory.DocumentStorage.DevTools;
using Microsoft.KernelMemory.FileSystem.DevTools;

Expand All @@ -26,16 +27,13 @@ private static async Task Test1()
var postgresConfig = cfg.GetSection("KernelMemory:Services:Postgres").Get<PostgresConfig>();
ArgumentNullExceptionEx.ThrowIfNull(postgresConfig, nameof(postgresConfig), "Postgres config not found");

var azureOpenAIEmbeddingConfig = cfg.GetSection("KernelMemory:Services:AzureOpenAIEmbedding").Get<AzureOpenAIConfig>();
ArgumentNullExceptionEx.ThrowIfNull(azureOpenAIEmbeddingConfig, nameof(azureOpenAIEmbeddingConfig), "AzureOpenAIEmbedding config not found");

var azureOpenAITextConfig = cfg.GetSection("KernelMemory:Services:AzureOpenAIText").Get<AzureOpenAIConfig>();
ArgumentNullExceptionEx.ThrowIfNull(azureOpenAITextConfig, nameof(azureOpenAITextConfig), "AzureOpenAIText config not found");
var ollamaConfig = cfg.GetSection("KernelMemory:Services:Ollama").Get<OllamaConfig>();
ArgumentNullExceptionEx.ThrowIfNull(ollamaConfig, nameof(ollamaConfig), "Ollama config not found");

// Concatenate our 'WithPostgresMemoryDb()' after 'WithOpenAIDefaults()' from the core nuget
var mem1 = new KernelMemoryBuilder()
.WithAzureOpenAITextGeneration(azureOpenAITextConfig)
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig)
.WithOllamaTextEmbeddingGeneration(ollamaConfig)
.WithOllamaTextGeneration(ollamaConfig)
.WithPostgresMemoryDb(postgresConfig)
.WithSimpleFileStorage(SimpleFileStorageConfig.Persistent)
.Build();
Expand All @@ -44,16 +42,16 @@ private static async Task Test1()
var mem2 = new KernelMemoryBuilder()
.WithPostgresMemoryDb(postgresConfig)
.WithSimpleFileStorage(SimpleFileStorageConfig.Persistent)
.WithAzureOpenAITextGeneration(azureOpenAITextConfig)
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig)
.WithOllamaTextEmbeddingGeneration(ollamaConfig)
.WithOllamaTextGeneration(ollamaConfig)
.Build();

// Concatenate our 'WithPostgresMemoryDb()' before and after KM builder extension methods from the core nuget
var mem3 = new KernelMemoryBuilder()
.WithSimpleFileStorage(SimpleFileStorageConfig.Persistent)
.WithAzureOpenAITextGeneration(azureOpenAITextConfig)
.WithOllamaTextEmbeddingGeneration(ollamaConfig)
.WithOllamaTextGeneration(ollamaConfig)
.WithPostgresMemoryDb(postgresConfig)
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig)
.Build();

await mem1.DeleteIndexAsync("index1");
Expand Down Expand Up @@ -92,22 +90,20 @@ private static async Task Test1()
private static async Task Test2()
{
var postgresConfig = new PostgresConfig();
var azureOpenAIEmbeddingConfig = new AzureOpenAIConfig();
var azureOpenAITextConfig = new AzureOpenAIConfig();
var ollamaConfig = new OllamaConfig();

new ConfigurationBuilder()
.AddJsonFile("appsettings.json")
.AddJsonFile("appsettings.development.json", optional: true)
.AddJsonFile("appsettings.Development.json", optional: true)
.Build()
.BindSection("KernelMemory:Services:Postgres", postgresConfig)
.BindSection("KernelMemory:Services:AzureOpenAIEmbedding", azureOpenAIEmbeddingConfig)
.BindSection("KernelMemory:Services:AzureOpenAIText", azureOpenAITextConfig);
.BindSection("KernelMemory:Services:Ollama", ollamaConfig);

var memory = new KernelMemoryBuilder()
.WithPostgresMemoryDb(postgresConfig)
.WithAzureOpenAITextGeneration(azureOpenAITextConfig)
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig)
.WithOllamaTextGeneration(ollamaConfig)
.WithOllamaTextEmbeddingGeneration(ollamaConfig)
.WithSimpleFileStorage(new SimpleFileStorageConfig
{
StorageType = FileSystemTypes.Disk,
Expand Down Expand Up @@ -140,8 +136,7 @@ private static async Task Test2()
private static async Task Test3()
{
var postgresConfig = new PostgresConfig();
var azureOpenAIEmbeddingConfig = new AzureOpenAIConfig();
var azureOpenAITextConfig = new AzureOpenAIConfig();
var ollamaConfig = new OllamaConfig();

// Note: using appsettings.custom-sql.json
new ConfigurationBuilder()
Expand All @@ -151,13 +146,12 @@ private static async Task Test3()
.AddJsonFile("appsettings.custom-sql.json")
.Build()
.BindSection("KernelMemory:Services:Postgres", postgresConfig)
.BindSection("KernelMemory:Services:AzureOpenAIEmbedding", azureOpenAIEmbeddingConfig)
.BindSection("KernelMemory:Services:AzureOpenAIText", azureOpenAITextConfig);
.BindSection("KernelMemory:Services:Ollama", ollamaConfig);

var memory = new KernelMemoryBuilder()
.WithPostgresMemoryDb(postgresConfig)
.WithAzureOpenAITextGeneration(azureOpenAITextConfig)
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig)
.WithOllamaTextGeneration(ollamaConfig)
.WithOllamaTextEmbeddingGeneration(ollamaConfig)
.WithSimpleFileStorage(new SimpleFileStorageConfig
{
StorageType = FileSystemTypes.Disk,
Expand Down
73 changes: 65 additions & 8 deletions extensions/Postgres/Postgres/Internals/PostgresDbClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@

this._columnsListNoEmbeddings = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload}";
this._columnsListWithEmbeddings = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload},{this._colEmbedding}";
this._columnsListHybrid = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload},{this._colEmbedding}";
this._columnsListHybridCoalesce = $@"
COALESCE(semantic_search.{this._colId}, keyword_search.{this._colId}) AS {this._colId},
COALESCE(semantic_search.{this._colTags}, keyword_search.{this._colTags}) AS {this._colTags},
COALESCE(semantic_search.{this._colContent}, keyword_search.{this._colContent}) AS {this._colContent},
COALESCE(semantic_search.{this._colPayload}, keyword_search.{this._colPayload}) AS {this._colPayload},
COALESCE(semantic_search.{this._colEmbedding}, keyword_search.{this._colEmbedding}) AS {this._colEmbedding},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove the trailing comma (consistency with the way these values are used)

";

this._createTableSql = string.Empty;
if (config.CreateTableSql?.Count > 0)
Expand Down Expand Up @@ -138,6 +146,8 @@
CancellationToken cancellationToken = default)
{
var origInputTableName = tableName;
var indexTags = this.WithTableNamePrefix(tableName) + "_idx_tags";
var indexContent = this.WithTableNamePrefix(tableName) + "_idx_content";
tableName = this.WithSchemaAndTableNamePrefix(tableName);
this._log.LogTrace("Creating table: {0}", tableName);

Expand Down Expand Up @@ -175,7 +185,8 @@
{this._colContent} TEXT DEFAULT '' NOT NULL,
{this._colPayload} JSONB DEFAULT '{{}}'::JSONB NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_tags ON {tableName} USING GIN({this._colTags});
CREATE INDEX IF NOT EXISTS ""{indexTags}"" ON {tableName} USING GIN({this._colTags});
CREATE INDEX IF NOT EXISTS ""{indexContent}"" ON {tableName} USING GIN(to_tsvector('english',{this._colContent}));
COMMIT;
";
#pragma warning restore CA2100
Expand Down Expand Up @@ -388,23 +399,27 @@
/// Get a list of records
/// </summary>
/// <param name="tableName">Table containing the records to fetch</param>
/// <param name="query">Prompt query. Only used in the case of hybrid search</param>
/// <param name="target">Source vector to compare for similarity</param>
/// <param name="minSimilarity">Minimum similarity threshold</param>
/// <param name="filterSql">SQL filter to apply</param>
/// <param name="sqlUserValues">List of user values passed with placeholders to avoid SQL injection</param>
/// <param name="limit">Max number of records to retrieve</param>
/// <param name="offset">Records to skip from the top</param>
/// <param name="withEmbeddings">Whether to include embedding vectors</param>
/// <param name="useHybridSearch">Whether to use hybrid search or vector search</param>
/// <param name="cancellationToken">Async task cancellation token</param>
public async IAsyncEnumerable<(PostgresMemoryRecord record, double similarity)> GetSimilarAsync(
string tableName,
string query,
Vector target,
double minSimilarity,
string? filterSql = null,
Dictionary<string, object>? sqlUserValues = null,
int limit = 1,
int offset = 0,
bool withEmbeddings = false,
bool useHybridSearch = false,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
tableName = this.WithSchemaAndTableNamePrefix(tableName);
Expand All @@ -413,14 +428,19 @@

// Column names
string columns = withEmbeddings ? this._columnsListWithEmbeddings : this._columnsListNoEmbeddings;
string columnsHibrid = this._columnsListHybrid;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: var name typo columnsHibrid => columnsHybrid

string columnsListHybridCoalesce = this._columnsListHybridCoalesce;
Comment on lines +431 to +432
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these 2 vars could be removed, is there a reason to copy the values?


// Filtering logic, including filter by similarity
//
filterSql = filterSql?.Trim().Replace(PostgresSchema.PlaceholdersTags, this._colTags, StringComparison.Ordinal);
if (string.IsNullOrWhiteSpace(filterSql))
{
filterSql = "TRUE";
}

string filterSqlHybridText = filterSql;

Check warning on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Unit Tests (9.0.x, ubuntu-latest)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check warning on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Unit Tests (9.0.x, ubuntu-latest)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check failure on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Build (9.0.x, ubuntu-latest, Debug)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check failure on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Build (9.0.x, ubuntu-latest, Debug)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check warning on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Unit Tests (9.0.x, windows-latest)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check warning on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Unit Tests (9.0.x, windows-latest)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check failure on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Build (9.0.x, ubuntu-latest, Release)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check failure on line 442 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Build (9.0.x, ubuntu-latest, Release)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check warning on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Unit Tests (9.0.x, ubuntu-latest)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check warning on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Unit Tests (9.0.x, ubuntu-latest)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check failure on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Build (9.0.x, ubuntu-latest, Debug)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check failure on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Build (9.0.x, ubuntu-latest, Debug)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check warning on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Unit Tests (9.0.x, windows-latest)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check warning on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Unit Tests (9.0.x, windows-latest)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check failure on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Build (9.0.x, ubuntu-latest, Release)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)

Check failure on line 443 in extensions/Postgres/Postgres/Internals/PostgresDbClient.cs

View workflow job for this annotation

GitHub Actions / Build (9.0.x, ubuntu-latest, Release)

Fix formatting (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/style-rules/ide0055)
var maxDistance = 1 - minSimilarity;
filterSql += $" AND {this._colEmbedding} <=> @embedding < @maxDistance";

Expand All @@ -440,16 +460,51 @@
#pragma warning disable CA2100 // SQL reviewed
string colDistance = "__distance";

// When using 1 - (embedding <=> target) the index is not being used, therefore we calculate
// the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause.
cmd.CommandText = @$"
SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance}
FROM {tableName}
WHERE {filterSql}
ORDER BY {colDistance} ASC
if (useHybridSearch)
{
// When using 1 - (embedding <=> target) the index is not being used, therefore we calculate
// the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause.
cmd.CommandText = @$"
WITH semantic_search AS (
SELECT {columnsHibrid}, RANK () OVER (ORDER BY {this._colEmbedding} <=> @embedding) AS rank
FROM {tableName}
WHERE {filterSql}
ORDER BY {this._colEmbedding} <=> @embedding
LIMIT @limit
),
keyword_search AS (
SELECT {columnsHibrid}, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', {this._colContent}), query) DESC)
FROM {tableName}, plainto_tsquery('english', @query) query
WHERE {filterSqlHybridText} AND to_tsvector('english', {this._colContent}) @@ query
ORDER BY ts_rank_cd(to_tsvector('english', {this._colContent}), query) DESC
Comment on lines +477 to +479
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: language should be configurable

LIMIT @limit
)
SELECT
{columnsListHybridCoalesce}
COALESCE(1.0 / (60 + semantic_search.rank), 0.0) +
COALESCE(1.0 / (60 + keyword_search.rank), 0.0) AS {colDistance}
Comment on lines +484 to +485
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you document or point to documentation explaining these formulas? e.g. what is 60?

FROM semantic_search
FULL OUTER JOIN keyword_search ON semantic_search.{this._colId} = keyword_search.{this._colId}
ORDER BY {colDistance} DESC
LIMIT @limit
OFFSET @offset
";
cmd.Parameters.AddWithValue("@query", query);
cmd.Parameters.AddWithValue("@minSimilarity", minSimilarity);
}
else
{
// When using 1 - (embedding <=> target) the index is not being used, therefore we calculate
// the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause.
cmd.CommandText = @$"
SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance}
FROM {tableName}
WHERE {filterSql}
ORDER BY {colDistance} ASC
LIMIT @limit
OFFSET @offset
";
}

cmd.Parameters.AddWithValue("@embedding", target);
cmd.Parameters.AddWithValue("@maxDistance", maxDistance);
Expand Down Expand Up @@ -692,6 +747,8 @@
private readonly string _colPayload;
private readonly string _columnsListNoEmbeddings;
private readonly string _columnsListWithEmbeddings;
private readonly string _columnsListHybrid;
private readonly string _columnsListHybridCoalesce;
private readonly bool _dbNamePresent;

/// <summary>
Expand Down
6 changes: 6 additions & 0 deletions extensions/Postgres/Postgres/PostgresConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ public class PostgresConfig
/// </summary>
public List<string> CreateTableSql { get; set; } = [];

/// <summary>
/// Important: when using hybrid search, relevance scores
/// are very different from when using just vector search.
/// </summary>
public bool UseHybridSearch { get; set; } = false;

/// <summary>
/// Create a new instance of the configuration
/// </summary>
Expand Down
5 changes: 5 additions & 0 deletions extensions/Postgres/Postgres/PostgresMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public sealed class PostgresMemory : IMemoryDb, IDisposable, IAsyncDisposable
private readonly ITextEmbeddingGenerator _embeddingGenerator;
private readonly ILogger<PostgresMemory> _log;

private readonly bool _useHybridSearch;

/// <summary>
/// Create a new instance of Postgres KM connector
/// </summary>
Expand All @@ -40,6 +42,7 @@ public PostgresMemory(
ILoggerFactory? loggerFactory = null)
{
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger<PostgresMemory>();
this._useHybridSearch = config.UseHybridSearch;

this._embeddingGenerator = embeddingGenerator;
if (this._embeddingGenerator == null)
Expand Down Expand Up @@ -159,12 +162,14 @@ await this._db.UpsertAsync(

var records = this._db.GetSimilarAsync(
index,
query: text,
target: new Vector(textEmbedding.Data),
minSimilarity: minRelevance,
filterSql: sql,
sqlUserValues: unsafeSqlUserValues,
limit: limit,
withEmbeddings: withEmbeddings,
useHybridSearch: this._useHybridSearch,
cancellationToken: cancellationToken).ConfigureAwait(false);

await foreach ((PostgresMemoryRecord record, double similarity) result in records)
Expand Down
5 changes: 4 additions & 1 deletion service/Service/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,10 @@
"ConnectionString": "Host=localhost;Port=5432;Username=public;Password=;Database=public",
// Mandatory prefix to add to the name of table managed by KM,
// e.g. to exclude other tables in the same schema.
"TableNamePrefix": "km-"
"TableNamePrefix": "km-",
// Hybrid search is not enabled by default. Note that when using hybrid search
// relevance scores are different, usually lower, than when using just vector search
"UseHybridSearch": false,
},
"Qdrant": {
// Qdrant endpoint
Expand Down
Loading