Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HHH-19054 - pgvector support for CockroachDB #9641

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion docker_db.sh
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,44 @@ hana() {
}

cockroachdb() {
cockroachdb_24_1
cockroachdb_24_3
}

cockroachdb_24_3() {
$CONTAINER_CLI rm -f cockroach || true
LOG_CONFIG="
sinks:
stderr:
channels: all
filter: ERROR
redact: false
exit-on-error: true
"
$CONTAINER_CLI run -d --name=cockroach -m 6g -p 26257:26257 -p 8080:8080 ${DB_IMAGE_COCKROACHDB_24_3:-cockroachdb/cockroach:v24.3.3} start-single-node \
--insecure --store=type=mem,size=0.25 --advertise-addr=localhost --log="$LOG_CONFIG"
OUTPUT=
while [[ $OUTPUT != *"CockroachDB node starting"* ]]; do
echo "Waiting for CockroachDB to start..."
sleep 10
# Note we need to redirect stderr to stdout to capture the logs
OUTPUT=$($CONTAINER_CLI logs cockroach 2>&1)
done
echo "Enabling experimental box2d operators and some optimized settings for running the tests"
#settings documented in https://www.cockroachlabs.com/docs/v24.1/local-testing#use-a-local-single-node-cluster-with-in-memory-storage
$CONTAINER_CLI exec cockroach bash -c "cat <<EOF | ./cockroach sql --insecure
SET CLUSTER SETTING sql.spatial.experimental_box2d_comparison_operators.enabled = on;
SET CLUSTER SETTING kv.range_merge.queue_interval = '50ms';
SET CLUSTER SETTING jobs.registry.interval.gc = '30s';
SET CLUSTER SETTING jobs.registry.interval.cancel = '180s';
SET CLUSTER SETTING jobs.retention_time = '5s';
SET CLUSTER SETTING sql.stats.automatic_collection.enabled = false;
ALTER RANGE default CONFIGURE ZONE USING "gc.ttlseconds" = 300;
ALTER DATABASE system CONFIGURE ZONE USING "gc.ttlseconds" = 300;

quit
EOF
"
echo "Cockroachdb successfully started"
}

cockroachdb_24_1() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.boot.model.FunctionContributor;
import org.hibernate.dialect.CockroachDialect;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.PostgreSQLDialect;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
Expand All @@ -24,7 +25,8 @@ public void contributeFunctions(FunctionContributions functionContributions) {
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
final Dialect dialect = functionContributions.getDialect();
if ( dialect instanceof PostgreSQLDialect ) {
if ( dialect instanceof PostgreSQLDialect ||
dialect instanceof CockroachDialect ) {
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
final BasicType<Integer> integerType = basicTypeRegistry.resolve( StandardBasicTypes.INTEGER );
functionRegistry.patternDescriptorBuilder( "cosine_distance", "?1<=>?2" )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.hibernate.boot.model.TypeContributions;
import org.hibernate.boot.model.TypeContributor;
import org.hibernate.dialect.CockroachDialect;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.PostgreSQLDialect;
import org.hibernate.engine.jdbc.Size;
Expand All @@ -34,7 +35,8 @@ public class PGVectorTypeContributor implements TypeContributor {
@Override
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect();
if ( dialect instanceof PostgreSQLDialect ) {
if ( dialect instanceof PostgreSQLDialect ||
dialect instanceof CockroachDialect ) {
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

import org.hibernate.annotations.Array;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.dialect.CockroachDialect;
import org.hibernate.dialect.PostgreSQLDialect;
import org.hibernate.testing.orm.junit.RequiresDialects;
import org.hibernate.testing.orm.junit.SkipForDialect;
import org.hibernate.type.SqlTypes;

import org.hibernate.testing.orm.junit.DomainModel;
Expand All @@ -32,7 +35,10 @@
*/
@DomainModel(annotatedClasses = PGVectorTest.VectorEntity.class)
@SessionFactory
@RequiresDialect(value = PostgreSQLDialect.class, matchSubTypes = false)
@RequiresDialects({
@RequiresDialect(value = PostgreSQLDialect.class, matchSubTypes = false),
@RequiresDialect(value = CockroachDialect.class, majorVersion = 24, minorVersion = 2)
})
public class PGVectorTest {

private static final float[] V1 = new float[]{ 1, 2, 3 };
Expand Down Expand Up @@ -166,6 +172,7 @@ public void testVectorNorm(SessionFactoryScope scope) {
}

@Test
@SkipForDialect(dialectClass = CockroachDialect.class, reason = "CockroachDB does not currently support the sum() function on vector type" )
public void testVectorSum(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-sum-example[]
Expand All @@ -178,6 +185,7 @@ public void testVectorSum(SessionFactoryScope scope) {
}

@Test
@SkipForDialect(dialectClass = CockroachDialect.class, reason = "CockroachDB does not currently support the avg() function on vector type" )
public void testVectorAvg(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-avg-example[]
Expand Down
Loading