From 109ac62fd22ae030ffd0778aea34b5d56445c731 Mon Sep 17 00:00:00 2001
From: "Giau. Tran Minh" <giau.tmg@gmail.com>
Date: Wed, 9 Nov 2022 03:56:00 +0700
Subject: [PATCH] provider: fixed the path for provider (#43)

* provider: fixed the path for provider

Signed-off-by: Giau. Tran Minh <hello@giautm.dev>

* fix: support relative path for dir

Signed-off-by: Giau. Tran Minh <hello@giautm.dev>

Signed-off-by: Giau. Tran Minh <hello@giautm.dev>
---
 go.mod                             |  2 +-
 internal/atlas/atlas.go            | 36 +++++++++++++++++++++---------
 internal/atlas/atlas_test.go       |  9 ++++++--
 internal/provider/provider.go      | 25 +++++++++++++--------
 internal/provider/provider_test.go |  2 +-
 main.go                            |  2 +-
 6 files changed, 51 insertions(+), 25 deletions(-)

diff --git a/go.mod b/go.mod
index b09f020..128e6f1 100644
--- a/go.mod
+++ b/go.mod
@@ -10,6 +10,7 @@ require (
 	github.com/hashicorp/terraform-plugin-framework v0.14.0
 	github.com/hashicorp/terraform-plugin-framework-validators v0.5.0
 	github.com/hashicorp/terraform-plugin-go v0.14.0
+	github.com/hashicorp/terraform-plugin-log v0.7.0
 	github.com/hashicorp/terraform-plugin-sdk/v2 v2.23.0
 	github.com/lib/pq v1.10.5
 	github.com/mattn/go-sqlite3 v1.14.10
@@ -44,7 +45,6 @@ require (
 	github.com/hashicorp/logutils v1.0.0 // indirect
 	github.com/hashicorp/terraform-exec v0.17.3 // indirect
 	github.com/hashicorp/terraform-json v0.14.0 // indirect
-	github.com/hashicorp/terraform-plugin-log v0.7.0 // indirect
 	github.com/hashicorp/terraform-registry-address v0.0.0-20220623143253-7d51757b572c // indirect
 	github.com/hashicorp/terraform-svchost v0.0.0-20200729002733-f050f53b9734 // indirect
 	github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d // indirect
diff --git a/internal/atlas/atlas.go b/internal/atlas/atlas.go
index a52b9af..1f1e168 100644
--- a/internal/atlas/atlas.go
+++ b/internal/atlas/atlas.go
@@ -7,10 +7,12 @@ import (
 	"os"
 	"os/exec"
 	"path"
+	"path/filepath"
 	"runtime"
 	"strings"
 
 	"github.com/hashicorp/terraform-plugin-framework/diag"
+	"github.com/hashicorp/terraform-plugin-log/tflog"
 )
 
 type (
@@ -38,8 +40,8 @@ type (
 // NewClient returns a new Atlas client.
 // The client will try to find the Atlas CLI in the current directory,
 // and in the PATH.
-func NewClient(name string) (*Client, error) {
-	path, err := execPath(name)
+func NewClient(ctx context.Context, dir, name string) (*Client, error) {
+	path, err := execPath(ctx, dir, name)
 	if err != nil {
 		return nil, err
 	}
@@ -53,10 +55,14 @@ func NewClientWithPath(path string) *Client {
 
 // Apply runs the `migrate apply` command.
 func (c *Client) Apply(ctx context.Context, data *ApplyParams) (*ApplyReport, error) {
+	dir, err := filepath.Abs(data.DirURL)
+	if err != nil {
+		return nil, err
+	}
 	args := []string{
 		"migrate", "apply", "--log", "{{ json . }}",
 		"--url", data.URL,
-		"--dir", fmt.Sprintf("file://%s", data.DirURL),
+		"--dir", fmt.Sprintf("file://%s", dir),
 	}
 	if data.RevisionsSchema != "" {
 		args = append(args, "--revisions-schema", data.RevisionsSchema)
@@ -79,10 +85,14 @@ func (c *Client) Apply(ctx context.Context, data *ApplyParams) (*ApplyReport, er
 
 // Status runs the `migrate status` command.
 func (c *Client) Status(ctx context.Context, data *StatusParams) (*StatusReport, error) {
+	dir, err := filepath.Abs(data.DirURL)
+	if err != nil {
+		return nil, err
+	}
 	args := []string{
 		"migrate", "status", "--log", "{{ json . }}",
 		"--url", data.URL,
-		"--dir", fmt.Sprintf("file://%s", data.DirURL),
+		"--dir", fmt.Sprintf("file://%s", dir),
 	}
 	if data.RevisionsSchema != "" {
 		args = append(args, "--revisions-schema", data.RevisionsSchema)
@@ -159,18 +169,22 @@ func (r StatusReport) Amount(version string) (amount uint, ok bool) {
 	return amount, false
 }
 
-func execPath(name string) (string, error) {
+func execPath(ctx context.Context, dir, name string) (string, error) {
 	if runtime.GOOS == "windows" {
 		name += ".exe"
 	}
-	wd, err := os.Getwd()
-	if err != nil {
-		return "", err
-	}
-	p := path.Join(wd, name)
-	if _, err = os.Stat(p); os.IsExist(err) {
+	tflog.Debug(ctx, "atlas: looking for the Atlas CLI in the current directory", map[string]interface{}{
+		"dir":  dir,
+		"path": path.Join(dir, name),
+		"name": name,
+	})
+	p := path.Join(dir, name)
+	if _, err := os.Stat(p); os.IsExist(err) {
 		return p, nil
 	}
+	tflog.Debug(ctx, "atlas: looking for the Atlas CLI in the $PATH", map[string]interface{}{
+		"name": name,
+	})
 	// If the binary is not in the current directory,
 	// try to find it in the PATH.
 	return exec.LookPath(name)
diff --git a/internal/atlas/atlas_test.go b/internal/atlas/atlas_test.go
index 0762557..740b1f8 100644
--- a/internal/atlas/atlas_test.go
+++ b/internal/atlas/atlas_test.go
@@ -3,6 +3,7 @@ package atlas_test
 import (
 	"context"
 	"fmt"
+	"os"
 	"testing"
 
 	_ "ariga.io/atlas/sql/mysql"
@@ -42,7 +43,9 @@ func Test_MigrateApply(t *testing.T) {
 			wantTarget: "20221101165415",
 		},
 	}
-	c, err := atlas.NewClient("atlas")
+	wd, err := os.Getwd()
+	r.NoError(err)
+	c, err := atlas.NewClient(context.Background(), wd, "atlas")
 	r.NoError(err)
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
@@ -83,7 +86,9 @@ func Test_MigrateStatus(t *testing.T) {
 			wantNext:    "20221101163823",
 		},
 	}
-	c, err := atlas.NewClient("atlas")
+	wd, err := os.Getwd()
+	r.NoError(err)
+	c, err := atlas.NewClient(context.Background(), wd, "atlas")
 	r.NoError(err)
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
diff --git a/internal/provider/provider.go b/internal/provider/provider.go
index b90c11e..35dc9b2 100644
--- a/internal/provider/provider.go
+++ b/internal/provider/provider.go
@@ -3,8 +3,10 @@ package provider
 import (
 	"bytes"
 	"context"
+	"fmt"
 	"os"
 	"path"
+	"runtime"
 
 	_ "ariga.io/atlas/sql/mysql"
 	_ "ariga.io/atlas/sql/postgres"
@@ -29,6 +31,8 @@ type (
 	AtlasProvider struct {
 		// client is the client used to interact with the Atlas CLI.
 		client *atlas.Client
+		// dir is the directory where the provider is installed.
+		dir string
 		// version is set to the provider version on release, "dev" when the
 		// provider is built and ran locally, and "test" when running acceptance
 		// testing.
@@ -53,9 +57,16 @@ const (
 )
 
 // New returns a new provider.
-func New(version, commit string) func() provider.Provider {
+func New(address, version, commit string) func() provider.Provider {
+	wd, err := os.Getwd()
+	if err != nil {
+		panic(err)
+	}
+	providersDir := path.Join(wd, ".terraform", "providers")
+	platform := fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH)
 	return func() provider.Provider {
 		return &AtlasProvider{
+			dir:     path.Join(providersDir, address, version, platform),
 			version: version,
 		}
 	}
@@ -78,7 +89,7 @@ func (p *AtlasProvider) GetSchema(ctx context.Context) (tfsdk.Schema, diag.Diagn
 
 // Configure implements provider.Provider.
 func (p *AtlasProvider) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) {
-	c, err := atlas.NewClient("atlas")
+	c, err := atlas.NewClient(ctx, p.dir, "atlas")
 	if err != nil {
 		resp.Diagnostics.AddError("Failed to create client", err.Error())
 		return
@@ -106,7 +117,7 @@ func (p *AtlasProvider) Resources(ctx context.Context) []func() resource.Resourc
 
 // ConfigValidators returns a list of functions which will all be performed during validation.
 func (p *AtlasProvider) ValidateConfig(ctx context.Context, req provider.ValidateConfigRequest, resp *provider.ValidateConfigResponse) {
-	msg := checkForUpdate(ctx, p.version)()
+	msg := checkForUpdate(ctx, p.dir, p.version)()
 	if msg != "" {
 		resp.Diagnostics.AddWarning(
 			"Update Available",
@@ -118,7 +129,7 @@ func (p *AtlasProvider) ValidateConfig(ctx context.Context, req provider.Validat
 func noText() string { return "" }
 
 // checkForUpdate checks for version updates and security advisories for Atlas.
-func checkForUpdate(ctx context.Context, version string) func() string {
+func checkForUpdate(ctx context.Context, dir, version string) func() string {
 	done := make(chan struct{})
 	// Users may skip update checking behavior.
 	if v := os.Getenv(envNoUpdate); v != "" {
@@ -128,14 +139,10 @@ func checkForUpdate(ctx context.Context, version string) func() string {
 	if !semver.IsValid(version) {
 		return noText
 	}
-	curDir, err := os.Getwd()
-	if err != nil {
-		return noText
-	}
 	var message string
 	go func() {
 		defer close(done)
-		vc := vercheck.New(vercheckURL, path.Join(curDir, versionFile))
+		vc := vercheck.New(vercheckURL, path.Join(dir, versionFile))
 		payload, err := vc.Check(version)
 		if err != nil {
 			return
diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go
index c91355c..da65f04 100644
--- a/internal/provider/provider_test.go
+++ b/internal/provider/provider_test.go
@@ -13,7 +13,7 @@ import (
 // CLI command executed to create a provider server to which the CLI can
 // reattach.
 var testAccProtoV6ProviderFactories = map[string]func() (tfprotov6.ProviderServer, error){
-	"atlas": providerserver.NewProtocol6WithError(provider.New("test", "")()),
+	"atlas": providerserver.NewProtocol6WithError(provider.New("registry.terraform.io/ariga/atlas", "test", "")()),
 }
 
 func testAccPreCheck(t *testing.T) {
diff --git a/main.go b/main.go
index 588d904..0060a86 100644
--- a/main.go
+++ b/main.go
@@ -39,7 +39,7 @@ func main() {
 		Debug:   debug,
 	}
 
-	err := providerserver.Serve(context.Background(), provider.New(version, commit), opts)
+	err := providerserver.Serve(context.Background(), provider.New(opts.Address, version, commit), opts)
 	if err != nil {
 		log.Fatal(err.Error())
 	}