diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 90da768350..bf51417028 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -553,6 +553,25 @@ functions: KMS_MOCK_SERVERS_RUNNING: "true" args: [*task-runner, evg-test-kmip] + run-retry-kms-requests: + - command: subprocess.exec + type: test + params: + binary: "bash" + env: + GO_BUILD_TAGS: cse + include_expansions_in_env: [AUTH, SSL, MONGODB_URI, TOPOLOGY, + MONGO_GO_DRIVER_COMPRESSOR] + args: [*task-runner, setup-test] + - command: subprocess.exec + type: test + params: + binary: "bash" + env: + KMS_FAILPOINT_CA_FILE: "${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem" + KMS_FAILPOINT_SERVER_RUNNING: "true" + args: [*task-runner, evg-test-retry-kms-requests] + run-fuzz-tests: - command: subprocess.exec type: test @@ -1440,7 +1459,7 @@ tasks: SSL: "nossl" - name: "test-kms-tls-invalid-cert" - tags: ["kms-tls"] + tags: ["kms-test"] commands: - func: bootstrap-mongo-orchestration vars: @@ -1456,7 +1475,7 @@ tasks: SSL: "nossl" - name: "test-kms-tls-invalid-hostname" - tags: ["kms-tls"] + tags: ["kms-test"] commands: - func: bootstrap-mongo-orchestration vars: @@ -1486,6 +1505,17 @@ tasks: AUTH: "noauth" SSL: "nossl" + - name: "test-retry-kms-requests" + tags: ["kms-test"] + commands: + - func: bootstrap-mongo-orchestration + vars: + TOPOLOGY: "server" + AUTH: "noauth" + SSL: "nossl" + - func: start-cse-servers + - func: run-retry-kms-requests + - name: "test-serverless" tags: ["serverless"] commands: @@ -2163,11 +2193,11 @@ buildvariants: tasks: - name: ".versioned-api" - - matrix_name: "kms-tls-test" + - matrix_name: "kms-test" matrix_spec: { version: ["7.0"], os-ssl-40: ["rhel87-64"] } - display_name: "KMS TLS ${os-ssl-40}" + display_name: "KMS TEST ${os-ssl-40}" tasks: - - name: ".kms-tls" + - name: ".kms-test" - matrix_name: "load-balancer-test" tags: ["pullrequest"] diff --git a/Taskfile.yml b/Taskfile.yml index f22427a640..d93a7e4e9b 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -141,6 +141,9 @@ tasks: evg-test-kms: - go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s ./internal/integration -run TestClientSideEncryptionProse/kms_tls_tests >> test.suite + evg-test-retry-kms-requests: + - go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s ./internal/integration -run TestClientSideEncryptionProse/kms_retry_tests >> test.suite + evg-test-load-balancers: # Load balancer should be tested with all unified tests as well as tests in the following # components: retryable reads, retryable writes, change streams, initial DNS seedlist discovery. diff --git a/etc/install-libmongocrypt.sh b/etc/install-libmongocrypt.sh index 646721a8f7..a94d648eae 100755 --- a/etc/install-libmongocrypt.sh +++ b/etc/install-libmongocrypt.sh @@ -3,7 +3,7 @@ # This script installs libmongocrypt into an "install" directory. set -eux -LIBMONGOCRYPT_TAG="1.11.0" +LIBMONGOCRYPT_TAG="1.12.0" # Install libmongocrypt based on OS. if [ "Windows_NT" = "${OS:-}" ]; then diff --git a/internal/integration/client_side_encryption_prose_test.go b/internal/integration/client_side_encryption_prose_test.go index 0620b9ff10..81011255f4 100644 --- a/internal/integration/client_side_encryption_prose_test.go +++ b/internal/integration/client_side_encryption_prose_test.go @@ -14,6 +14,7 @@ import ( "context" "crypto/tls" "encoding/base64" + "encoding/json" "fmt" "io/ioutil" "net" @@ -30,6 +31,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/handshake" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" "go.mongodb.org/mongo-driver/v2/internal/integtest" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" @@ -2918,7 +2920,7 @@ func TestClientSideEncryptionProse(t *testing.T) { } }) - mt.RunOpts("22. range explicit encryption applies defaults", qeRunOpts22, func(mt *mtest.T) { + mt.RunOpts("23. range explicit encryption applies defaults", qeRunOpts22, func(mt *mtest.T) { err := mt.Client.Database("keyvault").Collection("datakeys").Drop(context.Background()) assert.Nil(mt, err, "error on Drop: %v", err) @@ -2979,6 +2981,147 @@ func TestClientSideEncryptionProse(t *testing.T) { assert.Greater(t, len(payload.Data), len(payloadDefaults.Data), "the returned payload size is expected to be greater than %d", len(payloadDefaults.Data)) }) }) + + mt.RunOpts("24. kms retry tests", noClientOpts, func(mt *mtest.T) { + kmsTlsTestcase := os.Getenv("KMS_FAILPOINT_SERVER_RUNNING") + if kmsTlsTestcase == "" { + mt.Skipf("Skipping test as KMS_FAILPOINT_SERVER_RUNNING is not set") + } + + mt.Parallel() + + tlsCAFile := os.Getenv("KMS_FAILPOINT_CA_FILE") + require.NotEqual(mt, tlsCAFile, "", "failed to load CA file") + + clientAndCATlsMap := map[string]interface{}{ + "tlsCAFile": tlsCAFile, + } + tlsCfg, err := options.BuildTLSConfig(clientAndCATlsMap) + require.NoError(mt, err, "BuildTLSConfig error: %v", err) + + setFailPoint := func(failure string, count int) error { + url := fmt.Sprintf("https://localhost:9003/set_failpoint/%s", failure) + var payloadBuf bytes.Buffer + body := map[string]int{"count": count} + json.NewEncoder(&payloadBuf).Encode(body) + req, err := http.NewRequest(http.MethodPost, url, &payloadBuf) + if err != nil { + return err + } + + client := &http.Client{ + Transport: &http.Transport{TLSClientConfig: tlsCfg}, + } + res, err := client.Do(req) + if err != nil { + return err + } + return res.Body.Close() + } + + kmsProviders := map[string]map[string]interface{}{ + "aws": { + "accessKeyId": awsAccessKeyID, + "secretAccessKey": awsSecretAccessKey, + }, + "azure": { + "tenantId": azureTenantID, + "clientId": azureClientID, + "clientSecret": azureClientSecret, + "identityPlatformEndpoint": "127.0.0.1:9003", + }, + "gcp": { + "email": gcpEmail, + "privateKey": gcpPrivateKey, + "endpoint": "127.0.0.1:9003", + }, + } + + dataKeys := []struct { + provider string + masterKey interface{} + }{ + {"aws", bson.D{ + {"region", "foo"}, + {"key", "bar"}, + {"endpoint", "127.0.0.1:9003"}, + }}, + {"azure", bson.D{ + {"keyVaultEndpoint", "127.0.0.1:9003"}, + {"keyName", "foo"}, + }}, + {"gcp", bson.D{ + {"projectId", "foo"}, + {"location", "bar"}, + {"keyRing", "baz"}, + {"keyName", "qux"}, + {"endpoint", "127.0.0.1:9003"}, + }}, + } + + testCases := []struct { + name string + failure string + }{ + {"Case 1: createDataKey and encrypt with TCP retry", "network"}, + {"Case 2: createDataKey and encrypt with HTTP retry", "http"}, + } + + for _, tc := range testCases { + for _, dataKey := range dataKeys { + mt.Run(fmt.Sprintf("%s_%s", tc.name, dataKey.provider), func(mt *mtest.T) { + keyVaultClient, err := mongo.Connect(options.Client().ApplyURI(mtest.ClusterURI())) + require.NoError(mt, err, "error on Connect: %v", err) + + ceo := options.ClientEncryption(). + SetKeyVaultNamespace(kvNamespace). + SetKmsProviders(kmsProviders). + SetTLSConfig(map[string]*tls.Config{dataKey.provider: tlsCfg}) + clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo) + require.NoError(mt, err, "error on NewClientEncryption: %v", err) + + err = setFailPoint(tc.failure, 1) + require.NoError(mt, err, "mock server error: %v", err) + + dkOpts := options.DataKey().SetMasterKey(dataKey.masterKey) + var keyID bson.Binary + keyID, err = clientEncryption.CreateDataKey(context.Background(), dataKey.provider, dkOpts) + require.NoError(mt, err, "error in CreateDataKey: %v", err) + + err = setFailPoint(tc.failure, 1) + require.NoError(mt, err, "mock server error: %v", err) + + testVal := bson.RawValue{Type: bson.TypeInt32, Value: bsoncore.AppendInt32(nil, 123)} + eo := options.Encrypt(). + SetKeyID(keyID). + SetAlgorithm("AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic") + _, err = clientEncryption.Encrypt(context.Background(), testVal, eo) + require.NoError(mt, err, "error in Encrypt: %v", err) + }) + } + } + + for _, dataKey := range dataKeys { + mt.Run(fmt.Sprintf("Case 3: createDataKey fails after too many retries_%s", dataKey.provider), func(mt *mtest.T) { + keyVaultClient, err := mongo.Connect(options.Client().ApplyURI(mtest.ClusterURI())) + require.NoError(mt, err, "error on Connect: %v", err) + + ceo := options.ClientEncryption(). + SetKeyVaultNamespace(kvNamespace). + SetKmsProviders(kmsProviders). + SetTLSConfig(map[string]*tls.Config{dataKey.provider: tlsCfg}) + clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo) + require.NoError(mt, err, "error on NewClientEncryption: %v", err) + + err = setFailPoint("network", 4) + require.NoError(mt, err, "mock server error: %v", err) + + dkOpts := options.DataKey().SetMasterKey(dataKey.masterKey) + _, err = clientEncryption.CreateDataKey(context.Background(), dataKey.provider, dkOpts) + require.ErrorContains(mt, err, "KMS request failed after 3 retries due to a network error") + }) + } + }) } func getWatcher(mt *mtest.T, streamType mongo.StreamType, cpt *cseProseTest) watcher { diff --git a/x/mongo/driver/crypt.go b/x/mongo/driver/crypt.go index 4368fd125d..7a51b7a4b9 100644 --- a/x/mongo/driver/crypt.go +++ b/x/mongo/driver/crypt.go @@ -9,9 +9,7 @@ package driver import ( "context" "crypto/tls" - "errors" "fmt" - "io" "strings" "time" @@ -399,8 +397,8 @@ func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error { res := make([]byte, bytesNeeded) bytesRead, err := conn.Read(res) - if err != nil && !errors.Is(err, io.EOF) { - return err + if err != nil { + return kmsCtx.RequestError() } if err = kmsCtx.FeedResponse(res[:bytesRead]); err != nil { diff --git a/x/mongo/driver/mongocrypt/mongocrypt.go b/x/mongo/driver/mongocrypt/mongocrypt.go index 5f34f5cd71..7f7c3e8fc9 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt.go +++ b/x/mongo/driver/mongocrypt/mongocrypt.go @@ -53,6 +53,7 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) { if wrapped == nil { return nil, errors.New("could not create new mongocrypt object") } + C.mongocrypt_setopt_retry_kms(wrapped, true) httpClient := opts.HTTPClient if httpClient == nil { httpClient = httputil.DefaultHTTPClient @@ -85,7 +86,7 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) { } if opts.BypassQueryAnalysis { - C.mongocrypt_setopt_bypass_query_analysis(wrapped) + C.mongocrypt_setopt_bypass_query_analysis(crypt.wrapped) } // If loading the crypt_shared library isn't disabled, set the default library search path "$SYSTEM" diff --git a/x/mongo/driver/mongocrypt/mongocrypt_kms_context.go b/x/mongo/driver/mongocrypt/mongocrypt_kms_context.go index 296a22315c..49baa37f2e 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt_kms_context.go +++ b/x/mongo/driver/mongocrypt/mongocrypt_kms_context.go @@ -11,6 +11,7 @@ package mongocrypt // #include import "C" +import "time" // KmsContext represents a mongocrypt_kms_ctx_t handle. type KmsContext struct { @@ -41,6 +42,8 @@ func (kc *KmsContext) KMSProvider() string { // Message returns the message to send to the KMS. func (kc *KmsContext) Message() ([]byte, error) { + time.Sleep(time.Duration(C.mongocrypt_kms_ctx_usleep(kc.wrapped)) * time.Microsecond) + msgBinary := newBinary() defer msgBinary.close() @@ -74,3 +77,11 @@ func (kc *KmsContext) createErrorFromStatus() error { C.mongocrypt_kms_ctx_status(kc.wrapped, status) return errorFromStatus(status) } + +// RequestError returns the source of the network error for KMS requests. +func (kc *KmsContext) RequestError() error { + if bool(C.mongocrypt_kms_ctx_fail(kc.wrapped)) { + return nil + } + return kc.createErrorFromStatus() +} diff --git a/x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go b/x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go index 6bce2f0299..7968897648 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go +++ b/x/mongo/driver/mongocrypt/mongocrypt_kms_context_not_enabled.go @@ -37,3 +37,8 @@ func (kc *KmsContext) BytesNeeded() int32 { func (kc *KmsContext) FeedResponse([]byte) error { panic(cseNotSupportedMsg) } + +// RequestError returns the source of the network error for KMS requests. +func (kc *KmsContext) RequestError() error { + panic(cseNotSupportedMsg) +}