diff --git a/cmd/atlas/internal/docker/docker.go b/cmd/atlas/internal/docker/docker.go index 32096ac06db..5cde8f5e149 100644 --- a/cmd/atlas/internal/docker/docker.go +++ b/cmd/atlas/internal/docker/docker.go @@ -80,78 +80,87 @@ const ( ) // FromURL parses a URL in the format of -// "docker://image/tag" and returns a Config. +// "docker://image/tag[/dbname]" and returns a Config. func FromURL(u *url.URL) (*Config, error) { var ( - tag string - opts []ConfigOption - parts = strings.Split(strings.TrimPrefix(u.Path, "/"), "/") + parts = strings.Split(strings.TrimPrefix(u.Path, "/"), "/") + idxTag = len(parts) - 1 + dbName string ) - if len(parts) > 0 { - tag = parts[0] - } - switch n := len(parts); { - case n == 2 && !strings.Contains(parts[1], ":"): - opts = append(opts, Database(parts[1])) - case n == 3: - opts = append(opts, Database(parts[2])) - fallthrough - case n == 2: - parts[0] = fmt.Sprintf("%s/%s", parts[0], parts[1]) + // Check if the last part is a tag or a database name. + if idxTag > 0 && !strings.ContainsRune(parts[idxTag], ':') { + // The last part is not a tag, so it must be a database name. + dbName, idxTag = parts[idxTag], idxTag-1 } - // Support docker+driver://image/tag - if drv, ok := strings.CutPrefix(u.Scheme, "docker+"); ok { - img := Image(parts[0]) + var ( + opts []ConfigOption + tag string + ) + // Support docker+driver://[:] + driver, customImage := strings.CutPrefix(u.Scheme, "docker+") + if customImage { + // The image is fully specified in the URL. + img := path.Join(parts[:idxTag+1]...) if u.Host != "" && u.Host != "_" { - img = Image(u.Host, parts[0]) + img = path.Join(u.Host, img) + } + opts = append(opts, Image(img)) + } else { + driver = u.Host + if idxTag >= 0 { + tag = parts[idxTag] } - opts = append(opts, img) - u.Host = drv } var ( err error cfg *Config ) - switch u.Host { + switch driver { case DriverMySQL: - if len(parts) > 1 { - opts = append(opts, Env("MYSQL_DATABASE="+parts[1]), Setup(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", parts[1]))) + if dbName != "" { + opts = append(opts, + Database(dbName), + Env("MYSQL_DATABASE="+dbName), + Setup(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", dbName)), + ) } cfg, err = MySQL(tag, opts...) case "maria": - u.Host = DriverMariaDB + driver = DriverMariaDB fallthrough case DriverMariaDB: - if len(parts) > 1 { - opts = append(opts, Env("MYSQL_DATABASE="+parts[1]), Setup(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", parts[1]))) + if dbName != "" { + opts = append(opts, + Database(dbName), + Env("MYSQL_DATABASE="+dbName), + Setup(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", dbName)), + ) } cfg, err = MariaDB(tag, opts...) case "postgis": opts = append(opts, Image("postgis/postgis:"+tag)) - u.Host = DriverPostgres + driver = DriverPostgres fallthrough case DriverPostgres: - if len(parts) > 1 { - opts = append(opts, Env("POSTGRES_DB="+parts[1])) + if dbName != "" { + opts = append(opts, Database(dbName), Env("POSTGRES_DB="+dbName)) } cfg, err = PostgreSQL(tag, opts...) case DriverSQLServer: - if len(parts) > 1 { - if db := parts[1]; db != "master" { - opts = append(opts, Setup(fmt.Sprintf("CREATE DATABASE [%s]", db))) - } + if dbName != "" && dbName != "master" { + opts = append(opts, + Database(dbName), + Setup(fmt.Sprintf("CREATE DATABASE [%s]", dbName)), + ) } cfg, err = SQLServer(tag, opts...) - if err != nil { - return nil, err - } default: - return nil, fmt.Errorf("unsupported docker image %q", u.Host) + return nil, fmt.Errorf("unsupported docker image %q", driver) } if err != nil { return nil, err } - cfg.driver = u.Host + cfg.driver = driver return cfg, nil } diff --git a/cmd/atlas/internal/docker/docker_test.go b/cmd/atlas/internal/docker/docker_test.go index 4f09ff598eb..7bcdb6dff6c 100644 --- a/cmd/atlas/internal/docker/docker_test.go +++ b/cmd/atlas/internal/docker/docker_test.go @@ -300,6 +300,25 @@ func TestFromURL_CustomImage(t *testing.T) { db: "dev", dialect: "mysql", }, + // SQL Server. + { + url: "docker+sqlserver://mcr.microsoft.com/mssql/server:2022-latest", + image: "mcr.microsoft.com/mssql/server:2022-latest", + db: "master", + dialect: "sqlserver", + }, + { + url: "docker+sqlserver://mcr.microsoft.com/mssql/server:2022-latest/dev", + image: "mcr.microsoft.com/mssql/server:2022-latest", + db: "dev", + dialect: "sqlserver", + }, + { + url: "docker+sqlserver://mcr.microsoft.com/mssql/server:latest", + image: "mcr.microsoft.com/mssql/server:latest", + db: "master", + dialect: "sqlserver", + }, } { u, err := url.Parse(tt.url) require.NoError(t, err) @@ -323,7 +342,8 @@ func TestImageURL(t *testing.T) { require.Equal(t, u, got.String()) } for img, u := range map[string]string{ - "mcr.microsoft.com/azure-sql-edge:1.0.7": "docker+sqlserver://mcr.microsoft.com/azure-sql-edge:1.0.7", + "mcr.microsoft.com/azure-sql-edge:1.0.7": "docker+sqlserver://mcr.microsoft.com/azure-sql-edge:1.0.7", + "mcr.microsoft.com/mssql/server:2022-latest": "docker+sqlserver://mcr.microsoft.com/mssql/server:2022-latest", } { got, err := ImageURL(DriverSQLServer, img) require.NoError(t, err)