diff --git a/transport/grpc/dial_test.go b/transport/grpc/dial_test.go index 91b98c7a102..e4ad07110a2 100644 --- a/transport/grpc/dial_test.go +++ b/transport/grpc/dial_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "net" + "os" "testing" "time" @@ -64,3 +65,65 @@ func TestGRPCHook(t *testing.T) { t.Error("expected a call to expected dialer, didn't get one") } } + +func TestIsDirectPathEnabled(t *testing.T) { + for _, testcase := range []struct { + name string + endpoint string + envVar string + want bool + }{ + { + name: "matches", + endpoint: "some-api", + envVar: "some-api", + want: true, + }, + { + name: "does not match", + endpoint: "some-api", + envVar: "some-other-api", + want: false, + }, + { + name: "matches in list", + endpoint: "some-api-2", + envVar: "some-api-1,some-api-2,some-api-3", + want: true, + }, + { + name: "empty env var", + endpoint: "", + envVar: "", + want: false, + }, + { + name: "trailing comma", + endpoint: "", + envVar: "foo,bar,", + want: false, + }, + { + name: "dns schemes are allowed", + endpoint: "dns:///foo", + envVar: "dns:///foo", + want: true, + }, + { + name: "non-dns schemes are disallowed", + endpoint: "https://foo", + envVar: "https://foo", + want: false, + }, + } { + t.Run(testcase.name, func(t *testing.T) { + if err := os.Setenv("GOOGLE_CLOUD_ENABLE_DIRECT_PATH", testcase.envVar); err != nil { + t.Fatal(err) + } + + if got := isDirectPathEnabled(testcase.endpoint); got != testcase.want { + t.Fatalf("got %v, want %v", got, testcase.want) + } + }) + } +}