Skip to content

Commit

Permalink
update NeedKms logic
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Nov 20, 2024
1 parent 12a1530 commit d4a0765
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ tasks:
- go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s ./internal/integration -run TestClientSideEncryptionProse/kms_tls_tests >> test.suite

evg-test-retry-kms-requests:
- go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout {{.TEST_TIMEOUT}}s ./internal/integration -run TestClientSideEncryptionProse/kms_retry_tests >> test.suite
- go test -exec "env PKG_CONFIG_PATH=${PKG_CONFIG_PATH} LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" ${BUILD_TAGS} -v -timeout 300s ./internal/integration -run TestClientSideEncryptionProse/kms_retry_tests >> test.suite

evg-test-load-balancers:
# Load balancer should be tested with all unified tests as well as tests in the following
Expand Down
2 changes: 1 addition & 1 deletion etc/install-libmongocrypt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This script installs libmongocrypt into an "install" directory.
set -eux

LIBMONGOCRYPT_TAG="1.11.0"
LIBMONGOCRYPT_TAG="1.12.0"

# Install libmongocrypt based on OS.
if [ "Windows_NT" = "${OS:-}" ]; then
Expand Down
4 changes: 2 additions & 2 deletions internal/integration/client_side_encryption_prose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3018,7 +3018,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo)
require.NoError(mt, err, "error on NewClientEncryption: %v", err)

err = setFailPoint("http", 1)
err = setFailPoint("network", 1)
require.NoError(mt, err, "mock server error: %v", err)

dkOpts := options.DataKey().SetMasterKey(
Expand All @@ -3032,7 +3032,7 @@ func TestClientSideEncryptionProse(t *testing.T) {
keyID, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts)
require.NoError(mt, err, "error in CreateDataKey: %v", err)

err = setFailPoint("http", 1)
err = setFailPoint("network", 1)
require.NoError(mt, err, "mock server error: %v", err)

testVal := bson.RawValue{Type: bson.TypeInt32, Value: bsoncore.AppendInt32(nil, 123)}
Expand Down
2 changes: 2 additions & 0 deletions mongo/client_encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,14 @@ func (ce *ClientEncryption) CreateDataKey(
}

// create data key document
fmt.Println("CreateDataKey")
dataKeyDoc, err := ce.crypt.CreateDataKey(ctx, kmsProvider, co)
if err != nil {
return bson.Binary{}, err
}

// insert key into key vault
fmt.Println("InsertOne")
_, err = ce.keyVaultColl.InsertOne(ctx, dataKeyDoc)
if err != nil {
return bson.Binary{}, err
Expand Down
11 changes: 11 additions & 0 deletions x/mongo/driver/crypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ func (c *crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Co
var err error
for {
state := cryptCtx.State()
fmt.Println("state", state)
switch state {
case mongocrypt.NeedMongoCollInfo:
err = c.collectionInfo(ctx, cryptCtx, db)
Expand Down Expand Up @@ -341,6 +342,7 @@ func (c *crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context)
}

func (c *crypt) decryptKeys(cryptCtx *mongocrypt.Context) error {
c.mongoCrypt.EnableRetry()
for {
kmsCtx := cryptCtx.NextKmsContext()
if kmsCtx == nil {
Expand Down Expand Up @@ -376,8 +378,10 @@ func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
if tlsCfg == nil {
tlsCfg = &tls.Config{MinVersion: tls.VersionTLS12}
}
fmt.Println("dial", addr, kmsProvider, tlsCfg)
conn, err := tls.Dial("tcp", addr, tlsCfg)
if err != nil {
fmt.Println("dial error", err)
return err
}
defer func() {
Expand All @@ -388,18 +392,25 @@ func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
return err
}
if _, err = conn.Write(msg); err != nil {
fmt.Println("conn write", err)
return err
}

for {
bytesNeeded := kmsCtx.BytesNeeded()
fmt.Println("bytesNeeded", bytesNeeded)
if bytesNeeded == 0 {
return nil
}

res := make([]byte, bytesNeeded)
bytesRead, err := conn.Read(res)
if err != nil && !errors.Is(err, io.EOF) {
fail := kmsCtx.Fail()
fmt.Println("conn read", err, fail)
if fail {
continue
}
return err
}

Expand Down
5 changes: 5 additions & 0 deletions x/mongo/driver/mongocrypt/mongocrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,8 @@ func (m *MongoCrypt) GetKmsProviders(ctx context.Context) (bsoncore.Document, er
}
return builder.Build(), nil
}

// EnableRetry enables retry.
func (m *MongoCrypt) EnableRetry() {
_ = C.mongocrypt_setopt_retry_kms(m.wrapped, true)
}
8 changes: 8 additions & 0 deletions x/mongo/driver/mongocrypt/mongocrypt_kms_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package mongocrypt

// #include <mongocrypt.h>
import "C"
import "time"

// KmsContext represents a mongocrypt_kms_ctx_t handle.
type KmsContext struct {
Expand Down Expand Up @@ -41,6 +42,8 @@ func (kc *KmsContext) KMSProvider() string {

// Message returns the message to send to the KMS.
func (kc *KmsContext) Message() ([]byte, error) {
time.Sleep(time.Duration(C.mongocrypt_kms_ctx_usleep(kc.wrapped)) * time.Microsecond)

msgBinary := newBinary()
defer msgBinary.close()

Expand Down Expand Up @@ -74,3 +77,8 @@ func (kc *KmsContext) createErrorFromStatus() error {
C.mongocrypt_kms_ctx_status(kc.wrapped, status)
return errorFromStatus(status)
}

// Fail returns a boolean indicating whether the failed request may be retried.
func (kc *KmsContext) Fail() bool {
return bool(C.mongocrypt_kms_ctx_fail(kc.wrapped))
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ func (kc *KmsContext) BytesNeeded() int32 {
func (kc *KmsContext) FeedResponse([]byte) error {
panic(cseNotSupportedMsg)
}

// Fail returns a boolean indicating whether the failed request may be retried.
func (kc *KmsContext) Fail() bool {
panic(cseNotSupportedMsg)
}
5 changes: 5 additions & 0 deletions x/mongo/driver/mongocrypt/mongocrypt_not_enabled.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,8 @@ func (m *MongoCrypt) Close() {
func (m *MongoCrypt) GetKmsProviders(context.Context) (bsoncore.Document, error) {
panic(cseNotSupportedMsg)
}

// EnableRetry enables retry.
func (m *MongoCrypt) EnableRetry() {
panic(cseNotSupportedMsg)
}

0 comments on commit d4a0765

Please sign in to comment.