Skip to content

Commit

Permalink
cmd/atlas/internal/docker: add pgvector image (#3347)
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m authored Jan 28, 2025
1 parent 112f6bd commit d723234
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
20 changes: 18 additions & 2 deletions cmd/atlas/internal/docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ const (
DriverClickHouse = "clickhouse"
)

const (
PostgresPostGIS = "postgis"
PostgresPGVector = "pgvector"
)

// FromURL parses a URL in the format of
// "docker://driver/tag[/dbname]" and returns a Config.
func FromURL(u *url.URL, opts ...ConfigOption) (*Config, error) {
Expand Down Expand Up @@ -136,8 +141,10 @@ func FromURL(u *url.URL, opts ...ConfigOption) (*Config, error) {
)
}
cfg, err = MariaDB(tag, append(baseOpts, opts...)...)
case "postgis":
baseOpts = append(baseOpts, Image("postgis/postgis:"+tag))
case PostgresPostGIS:
baseOpts = append(baseOpts, Image(
fmt.Sprintf("%[1]s/%[1]s:%[2]s", PostgresPostGIS, tag),
))
if dbName != "" && dbName != "postgres" {
// Create manually the PostgreSQL database instead of using the POSTGRES_DB because
// PostGIS automatically creates and install the following extensions and schemas:
Expand All @@ -149,6 +156,15 @@ func FromURL(u *url.URL, opts ...ConfigOption) (*Config, error) {
}
driver = DriverPostgres
cfg, err = PostgreSQL(tag, append(baseOpts, opts...)...)
case PostgresPGVector:
baseOpts = append(baseOpts, Image(
fmt.Sprintf("%[1]s/%[1]s:%[2]s", PostgresPGVector, tag),
))
if dbName != "" {
baseOpts = append(baseOpts, Database(dbName), Env("POSTGRES_DB="+dbName))
}
driver = DriverPostgres
cfg, err = PostgreSQL(tag, append(baseOpts, opts...)...)
case DriverPostgres:
if dbName != "" {
baseOpts = append(baseOpts, Database(dbName), Env("POSTGRES_DB="+dbName))
Expand Down
30 changes: 30 additions & 0 deletions cmd/atlas/internal/docker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ func TestFromURL(t *testing.T) {
Out: io.Discard,
}, cfg)

// PostGIS.
u, err = url.Parse("docker://postgis/14-3.4")
require.NoError(t, err)
cfg, err = FromURL(u)
Expand Down Expand Up @@ -169,6 +170,35 @@ func TestFromURL(t *testing.T) {
setup: []string{`CREATE DATABASE "dev"`},
}, cfg)

// PGVector.
u, err = url.Parse("docker://pgvector/pg17")
require.NoError(t, err)
cfg, err = FromURL(u)
require.NoError(t, err)
require.Equal(t, &Config{
driver: "postgres",
Image: "pgvector/pgvector:pg17",
Database: "postgres",
Env: []string{"POSTGRES_PASSWORD=pass"},
User: url.UserPassword("postgres", pass),
Port: "5432",
Out: io.Discard,
}, cfg)

u, err = url.Parse("docker://pgvector/pg17/dev")
require.NoError(t, err)
cfg, err = FromURL(u)
require.NoError(t, err)
require.Equal(t, &Config{
driver: "postgres",
Image: "pgvector/pgvector:pg17",
Database: "dev",
Env: []string{"POSTGRES_PASSWORD=pass", "POSTGRES_DB=dev"},
User: url.UserPassword("postgres", pass),
Port: "5432",
Out: io.Discard,
}, cfg)

// SQL Server
u, err = url.Parse("docker://sqlserver")
require.NoError(t, err)
Expand Down

0 comments on commit d723234

Please sign in to comment.