diff --git a/cmd/atlas/internal/docker/docker.go b/cmd/atlas/internal/docker/docker.go index a8070d58451..017301ff953 100644 --- a/cmd/atlas/internal/docker/docker.go +++ b/cmd/atlas/internal/docker/docker.go @@ -430,15 +430,25 @@ func (c *Container) Wait(ctx context.Context, timeout time.Duration) error { // URL returns a URL to connect to the Container. func (c *Container) URL() (*url.URL, error) { + host := "localhost" + // Check if the DOCKER_HOST env var is set. + // If it is, use the host from the URL. + if h := os.Getenv("DOCKER_HOST"); h != "" { + u, err := url.Parse(h) + if err != nil { + return nil, err + } + host = u.Hostname() + } switch c.cfg.driver { case DriverClickHouse: - return url.Parse(fmt.Sprintf("clickhouse://:%s@%s:%s/%s", c.Passphrase, "localhost", c.Port, c.cfg.Database)) + return url.Parse(fmt.Sprintf("clickhouse://:%s@%s:%s/%s", c.Passphrase, host, c.Port, c.cfg.Database)) case DriverSQLServer: - return url.Parse(fmt.Sprintf("sqlserver://sa:%s@localhost:%s?database=%s", passSQLServer, c.Port, c.cfg.Database)) + return url.Parse(fmt.Sprintf("sqlserver://sa:%s@%s:%s?database=%s", passSQLServer, host, c.Port, c.cfg.Database)) case DriverPostgres: - return url.Parse(fmt.Sprintf("postgres://postgres:%s@localhost:%s/%s?sslmode=disable", c.Passphrase, c.Port, c.cfg.Database)) + return url.Parse(fmt.Sprintf("postgres://postgres:%s@%s:%s/%s?sslmode=disable", c.Passphrase, host, c.Port, c.cfg.Database)) case DriverMySQL, DriverMariaDB: - return url.Parse(fmt.Sprintf("%s://root:%s@localhost:%s/%s", c.cfg.driver, c.Passphrase, c.Port, c.cfg.Database)) + return url.Parse(fmt.Sprintf("%s://root:%s@%s:%s/%s", c.cfg.driver, c.Passphrase, host, c.Port, c.cfg.Database)) default: return nil, fmt.Errorf("unknown driver: %q", c.cfg.driver) } diff --git a/cmd/atlas/internal/docker/docker_test.go b/cmd/atlas/internal/docker/docker_test.go index 3b2ed113d8d..99162abf51e 100644 --- a/cmd/atlas/internal/docker/docker_test.go +++ b/cmd/atlas/internal/docker/docker_test.go @@ -414,3 +414,16 @@ func TestImageURL(t *testing.T) { require.Equal(t, u, got.String()) } } + +func TestContainerURL(t *testing.T) { + c := &Container{cfg: Config{driver: "postgres"}, Passphrase: "pass", Port: "5432"} + u, err := c.URL() + require.NoError(t, err) + require.Equal(t, "postgres://postgres:pass@localhost:5432/?sslmode=disable", u.String()) + + // With DOCKER_HOST + t.Setenv("DOCKER_HOST", "tcp://host.docker.internal:2375") + u, err = c.URL() + require.NoError(t, err) + require.Equal(t, "postgres://postgres:pass@host.docker.internal:5432/?sslmode=disable", u.String()) +}