diff --git a/go.mod b/go.mod index b09f020..128e6f1 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/hashicorp/terraform-plugin-framework v0.14.0 github.com/hashicorp/terraform-plugin-framework-validators v0.5.0 github.com/hashicorp/terraform-plugin-go v0.14.0 + github.com/hashicorp/terraform-plugin-log v0.7.0 github.com/hashicorp/terraform-plugin-sdk/v2 v2.23.0 github.com/lib/pq v1.10.5 github.com/mattn/go-sqlite3 v1.14.10 @@ -44,7 +45,6 @@ require ( github.com/hashicorp/logutils v1.0.0 // indirect github.com/hashicorp/terraform-exec v0.17.3 // indirect github.com/hashicorp/terraform-json v0.14.0 // indirect - github.com/hashicorp/terraform-plugin-log v0.7.0 // indirect github.com/hashicorp/terraform-registry-address v0.0.0-20220623143253-7d51757b572c // indirect github.com/hashicorp/terraform-svchost v0.0.0-20200729002733-f050f53b9734 // indirect github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d // indirect diff --git a/internal/atlas/atlas.go b/internal/atlas/atlas.go index a52b9af..1f1e168 100644 --- a/internal/atlas/atlas.go +++ b/internal/atlas/atlas.go @@ -7,10 +7,12 @@ import ( "os" "os/exec" "path" + "path/filepath" "runtime" "strings" "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-log/tflog" ) type ( @@ -38,8 +40,8 @@ type ( // NewClient returns a new Atlas client. // The client will try to find the Atlas CLI in the current directory, // and in the PATH. -func NewClient(name string) (*Client, error) { - path, err := execPath(name) +func NewClient(ctx context.Context, dir, name string) (*Client, error) { + path, err := execPath(ctx, dir, name) if err != nil { return nil, err } @@ -53,10 +55,14 @@ func NewClientWithPath(path string) *Client { // Apply runs the `migrate apply` command. func (c *Client) Apply(ctx context.Context, data *ApplyParams) (*ApplyReport, error) { + dir, err := filepath.Abs(data.DirURL) + if err != nil { + return nil, err + } args := []string{ "migrate", "apply", "--log", "{{ json . }}", "--url", data.URL, - "--dir", fmt.Sprintf("file://%s", data.DirURL), + "--dir", fmt.Sprintf("file://%s", dir), } if data.RevisionsSchema != "" { args = append(args, "--revisions-schema", data.RevisionsSchema) @@ -79,10 +85,14 @@ func (c *Client) Apply(ctx context.Context, data *ApplyParams) (*ApplyReport, er // Status runs the `migrate status` command. func (c *Client) Status(ctx context.Context, data *StatusParams) (*StatusReport, error) { + dir, err := filepath.Abs(data.DirURL) + if err != nil { + return nil, err + } args := []string{ "migrate", "status", "--log", "{{ json . }}", "--url", data.URL, - "--dir", fmt.Sprintf("file://%s", data.DirURL), + "--dir", fmt.Sprintf("file://%s", dir), } if data.RevisionsSchema != "" { args = append(args, "--revisions-schema", data.RevisionsSchema) @@ -159,18 +169,22 @@ func (r StatusReport) Amount(version string) (amount uint, ok bool) { return amount, false } -func execPath(name string) (string, error) { +func execPath(ctx context.Context, dir, name string) (string, error) { if runtime.GOOS == "windows" { name += ".exe" } - wd, err := os.Getwd() - if err != nil { - return "", err - } - p := path.Join(wd, name) - if _, err = os.Stat(p); os.IsExist(err) { + tflog.Debug(ctx, "atlas: looking for the Atlas CLI in the current directory", map[string]interface{}{ + "dir": dir, + "path": path.Join(dir, name), + "name": name, + }) + p := path.Join(dir, name) + if _, err := os.Stat(p); os.IsExist(err) { return p, nil } + tflog.Debug(ctx, "atlas: looking for the Atlas CLI in the $PATH", map[string]interface{}{ + "name": name, + }) // If the binary is not in the current directory, // try to find it in the PATH. return exec.LookPath(name) diff --git a/internal/atlas/atlas_test.go b/internal/atlas/atlas_test.go index 0762557..740b1f8 100644 --- a/internal/atlas/atlas_test.go +++ b/internal/atlas/atlas_test.go @@ -3,6 +3,7 @@ package atlas_test import ( "context" "fmt" + "os" "testing" _ "ariga.io/atlas/sql/mysql" @@ -42,7 +43,9 @@ func Test_MigrateApply(t *testing.T) { wantTarget: "20221101165415", }, } - c, err := atlas.NewClient("atlas") + wd, err := os.Getwd() + r.NoError(err) + c, err := atlas.NewClient(context.Background(), wd, "atlas") r.NoError(err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -83,7 +86,9 @@ func Test_MigrateStatus(t *testing.T) { wantNext: "20221101163823", }, } - c, err := atlas.NewClient("atlas") + wd, err := os.Getwd() + r.NoError(err) + c, err := atlas.NewClient(context.Background(), wd, "atlas") r.NoError(err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/provider/provider.go b/internal/provider/provider.go index b90c11e..35dc9b2 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -3,8 +3,10 @@ package provider import ( "bytes" "context" + "fmt" "os" "path" + "runtime" _ "ariga.io/atlas/sql/mysql" _ "ariga.io/atlas/sql/postgres" @@ -29,6 +31,8 @@ type ( AtlasProvider struct { // client is the client used to interact with the Atlas CLI. client *atlas.Client + // dir is the directory where the provider is installed. + dir string // version is set to the provider version on release, "dev" when the // provider is built and ran locally, and "test" when running acceptance // testing. @@ -53,9 +57,16 @@ const ( ) // New returns a new provider. -func New(version, commit string) func() provider.Provider { +func New(address, version, commit string) func() provider.Provider { + wd, err := os.Getwd() + if err != nil { + panic(err) + } + providersDir := path.Join(wd, ".terraform", "providers") + platform := fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH) return func() provider.Provider { return &AtlasProvider{ + dir: path.Join(providersDir, address, version, platform), version: version, } } @@ -78,7 +89,7 @@ func (p *AtlasProvider) GetSchema(ctx context.Context) (tfsdk.Schema, diag.Diagn // Configure implements provider.Provider. func (p *AtlasProvider) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) { - c, err := atlas.NewClient("atlas") + c, err := atlas.NewClient(ctx, p.dir, "atlas") if err != nil { resp.Diagnostics.AddError("Failed to create client", err.Error()) return @@ -106,7 +117,7 @@ func (p *AtlasProvider) Resources(ctx context.Context) []func() resource.Resourc // ConfigValidators returns a list of functions which will all be performed during validation. func (p *AtlasProvider) ValidateConfig(ctx context.Context, req provider.ValidateConfigRequest, resp *provider.ValidateConfigResponse) { - msg := checkForUpdate(ctx, p.version)() + msg := checkForUpdate(ctx, p.dir, p.version)() if msg != "" { resp.Diagnostics.AddWarning( "Update Available", @@ -118,7 +129,7 @@ func (p *AtlasProvider) ValidateConfig(ctx context.Context, req provider.Validat func noText() string { return "" } // checkForUpdate checks for version updates and security advisories for Atlas. -func checkForUpdate(ctx context.Context, version string) func() string { +func checkForUpdate(ctx context.Context, dir, version string) func() string { done := make(chan struct{}) // Users may skip update checking behavior. if v := os.Getenv(envNoUpdate); v != "" { @@ -128,14 +139,10 @@ func checkForUpdate(ctx context.Context, version string) func() string { if !semver.IsValid(version) { return noText } - curDir, err := os.Getwd() - if err != nil { - return noText - } var message string go func() { defer close(done) - vc := vercheck.New(vercheckURL, path.Join(curDir, versionFile)) + vc := vercheck.New(vercheckURL, path.Join(dir, versionFile)) payload, err := vc.Check(version) if err != nil { return diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index c91355c..da65f04 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -13,7 +13,7 @@ import ( // CLI command executed to create a provider server to which the CLI can // reattach. var testAccProtoV6ProviderFactories = map[string]func() (tfprotov6.ProviderServer, error){ - "atlas": providerserver.NewProtocol6WithError(provider.New("test", "")()), + "atlas": providerserver.NewProtocol6WithError(provider.New("registry.terraform.io/ariga/atlas", "test", "")()), } func testAccPreCheck(t *testing.T) { diff --git a/main.go b/main.go index 588d904..0060a86 100644 --- a/main.go +++ b/main.go @@ -39,7 +39,7 @@ func main() { Debug: debug, } - err := providerserver.Serve(context.Background(), provider.New(version, commit), opts) + err := providerserver.Serve(context.Background(), provider.New(opts.Address, version, commit), opts) if err != nil { log.Fatal(err.Error()) }