diff --git a/.evergreen/config.yml b/.evergreen/config.yml index b078af8066..d2d382cf17 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -350,6 +350,23 @@ functions: chmod +x $i done + assume-ec2-role: + - command: ec2.assume_role + params: + role_arn: ${aws_test_secrets_role} + + run-oidc-auth-test-with-test-credentials: + - command: shell.exec + type: test + params: + working_dir: src/go.mongodb.org/mongo-driver + shell: bash + include_expansions_in_env: ["DRIVERS_TOOLS", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + script: | + ${PREPARE_SHELL} + export OIDC="oidc" + bash ${PROJECT_DIRECTORY}/etc/run-oidc-test.sh 'make -s evg-test-oidc-auth' + run-make: - command: shell.exec type: test @@ -560,8 +577,6 @@ functions: working_dir: src/go.mongodb.org/mongo-driver script: | ${PREPARE_SHELL} - - IS_SERVERLESS_PROXY="${IS_SERVERLESS_PROXY}" \ bash etc/run-serverless-test.sh run-atlas-data-lake-test: @@ -1954,6 +1969,60 @@ tasks: popd ./.evergreen/run-deployed-lambda-aws-tests.sh + - name: "oidc-auth-test-latest" + commands: + - func: "run-oidc-auth-test-with-test-credentials" + + - name: "oidc-auth-test-azure-latest" + commands: + - command: shell.exec + params: + working_dir: src/go.mongodb.org/mongo-driver + shell: bash + script: |- + set -o errexit + ${PREPARE_SHELL} + export AZUREOIDC_DRIVERS_TAR_FILE=/tmp/mongo-go-driver.tar.gz + # we need to statically link libc to avoid the situation where the VM has a different + # version of libc + go build -tags osusergo,netgo -ldflags '-w -extldflags "-static -lgcc -lc"' -o test ./cmd/testoidcauth/main.go + rm "$AZUREOIDC_DRIVERS_TAR_FILE" || true + tar -cf $AZUREOIDC_DRIVERS_TAR_FILE ./test + tar -uf $AZUREOIDC_DRIVERS_TAR_FILE ./etc + rm "$AZUREOIDC_DRIVERS_TAR_FILE".gz || true + gzip $AZUREOIDC_DRIVERS_TAR_FILE + export AZUREOIDC_DRIVERS_TAR_FILE=/tmp/mongo-go-driver.tar.gz + # Define the command to run on the azure VM. + # Ensure that we source the environment file created for us, set up any other variables we need, + # and then run our test suite on the vm. + export AZUREOIDC_TEST_CMD="PROJECT_DIRECTORY='.' OIDC_ENV=azure OIDC=oidc ./etc/run-oidc-test.sh ./test" + bash $DRIVERS_TOOLS/.evergreen/auth_oidc/azure/run-driver-test.sh + + - name: "oidc-auth-test-gcp-latest" + commands: + - command: shell.exec + params: + working_dir: src/go.mongodb.org/mongo-driver + shell: bash + script: |- + set -o errexit + ${PREPARE_SHELL} + export GCPOIDC_DRIVERS_TAR_FILE=/tmp/mongo-go-driver.tar.gz + # we need to statically link libc to avoid the situation where the VM has a different + # version of libc + go build -tags osusergo,netgo -ldflags '-w -extldflags "-static -lgcc -lc"' -o test ./cmd/testoidcauth/main.go + rm "$GCPOIDC_DRIVERS_TAR_FILE" || true + tar -cf $GCPOIDC_DRIVERS_TAR_FILE ./test + tar -uf $GCPOIDC_DRIVERS_TAR_FILE ./etc + rm "$GCPOIDC_DRIVERS_TAR_FILE".gz || true + gzip $GCPOIDC_DRIVERS_TAR_FILE + export GCPOIDC_DRIVERS_TAR_FILE=/tmp/mongo-go-driver.tar.gz + # Define the command to run on the gcp VM. + # Ensure that we source the environment file created for us, set up any other variables we need, + # and then run our test suite on the vm. + export GCPOIDC_TEST_CMD="PROJECT_DIRECTORY='.' OIDC_ENV=gcp OIDC=oidc ./etc/run-oidc-test.sh ./test" + bash $DRIVERS_TOOLS/.evergreen/auth_oidc/gcp/run-driver-test.sh + - name: "test-search-index" commands: - func: "bootstrap-mongo-orchestration" @@ -2014,7 +2083,7 @@ axes: - id: "windows-64-go-1-20" display_name: "Windows 64-bit" run_on: - - windows-vsCurrent-latest-small + - windows-vsCurrent-small variables: GCC_PATH: "/cygdrive/c/ProgramData/chocolatey/lib/mingw/tools/install/mingw64/bin" GO_DIST: "C:\\golang\\go1.20" @@ -2038,7 +2107,7 @@ axes: - id: "windows-64-go-1-20" display_name: "Windows 64-bit" run_on: - - windows-vsCurrent-latest-small + - windows-vsCurrent-small variables: GCC_PATH: "/cygdrive/c/ProgramData/chocolatey/lib/mingw/tools/install/mingw64/bin" GO_DIST: "C:\\golang\\go1.20" @@ -2070,7 +2139,7 @@ axes: - id: "windows-64-vsCurrent-latest-small-go-1-20" display_name: "Windows 64-bit" run_on: - - windows-vsCurrent-latest-small + - windows-vsCurrent-small variables: GCC_PATH: "/cygdrive/c/ProgramData/chocolatey/lib/mingw/tools/install/mingw64/bin" GO_DIST: "C:\\golang\\go1.20" @@ -2108,17 +2177,6 @@ axes: variables: GO_DIST: "/opt/golang/go1.20" - - id: serverless-type - display_name: "Serverless Type" - values: - - id: "original" - display_name: "Serverless" - - id: "proxy" - display_name: "Serverless Proxy" - variables: - VAULT_NAME: "serverless_next" - IS_SERVERLESS_PROXY: "true" - task_groups: - name: serverless_task_group setup_group_can_fail_task: true @@ -2247,6 +2305,79 @@ task_groups: tasks: - testazurekms-task + - name: testoidc_task_group + setup_group: + - func: fetch-source + - func: prepare-resources + - func: fix-absolute-paths + - func: make-files-executable + - func: assume-ec2-role + - command: shell.exec + params: + shell: bash + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + script: | + ${PREPARE_SHELL} + ${DRIVERS_TOOLS}/.evergreen/auth_oidc/setup.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/teardown.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-latest + + - name: testazureoidc_task_group + setup_group: + - func: fetch-source + - func: prepare-resources + - func: fix-absolute-paths + - func: make-files-executable + - command: subprocess.exec + params: + binary: bash + env: + AZUREOIDC_VMNAME_PREFIX: "GO_DRIVER" + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/azure/create-and-setup-vm.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/azure/delete-vm.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-azure-latest + + - name: testgcpoidc_task_group + setup_group: + - func: fetch-source + - func: prepare-resources + - func: fix-absolute-paths + - func: make-files-executable + - command: subprocess.exec + params: + binary: bash + env: + AZUREOIDC_VMNAME_PREFIX: "GO_DRIVER" + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/gcp/setup.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/gcp/teardown.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-gcp-latest + - name: test-aws-lambda-task-group setup_group: - func: fetch-source @@ -2391,23 +2522,48 @@ buildvariants: tasks: - name: "test-docker-runner" - - matrix_name: "tests-36-with-zlib-support" + - matrix_name: "tests-rhel-36-with-zlib-support" tags: ["pullrequest"] - matrix_spec: { version: ["3.6"], os-ssl-32: ["windows-64-go-1-20", "rhel87-64-go-1-20"] } + matrix_spec: { version: ["3.6"], os-ssl-32: ["rhel87-64-go-1-20"] } display_name: "${version} ${os-ssl-32}" tasks: - name: ".test !.enterprise-auth !.snappy !.zstd" - - matrix_name: "tests-40-with-zlib-support" + - matrix_name: "tests-windows-36-with-zlib-support" + matrix_spec: { version: ["3.6"], os-ssl-32: ["windows-64-go-1-20"] } + display_name: "${version} ${os-ssl-32}" + tasks: + - name: ".test !.enterprise-auth !.snappy !.zstd" + + - matrix_name: "tests-rhel-40-with-zlib-support" tags: ["pullrequest"] - matrix_spec: { version: ["4.0"], os-ssl-40: ["windows-64-go-1-20", "rhel87-64-go-1-20"] } + matrix_spec: { version: ["4.0"], os-ssl-40: ["rhel87-64-go-1-20"] } + display_name: "${version} ${os-ssl-40}" + tasks: + - name: ".test !.enterprise-auth !.snappy !.zstd" + + - matrix_name: "tests-windows-40-with-zlib-support" + matrix_spec: { version: ["4.0"], os-ssl-40: ["windows-64-go-1-20"] } display_name: "${version} ${os-ssl-40}" tasks: - name: ".test !.enterprise-auth !.snappy !.zstd" - - matrix_name: "tests-42-plus-zlib-zstd-support" + - matrix_name: "tests-rhel-42-plus-zlib-zstd-support" + tags: ["pullrequest"] + matrix_spec: { version: ["4.2", "4.4", "5.0", "6.0", "7.0", "8.0"], os-ssl-40: ["rhel87-64-go-1-20"] } + display_name: "${version} ${os-ssl-40}" + tasks: + - name: ".test !.enterprise-auth !.snappy" + + - matrix_name: "tests-windows-42-plus-zlib-zstd-support" + matrix_spec: { version: ["4.2", "4.4", "5.0", "6.0", "7.0"], os-ssl-40: ["windows-64-go-1-20"] } + display_name: "${version} ${os-ssl-40}" + tasks: + - name: ".test !.enterprise-auth !.snappy" + + - matrix_name: "tests-windows-80-zlib-zstd-support" tags: ["pullrequest"] - matrix_spec: { version: ["4.2", "4.4", "5.0", "6.0", "7.0", "8.0"], os-ssl-40: ["windows-64-go-1-20", "rhel87-64-go-1-20"] } + matrix_spec: { version: ["8.0"], os-ssl-40: ["windows-64-go-1-20"] } display_name: "${version} ${os-ssl-40}" tasks: - name: ".test !.enterprise-auth !.snappy" @@ -2494,14 +2650,8 @@ buildvariants: - matrix_name: "serverless" tags: ["pullrequest"] - matrix_spec: { os-serverless: "*", serverless-type: "original" } - display_name: "${serverless-type} ${os-serverless}" - tasks: - - "serverless_task_group" - - - matrix_name: "serverless-proxy" - matrix_spec: { os-serverless: "*", serverless-type: "proxy" } - display_name: "${serverless-type} ${os-serverless}" + matrix_spec: { os-serverless: "*" } + display_name: "Serverless ${os-serverless}" tasks: - "serverless_task_group" @@ -2561,3 +2711,17 @@ buildvariants: - name: testazurekms_task_group batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README - testazurekms-fail-task + + - name: testoidc-variant + display_name: "OIDC" + run_on: + - ubuntu2204-large + expansions: + GO_DIST: "/opt/golang/go1.20" + tasks: + - name: testoidc_task_group + batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README + - name: testazureoidc_task_group + batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README + - name: testgcpoidc_task_group + batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000000..21c81a32fa --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,51 @@ +name: "CodeQL" + +on: + push: + branches: [ "v1", "cloud-*", "master", "release/*" ] + pull_request: + branches: [ "v1", "cloud-*", "master", "release/*" ] + schedule: + - cron: '36 17 * * 0' + workflow_call: + inputs: + ref: + required: true + type: string + +jobs: + analyze: + name: Analyze (${{ matrix.language }}) + runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} + timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} + permissions: + # required for all workflows + security-events: write + + strategy: + fail-fast: false + matrix: + include: + - language: go + build-mode: manual + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: ${{ matrix.language }} + build-mode: ${{ matrix.build-mode }} + + - if: matrix.build-mode == 'manual' + shell: bash + run: | + make build + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:${{matrix.language}}" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000000..0f4d446237 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,88 @@ +name: Release + +on: + workflow_dispatch: + inputs: + version: + description: "The new version to set" + required: true + prev_version: + description: "The previous tagged version" + required: true + push_changes: + description: "Push changes?" + default: true + type: boolean + +defaults: + run: + shell: bash -eux {0} + +env: + # Changes per branch + SILK_ASSET_GROUP: mongodb-go-driver-v1 + EVERGREEN_PROJECT: mongo-go-driver-v1 + +jobs: + pre-publish: + environment: release + runs-on: ubuntu-latest + permissions: + id-token: write + contents: write + outputs: + prev_version: ${{ steps.pre-publish.outputs.prev_version }} + steps: + - uses: mongodb-labs/drivers-github-tools/secure-checkout@v2 + with: + app_id: ${{ vars.APP_ID }} + private_key: ${{ secrets.APP_PRIVATE_KEY }} + - uses: mongodb-labs/drivers-github-tools/setup@v2 + with: + aws_role_arn: ${{ secrets.AWS_ROLE_ARN }} + aws_region_name: ${{ vars.AWS_REGION_NAME }} + aws_secret_id: ${{ secrets.AWS_SECRET_ID }} + artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }} + - name: Pre Publish + id: pre-publish + uses: mongodb-labs/drivers-github-tools/golang/pre-publish@v2 + with: + version: ${{ inputs.version }} + push_changes: ${{ inputs.push_changes }} + + static-scan: + needs: [pre-publish] + permissions: + security-events: write + uses: ./.github/workflows/codeql.yml + with: + ref: ${{ github.ref }} + + publish: + needs: [pre-publish, static-scan] + runs-on: ubuntu-latest + environment: release + permissions: + id-token: write + contents: write + security-events: read + steps: + - uses: mongodb-labs/drivers-github-tools/secure-checkout@v2 + with: + app_id: ${{ vars.APP_ID }} + private_key: ${{ secrets.APP_PRIVATE_KEY }} + - uses: mongodb-labs/drivers-github-tools/setup@v2 + with: + aws_role_arn: ${{ secrets.AWS_ROLE_ARN }} + aws_region_name: ${{ vars.AWS_REGION_NAME }} + aws_secret_id: ${{ secrets.AWS_SECRET_ID }} + artifactory_username: ${{ vars.ARTIFACTORY_USERNAME }} + - name: Publish + uses: mongodb-labs/drivers-github-tools/golang/publish@v2 + with: + version: ${{ inputs.version }} + silk_asset_group: ${{ env.SILK_ASSET_GROUP }} + evergreen_project: ${{ env.EVERGREEN_PROJECT }} + prev_version: ${{ inputs.prev_version }} + push_changes: ${{ inputs.push_changes }} + token: ${{ env.GH_TOKEN }} diff --git a/Makefile b/Makefile index 88bc756390..b38bb4b6f0 100644 --- a/Makefile +++ b/Makefile @@ -132,6 +132,11 @@ evg-test-atlas-data-lake: evg-test-enterprise-auth: go run -tags gssapi ./cmd/testentauth/main.go +.PHONY: evg-test-oidc-auth +evg-test-oidc-auth: + go run ./cmd/testoidcauth/main.go + go run -race ./cmd/testoidcauth/main.go + .PHONY: evg-test-kmip evg-test-kmip: go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH) DYLD_LIBRARY_PATH=$(MACOS_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionSpec/kmipKMS >> test.suite diff --git a/bson/bsoncodec/default_value_decoders.go b/bson/bsoncodec/default_value_decoders.go index 7e08aab35e..fc4a7b1dbf 100644 --- a/bson/bsoncodec/default_value_decoders.go +++ b/bson/bsoncodec/default_value_decoders.go @@ -330,7 +330,7 @@ func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr bsonrw.ValueReade case reflect.Int64: return reflect.ValueOf(i64), nil case reflect.Int: - if int64(int(i64)) != i64 { // Can we fit this inside of an int + if i64 > math.MaxInt { // Can we fit this inside of an int return emptyValue, fmt.Errorf("%d overflows int", i64) } @@ -434,7 +434,7 @@ func (dvd DefaultValueDecoders) UintDecodeValue(dc DecodeContext, vr bsonrw.Valu return fmt.Errorf("%d overflows uint64", i64) } case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + if i64 < 0 || uint64(i64) > uint64(math.MaxUint) { // Can we fit this inside of an uint return fmt.Errorf("%d overflows uint", i64) } default: diff --git a/bson/bsoncodec/uint_codec.go b/bson/bsoncodec/uint_codec.go index 8525472769..39b07135b1 100644 --- a/bson/bsoncodec/uint_codec.go +++ b/bson/bsoncodec/uint_codec.go @@ -164,11 +164,15 @@ func (uic *UIntCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t refl return reflect.ValueOf(uint64(i64)), nil case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + if i64 < 0 { + return emptyValue, fmt.Errorf("%d overflows uint", i64) + } + v := uint64(i64) + if v > math.MaxUint { // Can we fit this inside of an uint return emptyValue, fmt.Errorf("%d overflows uint", i64) } - return reflect.ValueOf(uint(i64)), nil + return reflect.ValueOf(uint(v)), nil default: return emptyValue, ValueDecoderError{ Name: "UintDecodeValue", diff --git a/bson/bsonrw/extjson_wrappers.go b/bson/bsonrw/extjson_wrappers.go index 9695704246..af6ae7b76b 100644 --- a/bson/bsonrw/extjson_wrappers.go +++ b/bson/bsonrw/extjson_wrappers.go @@ -95,9 +95,9 @@ func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) { return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t) } - i, err := strconv.ParseInt(val.v.(string), 16, 64) + i, err := strconv.ParseUint(val.v.(string), 16, 8) if err != nil { - return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string)) + return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err) } subType = byte(i) diff --git a/bson/bsonrw/value_reader.go b/bson/bsonrw/value_reader.go index a242bb57cf..0e07d50558 100644 --- a/bson/bsonrw/value_reader.go +++ b/bson/bsonrw/value_reader.go @@ -842,7 +842,7 @@ func (vr *valueReader) peekLength() (int32, error) { } idx := vr.offset - return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil + return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil } func (vr *valueReader) readLength() (int32, error) { return vr.readi32() } @@ -854,7 +854,7 @@ func (vr *valueReader) readi32() (int32, error) { idx := vr.offset vr.offset += 4 - return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil + return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil } func (vr *valueReader) readu32() (uint32, error) { @@ -864,7 +864,7 @@ func (vr *valueReader) readu32() (uint32, error) { idx := vr.offset vr.offset += 4 - return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil + return binary.LittleEndian.Uint32(vr.d[idx:]), nil } func (vr *valueReader) readi64() (int64, error) { @@ -874,8 +874,7 @@ func (vr *valueReader) readi64() (int64, error) { idx := vr.offset vr.offset += 8 - return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 | - int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil + return int64(binary.LittleEndian.Uint64(vr.d[idx:])), nil } func (vr *valueReader) readu64() (uint64, error) { @@ -885,6 +884,5 @@ func (vr *valueReader) readu64() (uint64, error) { idx := vr.offset vr.offset += 8 - return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 | - uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil + return binary.LittleEndian.Uint64(vr.d[idx:]), nil } diff --git a/bson/raw_value.go b/bson/raw_value.go index 4d1bfb3160..a8088e1e30 100644 --- a/bson/raw_value.go +++ b/bson/raw_value.go @@ -88,8 +88,12 @@ func (rv RawValue) UnmarshalWithRegistry(r *bsoncodec.Registry, val interface{}) return dec.DecodeValue(bsoncodec.DecodeContext{Registry: r}, vr, rval) } -// UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses the provided DecodeContext -// instead of the one attached or the default registry. +// UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses +// the provided DecodeContext instead of the one attached or the default +// registry. +// +// Deprecated: Use [RawValue.UnmarshalWithRegistry] with a custom registry to customize +// unmarshal behavior instead. func (rv RawValue) UnmarshalWithContext(dc *bsoncodec.DecodeContext, val interface{}) error { if dc == nil { return ErrNilContext diff --git a/bson/registry.go b/bson/registry.go index b5b0f35687..d6afb2850e 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -10,15 +10,27 @@ import ( "go.mongodb.org/mongo-driver/bson/bsoncodec" ) -// DefaultRegistry is the default bsoncodec.Registry. It contains the default codecs and the -// primitive codecs. +// DefaultRegistry is the default bsoncodec.Registry. It contains the default +// codecs and the primitive codecs. +// +// Deprecated: Use [NewRegistry] to construct a new default registry. To use a +// custom registry when marshaling or unmarshaling, use the "SetRegistry" method +// on an [Encoder] or [Decoder] instead: +// +// dec, err := bson.NewDecoder(bsonrw.NewBSONDocumentReader(data)) +// if err != nil { +// panic(err) +// } +// dec.SetRegistry(reg) +// +// See [Encoder] and [Decoder] for more examples. var DefaultRegistry = NewRegistry() // NewRegistryBuilder creates a new RegistryBuilder configured with the default encoders and // decoders from the bsoncodec.DefaultValueEncoders and bsoncodec.DefaultValueDecoders types and the // PrimitiveCodecs type in this package. // -// Deprecated: Use NewRegistry instead. +// Deprecated: Use [NewRegistry] instead. func NewRegistryBuilder() *bsoncodec.RegistryBuilder { rb := bsoncodec.NewRegistryBuilder() bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb) diff --git a/cmd/testatlas/main.go b/cmd/testatlas/atlas_test.go similarity index 82% rename from cmd/testatlas/main.go rename to cmd/testatlas/atlas_test.go index ae1b15fcbc..1b60c64769 100644 --- a/cmd/testatlas/main.go +++ b/cmd/testatlas/atlas_test.go @@ -11,6 +11,8 @@ import ( "errors" "flag" "fmt" + "os" + "testing" "time" "go.mongodb.org/mongo-driver/bson" @@ -19,15 +21,19 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -func main() { +func TestMain(m *testing.M) { flag.Parse() + os.Exit(m.Run()) +} + +func TestAtlas(t *testing.T) { uris := flag.Args() ctx := context.Background() - fmt.Printf("Running atlas tests for %d uris\n", len(uris)) + t.Logf("Running atlas tests for %d uris\n", len(uris)) for idx, uri := range uris { - fmt.Printf("Running test %d\n", idx) + t.Logf("Running test %d\n", idx) // Set a low server selection timeout so we fail fast if there are errors. clientOpts := options.Client(). @@ -36,18 +42,18 @@ func main() { // Run basic connectivity test. if err := runTest(ctx, clientOpts); err != nil { - panic(fmt.Sprintf("error running test with TLS at index %d: %v", idx, err)) + t.Fatalf("error running test with TLS at index %d: %v", idx, err) } // Run the connectivity test with InsecureSkipVerify to ensure SNI is done correctly even if verification is // disabled. clientOpts.TLSConfig.InsecureSkipVerify = true if err := runTest(ctx, clientOpts); err != nil { - panic(fmt.Sprintf("error running test with tlsInsecure at index %d: %v", idx, err)) + t.Fatalf("error running test with tlsInsecure at index %d: %v", idx, err) } } - fmt.Println("Finished!") + t.Logf("Finished!") } func runTest(ctx context.Context, clientOpts *options.ClientOptions) error { diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go new file mode 100644 index 0000000000..4bed494c4a --- /dev/null +++ b/cmd/testoidcauth/main.go @@ -0,0 +1,1571 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package main + +import ( + "context" + "fmt" + "log" + "os" + "path" + "reflect" + "sync" + "time" + "unsafe" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/x/mongo/driver/auth" +) + +var uriAdmin = os.Getenv("MONGODB_URI") +var uriSingle = os.Getenv("MONGODB_URI_SINGLE") +var uriMulti = os.Getenv("MONGODB_URI_MULTI") +var oidcTokenDir = os.Getenv("OIDC_TOKEN_DIR") + +var oidcDomain = os.Getenv("OIDC_DOMAIN") + +func explicitUser(user string) string { + return fmt.Sprintf("%s@%s", user, oidcDomain) +} + +func tokenFile(user string) string { + return path.Join(oidcTokenDir, user) +} + +func connectAdminClinet() (*mongo.Client, error) { + return mongo.Connect(context.Background(), options.Client().ApplyURI(uriAdmin)) +} + +func connectWithMachineCB(uri string, cb options.OIDCCallback) (*mongo.Client, error) { + opts := options.Client().ApplyURI(uri) + + opts.Auth.OIDCMachineCallback = cb + return mongo.Connect(context.Background(), opts) +} + +func connectWithHumanCB(uri string, cb options.OIDCCallback) (*mongo.Client, error) { + opts := options.Client().ApplyURI(uri) + + opts.Auth.OIDCHumanCallback = cb + return mongo.Connect(context.Background(), opts) +} + +func connectWithMachineCBAndProperties(uri string, cb options.OIDCCallback, props map[string]string) (*mongo.Client, error) { + opts := options.Client().ApplyURI(uri) + + opts.Auth.OIDCMachineCallback = cb + opts.Auth.AuthMechanismProperties = props + return mongo.Connect(context.Background(), opts) +} + +func main() { + // be quiet linter + _ = tokenFile("test_user2") + + hasError := false + aux := func(test_name string, f func() error) { + fmt.Printf("%s...", test_name) + err := f() + if err != nil { + fmt.Println("Test Error: ", err) + fmt.Println("...Failed") + hasError = true + } else { + fmt.Println("...Ok") + } + } + env := os.Getenv("OIDC_ENV") + switch env { + case "": + aux("machine_1_1_callbackIsCalled", machine11callbackIsCalled) + aux("machine_1_2_callbackIsCalledOnlyOneForMultipleConnections", machine12callbackIsCalledOnlyOneForMultipleConnections) + aux("machine_2_1_validCallbackInputs", machine21validCallbackInputs) + aux("machine_2_3_oidcCallbackReturnMissingData", machine23oidcCallbackReturnMissingData) + aux("machine_2_4_invalidClientConfigurationWithCallback", machine24invalidClientConfigurationWithCallback) + aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine31failureWithCachedTokensFetchANewTokenAndRetryAuth) + aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine32authFailuresWithoutCachedTokensReturnsAnError) + aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache) + aux("machine_4_1_reauthenticationSucceeds", machine41ReauthenticationSucceeds) + aux("machine_4_2_readCommandsFailIfReauthenticationFails", machine42ReadCommandsFailIfReauthenticationFails) + aux("machine_4_3_writeCommandsFailIfReauthenticationFails", machine43WriteCommandsFailIfReauthenticationFails) + aux("human_1_1_singlePrincipalImplictUsername", human11singlePrincipalImplictUsername) + aux("human_1_2_singlePrincipalExplicitUsername", human12singlePrincipalExplicitUsername) + aux("human_1_3_mulitplePrincipalUser1", human13mulitplePrincipalUser1) + aux("human_1_4_mulitplePrincipalUser2", human14mulitplePrincipalUser2) + aux("human_1_5_multiplPrincipalNoUser", human15mulitplePrincipalNoUser) + aux("human_1_6_allowedHostsBlocked", human16allowedHostsBlocked) + aux("human_1_7_allowedHostsInConnectionStringIgnored", human17AllowedHostsInConnectionStringIgnored) + aux("human_2_1_validCallbackInputs", human21validCallbackInputs) + aux("human_2_2_CallbackReturnsMissingData", human22CallbackReturnsMissingData) + aux("human_2_3_RefreshTokenIsPassedToCallback", human23RefreshTokenIsPassedToCallback) + aux("human_3_1_usesSpeculativeAuth", human31usesSpeculativeAuth) + aux("human_3_2_doesNotUseSpecualtiveAuth", human32doesNotUseSpecualtiveAuth) + aux("human_4_1_reauthenticationSucceeds", human41ReauthenticationSucceeds) + aux("human_4_2_reauthenticationSucceedsNoRefresh", human42ReauthenticationSucceedsNoRefreshToken) + aux("human_4_3_reauthenticationSucceedsAfterRefreshFails", human43ReauthenticationSucceedsAfterRefreshFails) + aux("human_4_4_reauthenticationFails", human44ReauthenticationFails) + case "azure": + aux("machine_5_1_azureWithNoUsername", machine51azureWithNoUsername) + aux("machine_5_2_azureWithNoUsername", machine52azureWithBadUsername) + case "gcp": + aux("machine_6_1_gcpWithNoUsername", machine61gcpWithNoUsername) + default: + log.Fatal("Unknown OIDC_ENV: ", env) + } + if hasError { + log.Fatal("One or more tests failed") + } +} + +func machine11callbackIsCalled() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_1_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_1_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_1_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_1_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine12callbackIsCalledOnlyOneForMultipleConnections() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_1_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_1_2: failed connecting client: %v", err) + } + + var wg sync.WaitGroup + + var findFailed error + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + coll := client.Database("test").Collection("test") + _, err := coll.Find(context.Background(), bson.D{}) + if err != nil { + findFailed = fmt.Errorf("machine_1_2: failed executing Find: %v", err) + } + }() + } + + wg.Wait() + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_1_2: expected callback count to be 1, got %d", callbackCount) + } + if callbackFailed != nil { + return callbackFailed + } + return findFailed +} + +func machine21validCallbackInputs() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + if args.RefreshToken != nil { + callbackFailed = fmt.Errorf("machine_2_1: expected RefreshToken to be nil, got %v", args.RefreshToken) + } + timeout, ok := ctx.Deadline() + if !ok { + callbackFailed = fmt.Errorf("machine_2_1: expected context to have deadline, got %v", ctx) + } + if timeout.Before(time.Now()) { + callbackFailed = fmt.Errorf("machine_2_1: expected timeout to be in the future, got %v", timeout) + } + if args.Version < 1 { + callbackFailed = fmt.Errorf("machine_2_1: expected Version to be at least 1, got %d", args.Version) + } + if args.IDPInfo != nil { + callbackFailed = fmt.Errorf("machine_2_1: expected IdpID to be nil for Machine flow, got %v", args.IDPInfo) + } + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + fmt.Printf("machine_2_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_2_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_2_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_2_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine23oidcCallbackReturnMissingData() error { + callbackCount := 0 + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + return &options.OIDCCredential{ + AccessToken: "", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_2_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_2_3: should have failed to executed Find, but succeeded") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_2_3: expected callback count to be 1, got %d", callbackCount) + } + return nil +} + +func machine24invalidClientConfigurationWithCallback() error { + _, err := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + t := time.Now().Add(time.Hour) + return &options.OIDCCredential{ + AccessToken: "", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }, + map[string]string{"ENVIRONMENT": "test"}, + ) + if err == nil { + return fmt.Errorf("machine_2_4: succeeded building client when it should fail") + } + return nil +} + +func machine31failureWithCachedTokensFetchANewTokenAndRetryAuth() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_3_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_3_1: failed connecting client: %v", err) + } + + // Poison the cache with a random token + clientElem := reflect.ValueOf(client).Elem() + authenticatorField := clientElem.FieldByName("authenticator") + authenticatorField = reflect.NewAt( + authenticatorField.Type(), + unsafe.Pointer(authenticatorField.UnsafeAddr())).Elem() + // this is the only usage of the x packages in the test, showing the the public interface is + // correct. + authenticatorField.Interface().(*auth.OIDCAuthenticator).SetAccessToken("some random happy sunshine string") + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_3_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine32authFailuresWithoutCachedTokensReturnsAnError() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + return &options.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_3_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_3_2: Find ucceeded when it should fail") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_2: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("machine_3_3: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_3_3: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_3_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "saslStart", + }}, + {Key: "errorCode", Value: 20}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_3_3: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_3_3: Find succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d", callbackCount) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_3_3: failed executing Find: %v", err) + } + if callbackCount != 1 { + return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine41ReauthenticationSucceeds() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("machine_4_1: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_1: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_1: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func machine42ReadCommandsFailIfReauthenticationFails() error { + callbackCount := 0 + var callbackFailed error + firstCall := true + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("machine_4_2: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + if firstCall { + firstCall = false + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + return &options.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_2: failed executing Find: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_2: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_4_2: Find succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_2: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func machine43WriteCommandsFailIfReauthenticationFails() error { + callbackCount := 0 + var callbackFailed error + firstCall := true + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("machine_4_3: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + if firstCall { + firstCall = false + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_3: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + return &options.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + _, err = coll.InsertOne(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_3: failed executing Insert: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "insert", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_3: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.InsertOne(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_4_3: Insert succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_3: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func human11singlePrincipalImplictUsername() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_1_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_1_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_1_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func human12singlePrincipalExplicitUsername() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + opts := options.Client().ApplyURI(uriSingle) + opts.Auth.OIDCHumanCallback = func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + opts.Auth.Username = explicitUser("test_user1") + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("human_1_2: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_1_2: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_1_2: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func human13mulitplePrincipalUser1() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + opts := options.Client().ApplyURI(uriMulti) + opts.Auth.OIDCHumanCallback = func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_3: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + opts.Auth.Username = explicitUser("test_user1") + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("human_1_3: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_1_3: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_1_3: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func human14mulitplePrincipalUser2() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + opts := options.Client().ApplyURI(uriMulti) + opts.Auth.OIDCHumanCallback = func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user2") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_4: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + opts.Auth.Username = explicitUser("test_user2") + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("human_1_4: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_1_4: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_1_4: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func human15mulitplePrincipalNoUser() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithHumanCB(uriMulti, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_5: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + if err != nil { + return fmt.Errorf("human_1_5: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("human_1_5: Find succeeded when it should fail") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 0 { + return fmt.Errorf("human_1_5: expected callback count to be 0, got %d", callbackCount) + } + return callbackFailed +} + +func human16allowedHostsBlocked() error { + var callbackFailed error + { + opts := options.Client().ApplyURI(uriSingle) + opts.Auth.OIDCHumanCallback = func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_6: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + opts.Auth.AuthMechanismProperties = map[string]string{"ALLOWED_HOSTS": ""} + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("human_1_4: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_1_6: Find succeeded when it should fail with empty 'ALLOWED_HOSTS'") + } + } + { + opts := options.Client().ApplyURI("mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com") + opts.Auth.OIDCHumanCallback = func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_1_6: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + opts.Auth.AuthMechanismProperties = map[string]string{"ALLOWED_HOSTS": "example.com"} + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("human_1_4: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_1_6: Find succeeded when it should fail with 'ALLOWED_HOSTS' 'example.com'") + } + } + return callbackFailed +} + +func human17AllowedHostsInConnectionStringIgnored() error { + uri := "mongodb+srv://example.com/?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:%5B%22example.com%22%5D" + opts := options.Client().ApplyURI(uri) + err := opts.Validate() + if err == nil { + return fmt.Errorf("human_1_7: succeeded in applying URI which should produce an error") + } + return nil +} + +func human21validCallbackInputs() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + if args.Version != 1 { + callbackFailed = fmt.Errorf("human_2_1: expected version to be 1, got %d", args.Version) + } + if args.IDPInfo == nil { + callbackFailed = fmt.Errorf("human_2_1: expected IDPInfo to be non-nil, previous error: (%v)", callbackFailed) + } + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_2_1: failed reading token file: %v, previous error: (%v)", err, callbackFailed) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_2_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_2_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_2_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func human22CallbackReturnsMissingData() error { + callbackCount := 0 + countMutex := sync.Mutex{} + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + return &options.OIDCCredential{}, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_2_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("human_2_2: Find succeeded when it should fail") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("human_2_2: expected callback count to be 1, got %d", callbackCount) + } + return nil +} + +func human23RefreshTokenIsPassedToCallback() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_2_3: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + if callbackCount == 1 && args.RefreshToken != nil { + callbackFailed = fmt.Errorf("human_2_3: expected refresh token to be nil first time, got %v, previous error: (%v)", args.RefreshToken, callbackFailed) + } + if callbackCount == 2 && args.RefreshToken == nil { + callbackFailed = fmt.Errorf("human_2_3: expected refresh token to be non-nil second time, got %v, previous error: (%v)", args.RefreshToken, callbackFailed) + } + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_2_3: failed reading token file: %v", err) + } + rt := "this is fake" + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: &rt, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_2_3: failed connecting client: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_2_3: failed to set failpoint") + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_2_3: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("human_2_3: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func human31usesSpeculativeAuth() error { + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_3_1: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + // the callback should not even be called due to spec auth. + return &options.OIDCCredential{}, nil + }) + + if err != nil { + return fmt.Errorf("human_3_1: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + // We deviate from the Prose test since the failPoint on find with no error code does not seem to + // work. Rather we put an access token in the cache to force speculative auth. + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + return fmt.Errorf("human_3_1: failed reading token file: %v", err) + } + clientElem := reflect.ValueOf(client).Elem() + authenticatorField := clientElem.FieldByName("authenticator") + authenticatorField = reflect.NewAt( + authenticatorField.Type(), + unsafe.Pointer(authenticatorField.UnsafeAddr())).Elem() + // This is the only usage of the x packages in the test, showing the the public interface is + // correct. + authenticatorField.Interface().(*auth.OIDCAuthenticator).SetAccessToken(string(accessToken)) + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "saslStart", + }}, + {Key: "errorCode", Value: 18}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_3_1: failed to set failpoint") + } + + coll := client.Database("test").Collection("test") + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_3_1: failed executing Find: %v", err) + } + + return nil +} + +func human32doesNotUseSpecualtiveAuth() error { + var callbackFailed error + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_3_2: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_3_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_3_2: failed connecting client: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "saslStart", + }}, + {Key: "errorCode", Value: 18}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_3_2: failed to set failpoint") + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("human_3_2: Find succeeded when it should fail") + } + return callbackFailed +} + +func human41ReauthenticationSucceeds() error { + return nil +} + +func human42ReauthenticationSucceedsNoRefreshToken() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_4_2: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_4_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_4_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_4_2: failed executing Find: %v", err) + } + + countMutex.Lock() + if callbackCount != 1 { + return fmt.Errorf("human_4_2: expected callback count to be 1, got %d", callbackCount) + } + countMutex.Unlock() + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_4_2: failed to set failpoint") + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_4_2: failed executing Find: %v", err) + } + + countMutex.Lock() + if callbackCount != 2 { + return fmt.Errorf("human_4_2: expected callback count to be 2, got %d", callbackCount) + } + countMutex.Unlock() + return callbackFailed +} + +func human43ReauthenticationSucceedsAfterRefreshFails() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_4_3: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_4_3: failed reading token file: %v", err) + } + refreshToken := "bad token" + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: &refreshToken, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_4_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_4_3: failed executing Find: %v", err) + } + + countMutex.Lock() + if callbackCount != 1 { + return fmt.Errorf("human_4_3: expected callback count to be 1, got %d", callbackCount) + } + countMutex.Unlock() + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_4_3: failed to set failpoint") + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_4_3: failed executing Find: %v", err) + } + + countMutex.Lock() + if callbackCount != 2 { + return fmt.Errorf("human_4_3: expected callback count to be 2, got %d", callbackCount) + } + countMutex.Unlock() + return callbackFailed +} + +func human44ReauthenticationFails() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + if err != nil { + return fmt.Errorf("human_4_4: failed connecting admin client: %v", err) + } + defer adminClient.Disconnect(context.Background()) + + client, err := connectWithHumanCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + badToken := "bad token" + t := time.Now().Add(time.Hour) + if callbackCount == 1 { + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("human_4_4: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: &badToken, + }, nil + } + return &options.OIDCCredential{ + AccessToken: badToken, + ExpiresAt: &t, + RefreshToken: &badToken, + }, fmt.Errorf("failed to refresh token") + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("human_4_4: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("human_4_4: failed executing Find: %v", err) + } + + countMutex.Lock() + if callbackCount != 1 { + return fmt.Errorf("human_4_4: expected callback count to be 1, got %d", callbackCount) + } + countMutex.Unlock() + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("human_4_4: failed to set failpoint") + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("human_4_4: Find succeeded when it should fail") + } + + countMutex.Lock() + if callbackCount != 3 { + return fmt.Errorf("human_4_4: expected callback count to be 3, got %d", callbackCount) + } + countMutex.Unlock() + return callbackFailed +} + +func machine51azureWithNoUsername() error { + opts := options.Client().ApplyURI(uriSingle) + if opts == nil || opts.Auth == nil { + return fmt.Errorf("machine_5_1: failed parsing uri: %q", uriSingle) + } + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("machine_5_1: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_5_1: failed executing Find: %v", err) + } + return nil +} + +func machine52azureWithBadUsername() error { + opts := options.Client().ApplyURI(uriSingle) + if opts == nil || opts.Auth == nil { + return fmt.Errorf("machine_5_2: failed parsing uri: %q", uriSingle) + } + opts.Auth.Username = "bad" + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("machine_5_2: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_5_2: Find succeeded when it should fail") + } + return nil +} + +func machine61gcpWithNoUsername() error { + opts := options.Client().ApplyURI(uriSingle) + if opts == nil || opts.Auth == nil { + return fmt.Errorf("machine_6_1: failed parsing uri: %q", uriSingle) + } + client, err := mongo.Connect(context.Background(), opts) + if err != nil { + return fmt.Errorf("machine_6_1: failed connecting client: %v", err) + } + defer client.Disconnect(context.Background()) + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_6_1: failed executing Find: %v", err) + } + return nil +} diff --git a/etc/run-atlas-test.sh b/etc/run-atlas-test.sh index 6ca6775b55..ae240f6cbf 100644 --- a/etc/run-atlas-test.sh +++ b/etc/run-atlas-test.sh @@ -7,5 +7,5 @@ set +x # Get the atlas secrets. . ${DRIVERS_TOOLS}/.evergreen/secrets_handling/setup-secrets.sh drivers/atlas_connect -echo "Running cmd/testatlas/main.go" -go run ./cmd/testatlas/main.go "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite +echo "Running cmd/testatlas" +go test -v -run ^TestAtlas$ go.mongodb.org/mongo-driver/cmd/testatlas -args "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite diff --git a/etc/run-oidc-test.sh b/etc/run-oidc-test.sh new file mode 100644 index 0000000000..4548a124a1 --- /dev/null +++ b/etc/run-oidc-test.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# run-oidc-test +# Runs oidc auth tests. +set -eu + +echo "Running MONGODB-OIDC authentication tests" + +OIDC_ENV="${OIDC_ENV:-"test"}" + +if [ $OIDC_ENV == "test" ]; then + # Make sure DRIVERS_TOOLS is set. + if [ -z "$DRIVERS_TOOLS" ]; then + echo "Must specify DRIVERS_TOOLS" + exit 1 + fi + source ${DRIVERS_TOOLS}/.evergreen/auth_oidc/secrets-export.sh + +elif [ $OIDC_ENV == "azure" ]; then + source ./env.sh + +elif [ $OIDC_ENV == "gcp" ]; then + source ./secrets-export.sh + +else + echo "Unrecognized OIDC_ENV $OIDC_ENV" + exit 1 +fi + +export TEST_AUTH_OIDC=1 +export COVERAGE=1 +export AUTH="auth" + +$1 diff --git a/etc/run-serverless-test.sh b/etc/run-serverless-test.sh index 95c7d28ff3..9d858a0610 100644 --- a/etc/run-serverless-test.sh +++ b/etc/run-serverless-test.sh @@ -5,7 +5,6 @@ source ${DRIVERS_TOOLS}/.evergreen/serverless/secrets-export.sh AUTH="auth" \ SSL="ssl" \ MONGODB_URI="${SERVERLESS_URI}" \ - IS_SERVERLESS_PROXY="${IS_SERVERLESS_PROXY}" \ SERVERLESS="serverless" \ MAKEFILE_TARGET=evg-test-serverless \ sh ${PROJECT_DIRECTORY}/.evergreen/run-tests.sh diff --git a/internal/logger/io_sink.go b/internal/logger/io_sink.go index c5ff1474b4..0a6c1bdcab 100644 --- a/internal/logger/io_sink.go +++ b/internal/logger/io_sink.go @@ -9,6 +9,7 @@ package logger import ( "encoding/json" "io" + "math" "sync" "time" ) @@ -36,7 +37,11 @@ func NewIOSink(out io.Writer) *IOSink { // Info will write a JSON-encoded message to the io.Writer. func (sink *IOSink) Info(_ int, msg string, keysAndValues ...interface{}) { - kvMap := make(map[string]interface{}, len(keysAndValues)/2+2) + mapSize := len(keysAndValues) / 2 + if math.MaxInt-mapSize >= 2 { + mapSize += 2 + } + kvMap := make(map[string]interface{}, mapSize) kvMap[KeyTimestamp] = time.Now().UnixNano() kvMap[KeyMessage] = msg diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index 3fdb67b9a2..40f1181e0e 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -186,7 +186,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE). ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). - Logger(bw.collection.client.logger) + Logger(bw.collection.client.logger).Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { @@ -256,7 +256,7 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). - Logger(bw.collection.client.logger) + Logger(bw.collection.client.logger).Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { @@ -387,7 +387,8 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). ArrayFilters(hasArrayFilters).ServerAPI(bw.collection.client.serverAPI). - Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger) + Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger). + Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 8d0a2031de..3ea8baf1f2 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -137,7 +137,8 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in ReadPreference(config.readPreference).ReadConcern(config.readConcern). Deployment(cs.client.deployment).ClusterClock(cs.client.clock). CommandMonitor(cs.client.monitor).Session(cs.sess).ServerSelector(cs.selector).Retry(driver.RetryNone). - ServerAPI(cs.client.serverAPI).Crypt(config.crypt).Timeout(cs.client.timeout) + ServerAPI(cs.client.serverAPI).Crypt(config.crypt).Timeout(cs.client.timeout). + Authenticator(cs.client.authenticator) if cs.options.Collation != nil { cs.aggregate.Collation(bsoncore.Document(cs.options.Collation.ToDocument())) diff --git a/mongo/client.go b/mongo/client.go index 280749c7dd..00f4f363ae 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -26,6 +26,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" @@ -79,6 +80,7 @@ type Client struct { metadataClientFLE *Client internalClientFLE *Client encryptedFieldsMap map[string]interface{} + authenticator driver.Authenticator } // Connect creates a new Client and then initializes it using the Connect method. This is equivalent to calling @@ -209,11 +211,40 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { clientOpt.SetMaxPoolSize(defaultMaxPoolSize) } - if err != nil { - return nil, err + if clientOpt.Auth != nil { + var oidcMachineCallback auth.OIDCCallback + if clientOpt.Auth.OIDCMachineCallback != nil { + oidcMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := clientOpt.Auth.OIDCMachineCallback(ctx, convertOIDCArgs(args)) + return (*driver.OIDCCredential)(cred), err + } + } + + var oidcHumanCallback auth.OIDCCallback + if clientOpt.Auth.OIDCHumanCallback != nil { + oidcHumanCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := clientOpt.Auth.OIDCHumanCallback(ctx, convertOIDCArgs(args)) + return (*driver.OIDCCredential)(cred), err + } + } + + // Create an authenticator for the client + client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{ + Source: clientOpt.Auth.AuthSource, + Username: clientOpt.Auth.Username, + Password: clientOpt.Auth.Password, + PasswordSet: clientOpt.Auth.PasswordSet, + Props: clientOpt.Auth.AuthMechanismProperties, + OIDCMachineCallback: oidcMachineCallback, + OIDCHumanCallback: oidcHumanCallback, + }, clientOpt.HTTPClient) + if err != nil { + return nil, err + } } - cfg, err := topology.NewConfig(clientOpt, client.clock) + cfg, err := topology.NewConfigWithAuthenticator(clientOpt, client.clock, client.authenticator) + if err != nil { return nil, err } @@ -235,6 +266,19 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { return client, nil } +// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent +// public type *options.OIDCArgs. +func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs { + if args == nil { + return nil + } + return &options.OIDCArgs{ + Version: args.Version, + IDPInfo: (*options.IDPInfo)(args.IDPInfo), + RefreshToken: args.RefreshToken, + } +} + // Connect initializes the Client by starting background monitoring goroutines. // If the Client was created using the NewClient function, this method must be called before a Client can be used. // @@ -694,7 +738,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... op := operation.NewListDatabases(filterDoc). Session(sess).ReadPreference(c.readPreference).CommandMonitor(c.monitor). ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.deployment).Crypt(c.cryptFLE). - ServerAPI(c.serverAPI).Timeout(c.timeout) + ServerAPI(c.serverAPI).Timeout(c.timeout).Authenticator(c.authenticator) if ldo.NameOnly != nil { op = op.NameOnly(*ldo.NameOnly) diff --git a/mongo/client_test.go b/mongo/client_test.go index 013c1ae6bb..0a96e54501 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -11,6 +11,7 @@ import ( "errors" "math" "os" + "reflect" "testing" "time" @@ -18,11 +19,13 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/integtest" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/tag" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" @@ -502,3 +505,76 @@ func TestClient(t *testing.T) { } }) } + +// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs +// into an options.OIDCArgs. +func TestConvertOIDCArgs(t *testing.T) { + refreshToken := "test refresh token" + + testCases := []struct { + desc string + args *driver.OIDCArgs + }{ + { + desc: "populated args", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: &driver.IDPInfo{ + Issuer: "test issuer", + ClientID: "test client ID", + RequestScopes: []string{"test scope 1", "test scope 2"}, + }, + RefreshToken: &refreshToken, + }, + }, + { + desc: "nil", + args: nil, + }, + { + desc: "nil IDPInfo and RefreshToken", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: nil, + RefreshToken: nil, + }, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + got := convertOIDCArgs(tc.args) + + if tc.args == nil { + assert.Nil(t, got, "expected nil when input is nil") + return + } + + require.Equal(t, + 3, + reflect.ValueOf(*tc.args).NumField(), + "expected the driver.OIDCArgs struct to have exactly 3 fields") + require.Equal(t, + 3, + reflect.ValueOf(*got).NumField(), + "expected the options.OIDCArgs struct to have exactly 3 fields") + + assert.Equal(t, + tc.args.Version, + got.Version, + "expected Version field to be equal") + assert.EqualValues(t, + tc.args.IDPInfo, + got.IDPInfo, + "expected IDPInfo field to be convertible to equal values") + assert.Equal(t, + tc.args.RefreshToken, + got.RefreshToken, + "expected RefreshToken field to be equal") + }) + } +} diff --git a/mongo/collection.go b/mongo/collection.go index 4cf6fd1a1a..dbe238a9e3 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -291,7 +291,8 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger). + Authenticator(coll.client.authenticator) imo := options.MergeInsertManyOptions(opts...) if imo.BypassDocumentValidation != nil && *imo.BypassDocumentValidation { op = op.BypassDocumentValidation(*imo.BypassDocumentValidation) @@ -471,7 +472,8 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger). + Authenticator(coll.client.authenticator) if do.Comment != nil { comment, err := marshalValue(do.Comment, coll.bsonOpts, coll.registry) if err != nil { @@ -588,7 +590,7 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Hint(uo.Hint != nil). ArrayFilters(uo.ArrayFilters != nil).Ordered(true).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).Logger(coll.client.logger) + Timeout(coll.client.timeout).Logger(coll.client.logger).Authenticator(coll.client.authenticator) if uo.Let != nil { let, err := marshal(uo.Let, coll.bsonOpts, coll.registry) if err != nil { @@ -861,7 +863,8 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { ServerAPI(a.client.serverAPI). HasOutputStage(hasOutputStage). Timeout(a.client.timeout). - MaxTime(ao.MaxTime) + MaxTime(ao.MaxTime). + Authenticator(a.client.authenticator) // Omit "maxTimeMS" from operations that return a user-managed cursor to // prevent confusing "cursor not found" errors. To maintain existing @@ -992,7 +995,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name). Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime) + Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime).Authenticator(coll.client.authenticator) if countOpts.Collation != nil { op.Collation(bsoncore.Document(countOpts.Collation.ToDocument())) } @@ -1077,7 +1080,7 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(co.MaxTime) + Timeout(coll.client.timeout).MaxTime(co.MaxTime).Authenticator(coll.client.authenticator) if co.Comment != nil { comment, err := marshalValue(co.Comment, coll.bsonOpts, coll.registry) @@ -1144,7 +1147,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(option.MaxTime) + Timeout(coll.client.timeout).MaxTime(option.MaxTime).Authenticator(coll.client.authenticator) if option.Collation != nil { op.Collation(bsoncore.Document(option.Collation.ToDocument())) @@ -1257,7 +1260,7 @@ func (coll *Collection) find( ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Logger(coll.client.logger). - OmitCSOTMaxTimeMS(omitCSOTMaxTimeMS) + OmitCSOTMaxTimeMS(omitCSOTMaxTimeMS).Authenticator(coll.client.authenticator) cursorOpts := coll.client.createBaseCursorOptions() @@ -1521,7 +1524,7 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} } fod := options.MergeFindOneAndDeleteOptions(opts...) op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). - MaxTime(fod.MaxTime) + MaxTime(fod.MaxTime).Authenticator(coll.client.authenticator) if fod.Collation != nil { op = op.Collation(bsoncore.Document(fod.Collation.ToDocument())) } @@ -1601,7 +1604,8 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ fo := options.MergeFindOneAndReplaceOptions(opts...) op := operation.NewFindAndModify(f).Update(bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: r}). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Authenticator(coll.client.authenticator) + if fo.BypassDocumentValidation != nil && *fo.BypassDocumentValidation { op = op.BypassDocumentValidation(*fo.BypassDocumentValidation) } @@ -1688,7 +1692,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} fo := options.MergeFindOneAndUpdateOptions(opts...) op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). - MaxTime(fo.MaxTime) + MaxTime(fo.MaxTime).Authenticator(coll.client.authenticator) u, err := marshalUpdateValue(update, coll.bsonOpts, coll.registry, true) if err != nil { @@ -1894,7 +1898,8 @@ func (coll *Collection) drop(ctx context.Context) error { ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). + Authenticator(coll.client.authenticator) err = op.Execute(ctx) // ignore namespace not found errors diff --git a/mongo/database.go b/mongo/database.go index 57c0186eca..5344c9641e 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -189,7 +189,7 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, ServerSelector(readSelect).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment). Crypt(db.client.cryptFLE).ReadPreference(ro.ReadPreference).ServerAPI(db.client.serverAPI). - Timeout(db.client.timeout).Logger(db.client.logger), sess, nil + Timeout(db.client.timeout).Logger(db.client.logger).Authenticator(db.client.authenticator), sess, nil } // RunCommand executes the given command against the database. @@ -308,7 +308,7 @@ func (db *Database) Drop(ctx context.Context) error { Session(sess).WriteConcern(wc).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) err = op.Execute(ctx) @@ -402,7 +402,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt Session(sess).ReadPreference(db.readPreference).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE). - ServerAPI(db.client.serverAPI).Timeout(db.client.timeout) + ServerAPI(db.client.serverAPI).Timeout(db.client.timeout).Authenticator(db.client.authenticator) cursorOpts := db.client.createBaseCursorOptions() @@ -679,7 +679,7 @@ func (db *Database) createCollection(ctx context.Context, name string, opts ...* func (db *Database) createCollectionOperation(name string, opts ...*options.CreateCollectionOptions) (*operation.Create, error) { cco := options.MergeCreateCollectionOptions(opts...) - op := operation.NewCreate(name).ServerAPI(db.client.serverAPI) + op := operation.NewCreate(name).ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) if cco.Capped != nil { op.Capped(*cco.Capped) @@ -805,7 +805,8 @@ func (db *Database) CreateView(ctx context.Context, viewName, viewOn string, pip op := operation.NewCreate(viewName). ViewOn(viewOn). Pipeline(pipelineArray). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI). + Authenticator(db.client.authenticator) cvo := options.MergeCreateViewOptions(opts...) if cvo.Collation != nil { op.Collation(bsoncore.Document(cvo.Collation.ToDocument())) diff --git a/mongo/index_view.go b/mongo/index_view.go index 8d3555d0b0..db65f75072 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -94,7 +94,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout) + Timeout(iv.coll.client.timeout).Authenticator(iv.coll.client.authenticator) cursorOpts := iv.coll.client.createBaseCursorOptions() @@ -262,7 +262,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. Session(sess).WriteConcern(wc).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name).CommandMonitor(iv.coll.client.monitor). Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime) + Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime).Authenticator(iv.coll.client.authenticator) if option.CommitQuorum != nil { commitQuorum, err := marshalValue(option.CommitQuorum, iv.coll.bsonOpts, iv.coll.registry) if err != nil { @@ -367,7 +367,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum return optsDoc, nil } -func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.DropIndexesOptions) (bson.Raw, error) { +func (iv IndexView) drop(ctx context.Context, index any, opts ...*options.DropIndexesOptions) (bson.Raw, error) { if ctx == nil { ctx = context.Background() } @@ -397,12 +397,12 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop // TODO(GODRIVER-3038): This operation should pass CSE to the DropIndexes // Crypt setter to be applied to the operation. - op := operation.NewDropIndexes(name). - Session(sess).WriteConcern(wc).CommandMonitor(iv.coll.client.monitor). + op := operation.NewDropIndexes(index).Session(sess).WriteConcern(wc).CommandMonitor(iv.coll.client.monitor). ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime) + Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime). + Authenticator(iv.coll.client.authenticator) err = op.Execute(ctx) if err != nil { @@ -435,6 +435,20 @@ func (iv IndexView) DropOne(ctx context.Context, name string, opts ...*options.D return iv.drop(ctx, name, opts...) } +// DropOneWithKey drops a collection index by key using the dropIndexes operation. If the operation succeeds, this returns +// a BSON document in the form {nIndexesWas: }. The "nIndexesWas" field in the response contains the number of +// indexes that existed prior to the drop. +// +// This function is useful to drop an index using its key specification instead of its name. +func (iv IndexView) DropOneWithKey(ctx context.Context, keySpecDocument interface{}, opts ...*options.DropIndexesOptions) (bson.Raw, error) { + doc, err := marshal(keySpecDocument, iv.coll.bsonOpts, iv.coll.registry) + if err != nil { + return nil, err + } + + return iv.drop(ctx, doc, opts...) +} + // DropAll executes a dropIndexes operation to drop all indexes on the collection. If the operation succeeds, this // returns a BSON document in the form {nIndexesWas: }. The "nIndexesWas" field in the response contains the // number of indexes that existed prior to the drop. diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index 0139d273da..0e478537b4 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -526,31 +526,23 @@ func TestClient(t *testing.T) { // Assert that the minimum RTT is eventually >250ms. topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's minimum RTTs to be >250ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().Min() <= 250*time.Millisecond { - done = false - } - } - if done { - return + callback := func() bool { + // Wait for all of the server's minimum RTTs to be >250ms. + for _, desc := range topo.Description().Servers { + server, err := topo.FindServer(desc) + assert.NoError(mt, err, "FindServer error: %v", err) + if server.RTTMonitor().Min() <= 250*time.Millisecond { + return false // the tick should wait for 100ms in this case } } - }, 10*time.Second) + + return true + } + assert.Eventually(t, + callback, + 10*time.Second, + 100*time.Millisecond, + "expected that the minimum RTT is eventually >250ms") }) // Test that if the minimum RTT is greater than the remaining timeout for an operation, the @@ -574,31 +566,23 @@ func TestClient(t *testing.T) { // Assert that the minimum RTT is eventually >250ms. topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's minimum RTTs to be >250ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().Min() <= 250*time.Millisecond { - done = false - } - } - if done { - return + callback := func() bool { + // Wait for all of the server's minimum RTTs to be >250ms. + for _, desc := range topo.Description().Servers { + server, err := topo.FindServer(desc) + assert.NoError(mt, err, "FindServer error: %v", err) + if server.RTTMonitor().Min() <= 250*time.Millisecond { + return false } } - }, 10*time.Second) + + return true + } + assert.Eventually(t, + callback, + 10*time.Second, + 100*time.Millisecond, + "expected that the minimum RTT is eventually >250ms") // Once we've waited for the minimum RTT for the single server to be >250ms, run a bunch of // Ping operations with a timeout of 250ms and expect that they return errors. @@ -625,31 +609,23 @@ func TestClient(t *testing.T) { // Assert that RTT90s are eventually >300ms. topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's RTT90s to be >300ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().P90() <= 300*time.Millisecond { - done = false - } - } - if done { - return + callback := func() bool { + // Wait for all of the server's RTT90s to be >300ms. + for _, desc := range topo.Description().Servers { + server, err := topo.FindServer(desc) + assert.NoError(mt, err, "FindServer error: %v", err) + if server.RTTMonitor().P90() <= 300*time.Millisecond { + return false } } - }, 10*time.Second) + + return true + } + assert.Eventually(t, + callback, + 10*time.Second, + 100*time.Millisecond, + "expected that the RTT90s are eventually >300ms") }) // Test that if Timeout is set and the RTT90 is greater than the remaining timeout for an operation, the @@ -676,31 +652,23 @@ func TestClient(t *testing.T) { // Assert that RTT90s are eventually >275ms. topo := getTopologyFromClient(mt.Client) - assert.Soon(mt, func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(100 * time.Millisecond) - - // Wait for all of the server's RTT90s to be >275ms. - done := true - for _, desc := range topo.Description().Servers { - server, err := topo.FindServer(desc) - assert.Nil(mt, err, "FindServer error: %v", err) - if server.RTTMonitor().P90() <= 275*time.Millisecond { - done = false - } - } - if done { - return + callback := func() bool { + // Wait for all of the server's RTT90s to be >275ms. + for _, desc := range topo.Description().Servers { + server, err := topo.FindServer(desc) + assert.NoError(mt, err, "FindServer error: %v", err) + if server.RTTMonitor().P90() <= 275*time.Millisecond { + return false } } - }, 10*time.Second) + + return true + } + assert.Eventually(t, + callback, + 10*time.Second, + 100*time.Millisecond, + "expected that the RTT90s are eventually >275ms") // Once we've waited for the RTT90 for the servers to be >275ms, run 10 Ping operations // with a timeout of 275ms and expect that they return timeout errors. diff --git a/mongo/integration/csot_prose_test.go b/mongo/integration/csot_prose_test.go index 4f9f112b3f..c8ddfd68df 100644 --- a/mongo/integration/csot_prose_test.go +++ b/mongo/integration/csot_prose_test.go @@ -89,13 +89,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("serverSelectionTimeoutMS honored if timeoutMS is not set", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to server selection timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=100&serverSelectionTimeoutMS=200") @@ -103,13 +108,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("timeoutMS honored for server selection if it's lower than serverSelectionTimeoutMS", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=200&serverSelectionTimeoutMS=100") @@ -117,13 +127,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("serverSelectionTimeoutMS honored for server selection if it's lower than timeoutMS", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to server selection timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) cliOpts = options.Client().ApplyURI("mongodb://invalid/?timeoutMS=0&serverSelectionTimeoutMS=100") @@ -131,13 +146,18 @@ func TestCSOTProse(t *testing.T) { mt.RunOpts("serverSelectionTimeoutMS honored for server selection if timeoutMS=0", mtOpts, func(mt *mtest.T) { mt.Parallel() - callback := func(ctx context.Context) { - err := mt.Client.Ping(ctx, nil) - assert.NotNil(mt, err, "expected Ping error, got nil") + callback := func() bool { + err := mt.Client.Ping(context.Background(), nil) + assert.Error(mt, err, "expected Ping error, got nil") + return true } // Assert that Ping fails within 150ms due to server selection timeout. - assert.Soon(mt, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected ping to fail within 150ms") }) }) } diff --git a/mongo/integration/csot_test.go b/mongo/integration/csot_test.go index 3eb0328616..fb1cc340a2 100644 --- a/mongo/integration/csot_test.go +++ b/mongo/integration/csot_test.go @@ -26,13 +26,12 @@ import ( // Test automatic "maxTimeMS" appending and connection closing behavior when // CSOT is disabled and enabled. -func TestCSOT(t *testing.T) { +func TestCSOT_maxTimeMS(t *testing.T) { mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) testCases := []struct { desc string commandName string - setup func(coll *mongo.Collection) error operation func(ctx context.Context, coll *mongo.Collection) error topologies []mtest.TopologyKind @@ -54,10 +53,6 @@ func TestCSOT(t *testing.T) { { desc: "FindOne", commandName: "find", - setup: func(coll *mongo.Collection) error { - _, err := coll.InsertOne(context.Background(), bson.D{}) - return err - }, operation: func(ctx context.Context, coll *mongo.Collection) error { return coll.FindOne(ctx, bson.D{}).Err() }, @@ -68,10 +63,6 @@ func TestCSOT(t *testing.T) { { desc: "Find", commandName: "find", - setup: func(coll *mongo.Collection) error { - _, err := coll.InsertOne(context.Background(), bson.D{}) - return err - }, operation: func(ctx context.Context, coll *mongo.Collection) error { _, err := coll.Find(ctx, bson.D{}) return err @@ -83,10 +74,6 @@ func TestCSOT(t *testing.T) { { desc: "FindOneAndDelete", commandName: "findAndModify", - setup: func(coll *mongo.Collection) error { - _, err := coll.InsertOne(context.Background(), bson.D{}) - return err - }, operation: func(ctx context.Context, coll *mongo.Collection) error { return coll.FindOneAndDelete(ctx, bson.D{}).Err() }, @@ -97,10 +84,6 @@ func TestCSOT(t *testing.T) { { desc: "FindOneAndUpdate", commandName: "findAndModify", - setup: func(coll *mongo.Collection) error { - _, err := coll.InsertOne(context.Background(), bson.D{}) - return err - }, operation: func(ctx context.Context, coll *mongo.Collection) error { return coll.FindOneAndUpdate(ctx, bson.D{}, bson.M{"$set": bson.M{"key": "value"}}).Err() }, @@ -111,10 +94,6 @@ func TestCSOT(t *testing.T) { { desc: "FindOneAndReplace", commandName: "findAndModify", - setup: func(coll *mongo.Collection) error { - _, err := coll.InsertOne(context.Background(), bson.D{}) - return err - }, operation: func(ctx context.Context, coll *mongo.Collection) error { return coll.FindOneAndReplace(ctx, bson.D{}, bson.D{}).Err() }, @@ -243,10 +222,6 @@ func TestCSOT(t *testing.T) { { desc: "Cursor getMore", commandName: "getMore", - setup: func(coll *mongo.Collection) error { - _, err := coll.InsertMany(context.Background(), []interface{}{bson.D{}, bson.D{}}) - return err - }, operation: func(ctx context.Context, coll *mongo.Collection) error { cursor, err := coll.Find(ctx, bson.D{}, options.Find().SetBatchSize(1)) if err != nil { @@ -261,6 +236,14 @@ func TestCSOT(t *testing.T) { }, } + // insertTwoDocuments inserts two documents in the test collection. + insertTwoDocuments := func(mt *mtest.T) { + mt.Helper() + + _, err := mt.Coll.InsertMany(context.Background(), []interface{}{bson.D{}, bson.D{}}) + require.NoError(mt, err, "InsertMany error") + } + // getStartedEvent returns the first command started event that matches the // specified command name. getStartedEvent := func(mt *mtest.T, command string) *event.CommandStartedEvent { @@ -281,12 +264,13 @@ func TestCSOT(t *testing.T) { return nil } - // assertMaxTimeMSIsSet asserts that "maxTimeMS" is set to a positive value - // on the given command document. - assertMaxTimeMSIsSet := func(mt *mtest.T, command bson.Raw) { + // getMaxTimeMS asserts that "maxTimeMS" is set on the command document for + // the given command name and returns the value. + getMaxTimeMS := func(mt *mtest.T, command string) int64 { mt.Helper() - maxTimeVal := command.Lookup("maxTimeMS") + evt := getStartedEvent(mt, command) + maxTimeVal := evt.Command.Lookup("maxTimeMS") require.Greater(mt, len(maxTimeVal.Value), @@ -300,14 +284,18 @@ func TestCSOT(t *testing.T) { maxTimeVal.Int64(), int64(0), "expected maxTimeMS value to be greater than 0") + + return maxTimeVal.Int64() } // assertMaxTimeMSIsSet asserts that "maxTimeMS" is not set on the given // command document. - assertMaxTimeMSNotSet := func(mt *mtest.T, command bson.Raw) { + assertMaxTimeMSNotSet := func(mt *mtest.T, command string) { mt.Helper() - _, err := command.LookupErr("maxTimeMS") + evt := getStartedEvent(mt, command) + + _, err := evt.Command.LookupErr("maxTimeMS") assert.ErrorIs(mt, err, bsoncore.ErrElementNotFound, @@ -318,41 +306,34 @@ func TestCSOT(t *testing.T) { mt.RunOpts(tc.desc, mtest.NewOptions().Topologies(tc.topologies...), func(mt *mtest.T) { mt.Run("maxTimeMS", func(mt *mtest.T) { mt.Run("timeoutMS not set", func(mt *mtest.T) { - if tc.setup != nil { - err := tc.setup(mt.Coll) - require.NoError(mt, err) - } + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) err := tc.operation(context.Background(), mt.Coll) require.NoError(mt, err) - - evt := getStartedEvent(mt, tc.commandName) - assertMaxTimeMSNotSet(mt, evt.Command) + assertMaxTimeMSNotSet(mt, tc.commandName) }) csotOpts := mtest.NewOptions().ClientOptions(options.Client().SetTimeout(10 * time.Second)) mt.RunOpts("timeoutMS and context.Background", csotOpts, func(mt *mtest.T) { - if tc.setup != nil { - err := tc.setup(mt.Coll) - require.NoError(mt, err) - } + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) err := tc.operation(context.Background(), mt.Coll) require.NoError(mt, err) - evt := getStartedEvent(mt, tc.commandName) - if tc.sendsMaxTimeMSWithTimeoutMS { - assertMaxTimeMSIsSet(mt, evt.Command) - } else { - assertMaxTimeMSNotSet(mt, evt.Command) + if !tc.sendsMaxTimeMSWithTimeoutMS { + assertMaxTimeMSNotSet(mt, tc.commandName) + return } + + maxTimeMS := getMaxTimeMS(mt, tc.commandName) + assert.Greater(mt, maxTimeMS, int64(0), "expected maxTimeMS to be greater than 0") }) mt.RunOpts("timeoutMS and Context with deadline", csotOpts, func(mt *mtest.T) { - if tc.setup != nil { - err := tc.setup(mt.Coll) - require.NoError(mt, err) - } + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -360,12 +341,13 @@ func TestCSOT(t *testing.T) { err := tc.operation(ctx, mt.Coll) require.NoError(mt, err) - evt := getStartedEvent(mt, tc.commandName) - if tc.sendsMaxTimeMSWithContextDeadline { - assertMaxTimeMSIsSet(mt, evt.Command) - } else { - assertMaxTimeMSNotSet(mt, evt.Command) + if !tc.sendsMaxTimeMSWithContextDeadline { + assertMaxTimeMSNotSet(mt, tc.commandName) + return } + + maxTimeMS := getMaxTimeMS(mt, tc.commandName) + assert.Greater(mt, maxTimeMS, int64(0), "expected maxTimeMS to be greater than 0") }) }) @@ -375,10 +357,8 @@ func TestCSOT(t *testing.T) { Topologies(mtest.Single, mtest.ReplicaSet). MinServerVersion("4.2") mt.RunOpts("prevents connection closure with timeoutMS", opts, func(mt *mtest.T) { - if tc.setup != nil { - err := tc.setup(mt.Coll) - require.NoError(mt, err) - } + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) mt.SetFailPoint(mtest.FailPoint{ ConfigureFailPoint: "failCommand", @@ -403,7 +383,7 @@ func TestCSOT(t *testing.T) { cancel() if !mongo.IsTimeout(err) { - t.Logf("CSOT-disabled operation %d returned a non-timeout error: %v", i, err) + t.Errorf("CSOT-disabled operation %d returned a non-timeout error: %v", i, err) } } @@ -428,7 +408,7 @@ func TestCSOT(t *testing.T) { cancel() if !mongo.IsTimeout(err) { - t.Logf("CSOT-enabled operation %d returned a non-timeout error: %v", i, err) + t.Errorf("CSOT-enabled operation %d returned a non-timeout error: %v", i, err) } } @@ -441,8 +421,10 @@ func TestCSOT(t *testing.T) { }) } - csotOpts := mtest.NewOptions().ClientOptions(options.Client().SetTimeout(10 * time.Second)) - mt.RunOpts("maxTimeMS is omitted for values greater than 2147483647ms", csotOpts, func(mt *mtest.T) { + mt.Run("maxTimeMS is omitted for values greater than 2147483647ms", func(mt *mtest.T) { + // Set a client-level timeoutMS value. + mt.ResetClient(options.Client().SetTimeout(10 * time.Second)) + ctx, cancel := context.WithTimeout(context.Background(), (2147483647+1000)*time.Millisecond) defer cancel() _, err := mt.Coll.InsertOne(ctx, bson.D{}) @@ -455,6 +437,152 @@ func TestCSOT(t *testing.T) { bsoncore.ErrElementNotFound, "expected maxTimeMS BSON value to be missing, but is present") }) + + // Deprecated MaxTime option tests. + mt.Run("Find uses MaxTime option when no other timeouts are set", func(mt *mtest.T) { + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) + + // Set a 5-second MaxTime value. + opts := options.Find().SetMaxTime(5 * time.Second) + + cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) + require.NoError(mt, err, "Find error") + err = cursor.Close(context.Background()) + require.NoError(mt, err, "Cursor.Close error") + + // Assert that maxTimeMS is set and that it's equal to the MaxTime + // value. + maxTimeMS := getMaxTimeMS(mt, "find") + assert.Equal(mt, + int64(5_000), + maxTimeMS, + "expected maxTimeMS to be equal to the MaxTime value") + }) + mt.Run("Find ignores MaxTime option when timeoutMS is set", func(mt *mtest.T) { + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) + + // Set a 10-second client-level timeoutMS value . + mt.ResetClient(options.Client().SetTimeout(10 * time.Second)) + + // Set a 5-second MaxTime value. + opts := options.Find().SetMaxTime(5 * time.Second) + + cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) + require.NoError(mt, err, "Find error") + err = cursor.Close(context.Background()) + require.NoError(mt, err, "Cursor.Close error") + + // Assert that maxTimeMS is set and that it's greater than the + // MaxTime value. + maxTimeMS := getMaxTimeMS(mt, "find") + assert.Greater(mt, + maxTimeMS, + int64(5_000), + "expected maxTimeMS to be greater than the MaxTime value") + }) + // TODO(GODRIVER-2944): Remove this test once the "timeoutMode" option is + // supported. + mt.Run("Find uses MaxTime option when timeoutMS and Context with deadline are set", func(mt *mtest.T) { + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) + + // Set a 10-second client-level timeoutMS value . + mt.ResetClient(options.Client().SetTimeout(10 * time.Second)) + + // Set a 10-second operation-level Context timeout. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Set a 5-second MaxTime value. + opts := options.Find().SetMaxTime(5 * time.Second) + + cursor, err := mt.Coll.Find(ctx, bson.D{}, opts) + require.NoError(mt, err, "Find error") + err = cursor.Close(context.Background()) + require.NoError(mt, err, "Cursor.Close error") + + // Assert that maxTimeMS is set and that it's equal to the MaxTime + // value. + maxTimeMS := getMaxTimeMS(mt, "find") + assert.Equal(mt, + int64(5_000), + maxTimeMS, + "expected maxTimeMS to be equal to the MaxTime value") + }) + mt.Run("Aggregate uses MaxTime option when no other timeouts are set", func(mt *mtest.T) { + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) + + // Set a 5-second MaxTime value. + opts := options.Aggregate().SetMaxTime(5 * time.Second) + + cursor, err := mt.Coll.Aggregate(context.Background(), bson.D{}, opts) + require.NoError(mt, err, "Aggregate error") + err = cursor.Close(context.Background()) + require.NoError(mt, err, "Cursor.Close error") + + // Assert that maxTimeMS is set and that it's equal to the MaxTime + // value. + maxTimeMS := getMaxTimeMS(mt, "aggregate") + assert.Equal(mt, + int64(5_000), + maxTimeMS, + "expected maxTimeMS to be equal to the MaxTime value") + }) + mt.Run("Aggregate ignores MaxTime option when timeoutMS is set", func(mt *mtest.T) { + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) + + // Set a 10-second client-level timeoutMS value . + mt.ResetClient(options.Client().SetTimeout(10 * time.Second)) + + // Set a 5-second MaxTime value. + opts := options.Aggregate().SetMaxTime(5 * time.Second) + + cursor, err := mt.Coll.Aggregate(context.Background(), bson.D{}, opts) + require.NoError(mt, err, "Aggregate error") + err = cursor.Close(context.Background()) + require.NoError(mt, err, "Cursor.Close error") + + // Assert that maxTimeMS is set and that it's greater than the + // MaxTime value. + maxTimeMS := getMaxTimeMS(mt, "aggregate") + assert.Greater(mt, + maxTimeMS, + int64(5_000), + "expected maxTimeMS to be greater than the MaxTime value") + }) + // TODO(GODRIVER-2944): Remove this test once the "timeoutMode" option is + // supported. + mt.Run("Aggregate uses MaxTime option when timeoutMS and Context with deadline are set", func(mt *mtest.T) { + // Insert some documents so the collection isn't empty. + insertTwoDocuments(mt) + + // Set a 10-second client-level timeoutMS value . + mt.ResetClient(options.Client().SetTimeout(10 * time.Second)) + + // Set a 10-second operation-level Context timeout. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Set a 5-second MaxTime value. + opts := options.Aggregate().SetMaxTime(5 * time.Second) + + cursor, err := mt.Coll.Aggregate(ctx, bson.D{}, opts) + require.NoError(mt, err, "Aggregate error") + err = cursor.Close(context.Background()) + require.NoError(mt, err, "Cursor.Close error") + + // Assert that maxTimeMS is set and that it's equal to the MaxTime + // value. + maxTimeMS := getMaxTimeMS(mt, "aggregate") + assert.Equal(mt, + int64(5_000), + maxTimeMS, + "expected maxTimeMS to be equal to the MaxTime value") + }) } func TestCSOT_errors(t *testing.T) { diff --git a/mongo/integration/index_view_test.go b/mongo/integration/index_view_test.go index bff69150d1..1c83a8ca34 100644 --- a/mongo/integration/index_view_test.go +++ b/mongo/integration/index_view_test.go @@ -607,6 +607,88 @@ func TestIndexView(t *testing.T) { } assert.Nil(mt, cursor.Err(), "cursor error: %v", cursor.Err()) }) + mt.Run("drop with key", func(mt *mtest.T) { + tests := []struct { + name string + models []mongo.IndexModel + index any + want string + }{ + { + name: "custom index name and unique indexes", + models: []mongo.IndexModel{ + { + Keys: bson.D{{"username", int32(1)}}, + Options: options.Index().SetUnique(true).SetName("myidx"), + }, + }, + index: bson.D{{"username", int32(1)}}, + want: "myidx", + }, + { + name: "normal generated index name", + models: []mongo.IndexModel{ + { + Keys: bson.D{{"foo", int32(-1)}}, + }, + }, + index: bson.D{{"foo", int32(-1)}}, + want: "foo_-1", + }, + { + name: "compound index", + models: []mongo.IndexModel{ + { + Keys: bson.D{{"foo", int32(1)}, {"bar", int32(1)}}, + }, + }, + index: bson.D{{"foo", int32(1)}, {"bar", int32(1)}}, + want: "foo_1_bar_1", + }, + { + name: "text index", + models: []mongo.IndexModel{ + { + Keys: bson.D{{"plot1", "text"}, {"plot2", "text"}}, + }, + }, + // Key is automatically set to Full Text Search for any text index + index: bson.D{{"_fts", "text"}, {"_ftsx", int32(1)}}, + want: "plot1_text_plot2_text", + }, + } + + for _, test := range tests { + mt.Run(test.name, func(mt *mtest.T) { + iv := mt.Coll.Indexes() + indexNames, err := iv.CreateMany(context.Background(), test.models) + + s, _ := test.index.(bson.D) + for _, name := range indexNames { + verifyIndexExists(mt, iv, index{ + Key: s, + Name: name, + }) + } + + assert.NoError(mt, err) + assert.Equal(mt, len(test.models), len(indexNames), "expected %v index names, got %v", len(test.models), len(indexNames)) + + _, err = iv.DropOneWithKey(context.Background(), test.index) + assert.Nil(mt, err, "DropOne error: %v", err) + + cursor, err := iv.List(context.Background()) + assert.Nil(mt, err, "List error: %v", err) + for cursor.Next(context.Background()) { + var idx index + err = cursor.Decode(&idx) + assert.Nil(mt, err, "Decode error: %v (document %v)", err, cursor.Current) + assert.NotEqual(mt, test.want, idx.Name, "found index %v after dropping", test.want) + } + assert.Nil(mt, cursor.Err(), "cursor error: %v", cursor.Err()) + }) + } + }) mt.Run("drop all", func(mt *mtest.T) { iv := mt.Coll.Indexes() names, err := iv.CreateMany(context.Background(), []mongo.IndexModel{ diff --git a/mongo/integration/mtest/mongotest.go b/mongo/integration/mtest/mongotest.go index f92b5c583f..25f30849b0 100644 --- a/mongo/integration/mtest/mongotest.go +++ b/mongo/integration/mtest/mongotest.go @@ -639,25 +639,25 @@ func (t *T) createTestClient() { // Setup command monitor var customMonitor = clientOpts.Monitor clientOpts.SetMonitor(&event.CommandMonitor{ - Started: func(_ context.Context, cse *event.CommandStartedEvent) { + Started: func(ctx context.Context, cse *event.CommandStartedEvent) { if customMonitor != nil && customMonitor.Started != nil { - customMonitor.Started(context.Background(), cse) + customMonitor.Started(ctx, cse) } t.monitorLock.Lock() defer t.monitorLock.Unlock() t.started = append(t.started, cse) }, - Succeeded: func(_ context.Context, cse *event.CommandSucceededEvent) { + Succeeded: func(ctx context.Context, cse *event.CommandSucceededEvent) { if customMonitor != nil && customMonitor.Succeeded != nil { - customMonitor.Succeeded(context.Background(), cse) + customMonitor.Succeeded(ctx, cse) } t.monitorLock.Lock() defer t.monitorLock.Unlock() t.succeeded = append(t.succeeded, cse) }, - Failed: func(_ context.Context, cfe *event.CommandFailedEvent) { + Failed: func(ctx context.Context, cfe *event.CommandFailedEvent) { if customMonitor != nil && customMonitor.Failed != nil { - customMonitor.Failed(context.Background(), cfe) + customMonitor.Failed(ctx, cfe) } t.monitorLock.Lock() defer t.monitorLock.Unlock() diff --git a/mongo/integration/mtest/opmsg_deployment.go b/mongo/integration/mtest/opmsg_deployment.go index 2215f84b38..2ddc23c413 100644 --- a/mongo/integration/mtest/opmsg_deployment.go +++ b/mongo/integration/mtest/opmsg_deployment.go @@ -61,6 +61,13 @@ func (c *connection) WriteWireMessage(context.Context, []byte) error { return nil } +func (c *connection) OIDCTokenGenID() uint64 { + return 0 +} + +func (c *connection) SetOIDCTokenGenID(uint64) { +} + // ReadWireMessage returns the next response in the connection's list of responses. func (c *connection) ReadWireMessage(_ context.Context) ([]byte, error) { var dst []byte diff --git a/mongo/integration/sdam_error_handling_test.go b/mongo/integration/sdam_error_handling_test.go index 58cac9ccdd..4a2baf542d 100644 --- a/mongo/integration/sdam_error_handling_test.go +++ b/mongo/integration/sdam_error_handling_test.go @@ -85,23 +85,13 @@ func TestSDAMErrorHandling(t *testing.T) { assert.NotNil(mt, err, "expected InsertOne error, got nil") assert.True(mt, mongo.IsTimeout(err), "expected timeout error, got %v", err) assert.True(mt, mongo.IsNetworkError(err), "expected network error, got %v", err) + // Assert that the pool is cleared within 2 seconds. - assert.Soon(mt, func(ctx context.Context) { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - - if tpm.IsPoolCleared() { - return - } - } - }, 2*time.Second) + assert.Eventually(t, + tpm.IsPoolCleared, + 2*time.Second, + 100*time.Millisecond, + "expected pool is cleared within 2 seconds") }) mt.RunOpts("pool cleared on non-timeout network error", noClientOpts, func(mt *mtest.T) { @@ -131,22 +121,11 @@ func TestSDAMErrorHandling(t *testing.T) { SetMinPoolSize(5)) // Assert that the pool is cleared within 2 seconds. - assert.Soon(mt, func(ctx context.Context) { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - - if tpm.IsPoolCleared() { - return - } - } - }, 2*time.Second) + assert.Eventually(t, + tpm.IsPoolCleared, + 2*time.Second, + 100*time.Millisecond, + "expected pool is cleared within 2 seconds") }) mt.Run("foreground", func(mt *mtest.T) { @@ -175,22 +154,11 @@ func TestSDAMErrorHandling(t *testing.T) { assert.False(mt, mongo.IsTimeout(err), "expected non-timeout error, got %v", err) // Assert that the pool is cleared within 2 seconds. - assert.Soon(mt, func(ctx context.Context) { - ticker := time.NewTicker(100 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - - if tpm.IsPoolCleared() { - return - } - } - }, 2*time.Second) + assert.Eventually(t, + tpm.IsPoolCleared, + 2*time.Second, + 100*time.Millisecond, + "expected pool is cleared within 2 seconds") }) }) }) diff --git a/mongo/integration/sdam_prose_test.go b/mongo/integration/sdam_prose_test.go index f91bab1176..3107dcb97d 100644 --- a/mongo/integration/sdam_prose_test.go +++ b/mongo/integration/sdam_prose_test.go @@ -11,6 +11,8 @@ import ( "net" "os" "runtime" + "sync" + "sync/atomic" "testing" "time" @@ -124,28 +126,23 @@ func TestSDAMProse(t *testing.T) { AppName: "streamingRttTest", }, }) - callback := func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: + callback := func() bool { + // We don't know which server received the failpoint command, so we wait until any of the server + // RTTs cross the threshold. + for _, serverDesc := range testTopology.Description().Servers { + if serverDesc.AverageRTT > 250*time.Millisecond { + return true } - - // We don't know which server received the failpoint command, so we wait until any of the server - // RTTs cross the threshold. - for _, serverDesc := range testTopology.Description().Servers { - if serverDesc.AverageRTT > 250*time.Millisecond { - return - } - } - - // The next update will be in ~500ms. - time.Sleep(500 * time.Millisecond) } + + // The next update will be in ~500ms. + return false } - assert.Soon(t, callback, defaultCallbackTimeout) + assert.Eventually(t, + callback, + defaultCallbackTimeout, + 500*time.Millisecond, + "expected average rtt heartbeats at least within every 500 ms period") }) }) @@ -237,4 +234,45 @@ func TestServerHeartbeatStartedEvent(t *testing.T) { } assert.Equal(t, expectedEvents, actualEvents) }) + + mt := mtest.New(t) + + mt.Run("polling must await frequency", func(mt *mtest.T) { + var heartbeatStartedCount atomic.Int64 + + servers := map[string]bool{} + serversMu := sync.RWMutex{} // Guard the servers set + + serverMonitor := &event.ServerMonitor{ + ServerHeartbeatStarted: func(*event.ServerHeartbeatStartedEvent) { + heartbeatStartedCount.Add(1) + }, + TopologyDescriptionChanged: func(evt *event.TopologyDescriptionChangedEvent) { + serversMu.Lock() + defer serversMu.Unlock() + + for _, srv := range evt.NewDescription.Servers { + servers[srv.Addr.String()] = true + } + }, + } + + // Create a client with heartbeatFrequency=100ms, + // serverMonitoringMode=poll. Use SDAM to record the number of times the + // a heartbeat is started and the number of servers discovered. + mt.ResetClient(options.Client(). + SetServerMonitor(serverMonitor). + SetServerMonitoringMode(options.ServerMonitoringModePoll)) + + // Per specifications, minHeartbeatFrequencyMS=500ms. So, within the first + // 500ms the heartbeatStartedCount should be LEQ to the number of discovered + // servers. + time.Sleep(500 * time.Millisecond) + + serversMu.Lock() + serverCount := int64(len(servers)) + serversMu.Unlock() + + assert.LessOrEqual(mt, heartbeatStartedCount.Load(), serverCount) + }) } diff --git a/mongo/integration/search_index_prose_test.go b/mongo/integration/search_index_prose_test.go index 3d7e0ffb10..2c3207332d 100644 --- a/mongo/integration/search_index_prose_test.go +++ b/mongo/integration/search_index_prose_test.go @@ -311,4 +311,151 @@ func TestSearchIndexProse(t *testing.T) { actual := doc.Lookup("latestDefinition").Value assert.Equal(mt, expected, actual, "unmatched definition") }) + + case7CollName, err := uuid.New() + assert.NoError(mt, err, "failed to create random collection name for case #7") + + mt.RunOpts("case 7: Driver can successfully handle search index types when creating indexes", + mtest.NewOptions().CollectionName(case7CollName.String()), + func(mt *mtest.T) { + ctx := context.Background() + + _, err := mt.Coll.InsertOne(ctx, bson.D{}) + require.NoError(mt, err, "failed to insert") + + view := mt.Coll.SearchIndexes() + + definition := bson.D{{"mappings", bson.D{{"dynamic", false}}}} + indexName := "test-search-index-case7-implicit" + opts := options.SearchIndexes().SetName(indexName) + index, err := view.CreateOne(ctx, mongo.SearchIndexModel{ + Definition: definition, + Options: opts, + }) + require.NoError(mt, err, "failed to create index") + require.Equal(mt, indexName, index, "unmatched name") + var doc bson.Raw + for doc == nil { + cursor, err := view.List(ctx, opts) + require.NoError(mt, err, "failed to list") + + if !cursor.Next(ctx) { + break + } + name := cursor.Current.Lookup("name").StringValue() + queryable := cursor.Current.Lookup("queryable").Boolean() + indexType := cursor.Current.Lookup("type").StringValue() + if name == indexName && queryable { + doc = cursor.Current + assert.Equal(mt, indexType, "search") + } else { + t.Logf("cursor: %s, sleep 5 seconds...", cursor.Current.String()) + time.Sleep(5 * time.Second) + } + } + + indexName = "test-search-index-case7-explicit" + opts = options.SearchIndexes().SetName(indexName).SetType("search") + index, err = view.CreateOne(ctx, mongo.SearchIndexModel{ + Definition: definition, + Options: opts, + }) + require.NoError(mt, err, "failed to create index") + require.Equal(mt, indexName, index, "unmatched name") + doc = nil + for doc == nil { + cursor, err := view.List(ctx, opts) + require.NoError(mt, err, "failed to list") + + if !cursor.Next(ctx) { + break + } + name := cursor.Current.Lookup("name").StringValue() + queryable := cursor.Current.Lookup("queryable").Boolean() + indexType := cursor.Current.Lookup("type").StringValue() + if name == indexName && queryable { + doc = cursor.Current + assert.Equal(mt, indexType, "search") + } else { + t.Logf("cursor: %s, sleep 5 seconds...", cursor.Current.String()) + time.Sleep(5 * time.Second) + } + } + + indexName = "test-search-index-case7-vector" + type vectorDefinitionField struct { + Type string `bson:"type"` + Path string `bson:"path"` + NumDimensions int `bson:"numDimensions"` + Similarity string `bson:"similarity"` + } + + type vectorDefinition struct { + Fields []vectorDefinitionField `bson:"fields"` + } + + opts = options.SearchIndexes().SetName(indexName).SetType("vectorSearch") + index, err = view.CreateOne(ctx, mongo.SearchIndexModel{ + Definition: vectorDefinition{ + Fields: []vectorDefinitionField{{"vector", "path", 1536, "euclidean"}}, + }, + Options: opts, + }) + require.NoError(mt, err, "failed to create index") + require.Equal(mt, indexName, index, "unmatched name") + doc = nil + for doc == nil { + cursor, err := view.List(ctx, opts) + require.NoError(mt, err, "failed to list") + + if !cursor.Next(ctx) { + break + } + name := cursor.Current.Lookup("name").StringValue() + queryable := cursor.Current.Lookup("queryable").Boolean() + indexType := cursor.Current.Lookup("type").StringValue() + if name == indexName && queryable { + doc = cursor.Current + assert.Equal(mt, indexType, "vectorSearch") + } else { + t.Logf("cursor: %s, sleep 5 seconds...", cursor.Current.String()) + time.Sleep(5 * time.Second) + } + } + }) + + case8CollName, err := uuid.New() + assert.NoError(mt, err, "failed to create random collection name for case #8") + + mt.RunOpts("case 8: Driver requires explicit type to create a vector search index", + mtest.NewOptions().CollectionName(case8CollName.String()), + func(mt *mtest.T) { + ctx := context.Background() + + _, err := mt.Coll.InsertOne(ctx, bson.D{}) + require.NoError(mt, err, "failed to insert") + + view := mt.Coll.SearchIndexes() + + type vectorDefinitionField struct { + Type string `bson:"type"` + Path string `bson:"path"` + NumDimensions int `bson:"numDimensions"` + Similarity string `bson:"similarity"` + } + + type vectorDefinition struct { + Fields []vectorDefinitionField `bson:"fields"` + } + + const indexName = "test-search-index-case7-vector" + opts := options.SearchIndexes().SetName(indexName) + _, err = view.CreateOne(ctx, mongo.SearchIndexModel{ + Definition: vectorDefinition{ + Fields: []vectorDefinitionField{{"vector", "plot_embedding", 1536, "euclidean"}}, + }, + Options: opts, + }) + assert.ErrorContains(mt, err, "Attribute mappings missing") + }) } diff --git a/mongo/integration/unified/collection_operation_execution.go b/mongo/integration/unified/collection_operation_execution.go index 978ce13f00..1235e4d62d 100644 --- a/mongo/integration/unified/collection_operation_execution.go +++ b/mongo/integration/unified/collection_operation_execution.go @@ -326,6 +326,7 @@ func executeCreateSearchIndex(ctx context.Context, operation *operation) (*opera var m struct { Definition interface{} Name *string + Type *string } err = bson.Unmarshal(val.Document(), &m) if err != nil { @@ -334,6 +335,7 @@ func executeCreateSearchIndex(ctx context.Context, operation *operation) (*opera model.Definition = m.Definition model.Options = options.SearchIndexes() model.Options.Name = m.Name + model.Options.Type = m.Type default: return nil, fmt.Errorf("unrecognized createSearchIndex option %q", key) } @@ -369,6 +371,7 @@ func executeCreateSearchIndexes(ctx context.Context, operation *operation) (*ope var m struct { Definition interface{} Name *string + Type *string } err = bson.Unmarshal(val.Value, &m) if err != nil { @@ -379,6 +382,7 @@ func executeCreateSearchIndexes(ctx context.Context, operation *operation) (*ope Options: options.SearchIndexes(), } model.Options.Name = m.Name + model.Options.Type = m.Type models = append(models, model) } default: diff --git a/mongo/integration/unified/unified_spec_runner.go b/mongo/integration/unified/unified_spec_runner.go index a2f1b8c102..41628d0c6c 100644 --- a/mongo/integration/unified/unified_spec_runner.go +++ b/mongo/integration/unified/unified_spec_runner.go @@ -10,7 +10,6 @@ import ( "context" "fmt" "io/ioutil" - "os" "path" "strings" "testing" @@ -69,10 +68,6 @@ var ( "operation is retried multiple times for non-zero timeoutMS - aggregate on database": "maxTimeMS is disabled on find and aggregate. See DRIVERS-2722.", } - skippedServerlessProxyTests = map[string]string{ - "errors during the initial connection hello are ignored": "Serverless Proxy does not support failpoints on hello (see GODRIVER-3157)", - } - logMessageValidatorTimeout = 10 * time.Millisecond lowHeartbeatFrequency = 50 * time.Millisecond ) @@ -256,11 +251,6 @@ func (tc *TestCase) Run(ls LoggerSkipper) error { if skipReason, ok := skippedTests[tc.Description]; ok { ls.Skipf("skipping due to known failure: %q", skipReason) } - // If we're running against a Serverless Proxy instance, also check the - // tests that should be skipped only for Serverless Proxy. - if skipReason, ok := skippedServerlessProxyTests[tc.Description]; ok && os.Getenv("IS_SERVERLESS_PROXY") == "true" { - ls.Skipf("skipping due to known failure with Serverless Proxy: %q", skipReason) - } // Validate that we support the schema declared by the test file before attempting to use its contents. if err := checkSchemaVersion(tc.schemaVersion); err != nil { diff --git a/mongo/integration/unified_runner_events_helper_test.go b/mongo/integration/unified_runner_events_helper_test.go index 780be40de9..5afc510e14 100644 --- a/mongo/integration/unified_runner_events_helper_test.go +++ b/mongo/integration/unified_runner_events_helper_test.go @@ -87,31 +87,23 @@ func waitForEvent(mt *mtest.T, test *testCase, op *operation) { eventType := op.Arguments.Lookup("event").StringValue() expectedCount := int(op.Arguments.Lookup("count").Int32()) - callback := func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - var count int - // Spec tests only ever wait for ServerMarkedUnknown SDAM events for the time being. - if eventType == "ServerMarkedUnknownEvent" { - count = test.monitor.getServerMarkedUnknownCount() - } else { - count = test.monitor.getPoolEventCount(eventType) - } - - if count >= expectedCount { - return - } - time.Sleep(100 * time.Millisecond) + callback := func() bool { + var count int + // Spec tests only ever wait for ServerMarkedUnknown SDAM events for the time being. + if eventType == "ServerMarkedUnknownEvent" { + count = test.monitor.getServerMarkedUnknownCount() + } else { + count = test.monitor.getPoolEventCount(eventType) } + + return count >= expectedCount } - assert.Soon(mt, callback, defaultCallbackTimeout) + assert.Eventually(mt, + callback, + defaultCallbackTimeout, + 100*time.Millisecond, + "expected spec tests to only wait for Server Marked Unknown SDAM events") } func assertEventCount(mt *mtest.T, testCase *testCase, op *operation) { @@ -134,23 +126,16 @@ func recordPrimary(mt *mtest.T, testCase *testCase) { } func waitForPrimaryChange(mt *mtest.T, testCase *testCase, op *operation) { - callback := func(ctx context.Context) { - for { - // Stop loop if callback has been canceled. - select { - case <-ctx.Done(): - return - default: - } - - if getPrimaryAddress(mt, testCase.testTopology, false) != testCase.recordedPrimary { - return - } - } + callback := func() bool { + return getPrimaryAddress(mt, testCase.testTopology, false) != testCase.recordedPrimary } timeout := convertValueToMilliseconds(mt, op.Arguments.Lookup("timeoutMS")) - assert.Soon(mt, callback, timeout) + assert.Eventually(mt, + callback, + timeout, + 100*time.Millisecond, + "expected primary address to be different within the timeout period") } // getPrimaryAddress returns the address of the current primary. If failFast is true, the server selection fast path diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index db56745919..180d039969 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -15,6 +15,7 @@ import ( "errors" "fmt" "io/ioutil" + "math" "net" "net/http" "strings" @@ -110,6 +111,34 @@ type Credential struct { Username string Password string PasswordSet bool + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback +} + +// OIDCCallback is the type for both Human and Machine Callback flows. +// RefreshToken will always be nil in the OIDCArgs for the Machine flow. +type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs struct { + Version int + IDPInfo *IDPInfo + RefreshToken *string +} + +// OIDCCredential contains the access token and refresh token. +type OIDCCredential struct { + AccessToken string + ExpiresAt *time.Time + RefreshToken *string +} + +// IDPInfo contains the information needed to perform OIDC authentication with +// an Identity Provider. +type IDPInfo struct { + Issuer string + ClientID string + RequestScopes []string } // BSONOptions are optional BSON marshaling and unmarshaling behaviors. @@ -1177,7 +1206,19 @@ func addClientCertFromSeparateFiles(cfg *tls.Config, keyFile, certFile, keyPassw return "", err } - data := make([]byte, 0, len(keyData)+len(certData)+1) + keySize := len(keyData) + if keySize > 64*1024*1024 { + return "", errors.New("X.509 key must be less than 64 MiB") + } + certSize := len(certData) + if certSize > 64*1024*1024 { + return "", errors.New("X.509 certificate must be less than 64 MiB") + } + dataSize := keySize + certSize + 1 + if dataSize > math.MaxInt { + return "", errors.New("size overflow") + } + data := make([]byte, 0, dataSize) data = append(data, keyData...) data = append(data, '\n') data = append(data, certData...) diff --git a/mongo/options/searchindexoptions.go b/mongo/options/searchindexoptions.go index 9774d615ba..8cb8a08b78 100644 --- a/mongo/options/searchindexoptions.go +++ b/mongo/options/searchindexoptions.go @@ -9,6 +9,7 @@ package options // SearchIndexesOptions represents options that can be used to configure a SearchIndexView. type SearchIndexesOptions struct { Name *string + Type *string } // SearchIndexes creates a new SearchIndexesOptions instance. @@ -22,6 +23,12 @@ func (sio *SearchIndexesOptions) SetName(name string) *SearchIndexesOptions { return sio } +// SetType sets the value for the Type field. +func (sio *SearchIndexesOptions) SetType(typ string) *SearchIndexesOptions { + sio.Type = &typ + return sio +} + // CreateSearchIndexesOptions represents options that can be used to configure a SearchIndexView.CreateOne or // SearchIndexView.CreateMany operation. type CreateSearchIndexesOptions struct { diff --git a/mongo/search_index_view.go b/mongo/search_index_view.go index 695a396425..3253a73a2b 100644 --- a/mongo/search_index_view.go +++ b/mongo/search_index_view.go @@ -108,6 +108,9 @@ func (siv SearchIndexView) CreateMany( if model.Options != nil && model.Options.Name != nil { indexes = bsoncore.AppendStringElement(indexes, "name", *model.Options.Name) } + if model.Options != nil && model.Options.Type != nil { + indexes = bsoncore.AppendStringElement(indexes, "type", *model.Options.Type) + } indexes = bsoncore.AppendDocumentElement(indexes, "definition", definition) indexes, err = bsoncore.AppendDocumentEnd(indexes, iidx) @@ -140,7 +143,7 @@ func (siv SearchIndexView) CreateMany( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) err = op.Execute(ctx) if err != nil { @@ -195,7 +198,7 @@ func (siv SearchIndexView) DropOne( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) err = op.Execute(ctx) if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() { @@ -249,7 +252,7 @@ func (siv SearchIndexView) UpdateOne( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) return op.Execute(ctx) } diff --git a/mongo/session.go b/mongo/session.go index 8f1e029b95..77be4ab6db 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -296,7 +296,8 @@ func (s *sessionImpl) AbortTransaction(ctx context.Context) error { _ = operation.NewAbortTransaction().Session(s.clientSession).ClusterClock(s.client.clock).Database("admin"). Deployment(s.deployment).WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector). Retry(driver.RetryOncePerCommand).CommandMonitor(s.client.monitor). - RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).ServerAPI(s.client.serverAPI).Execute(ctx) + RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).ServerAPI(s.client.serverAPI). + Authenticator(s.client.authenticator).Execute(ctx) s.clientSession.Aborting = false _ = s.clientSession.AbortTransaction() @@ -328,7 +329,7 @@ func (s *sessionImpl) CommitTransaction(ctx context.Context) error { Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment). WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand). CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)). - ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct) + ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct).Authenticator(s.client.authenticator) err = op.Execute(ctx) // Return error without updating transaction state if it is a timeout, as the transaction has not diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index af7ce98b0c..544053b973 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -399,19 +399,23 @@ func TestConvenientTransactions(t *testing.T) { // Insert a document within a session and manually cancel context before // "commitTransaction" can be sent. - callback := func(ctx context.Context) { - transactionCtx, cancel := context.WithCancel(ctx) - + callback := func() bool { + transactionCtx, cancel := context.WithCancel(context.Background()) _, _ = sess.WithTransaction(transactionCtx, func(ctx SessionContext) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.M{"x": 1}) - assert.Nil(t, err, "InsertOne error: %v", err) + assert.NoError(t, err, "InsertOne error: %v", err) cancel() return nil, nil }) + return true } // Assert that transaction is canceled within 500ms and not 2 seconds. - assert.Soon(t, callback, 500*time.Millisecond) + assert.Eventually(t, + callback, + 500*time.Millisecond, + time.Millisecond, + "expected transaction to be canceled within 500ms") // Assert that AbortTransaction was started once and succeeded. assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted)) @@ -459,19 +463,24 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - callback := func(ctx context.Context) { + callback := func() bool { // Create transaction context with short timeout. - withTransactionContext, cancel := context.WithTimeout(ctx, time.Nanosecond) + withTransactionContext, cancel := context.WithTimeout(context.Background(), time.Nanosecond) defer cancel() _, _ = sess.WithTransaction(withTransactionContext, func(ctx SessionContext) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.D{{}}) return nil, err }) + return true } // Assert that transaction fails within 500ms and not 2 seconds. - assert.Soon(t, callback, 500*time.Millisecond) + assert.Eventually(t, + callback, + 500*time.Millisecond, + time.Millisecond, + "expected transaction to fail within 500ms") }) t.Run("canceled context before callback does not retry", func(t *testing.T) { withTransactionTimeout = 2 * time.Second @@ -489,19 +498,24 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - callback := func(ctx context.Context) { + callback := func() bool { // Create transaction context and cancel it immediately. - withTransactionContext, cancel := context.WithTimeout(ctx, 2*time.Second) + withTransactionContext, cancel := context.WithTimeout(context.Background(), 2*time.Second) cancel() _, _ = sess.WithTransaction(withTransactionContext, func(ctx SessionContext) (interface{}, error) { _, err := coll.InsertOne(ctx, bson.D{{}}) return nil, err }) + return true } // Assert that transaction fails within 500ms and not 2 seconds. - assert.Soon(t, callback, 500*time.Millisecond) + assert.Eventually(t, + callback, + 500*time.Millisecond, + time.Millisecond, + "expected transaction to fail within 500ms") }) t.Run("slow operation in callback retries", func(t *testing.T) { withTransactionTimeout = 2 * time.Second @@ -540,8 +554,8 @@ func TestConvenientTransactions(t *testing.T) { assert.Nil(t, err, "StartSession error: %v", err) defer sess.EndSession(context.Background()) - callback := func(ctx context.Context) { - _, err = sess.WithTransaction(ctx, func(ctx SessionContext) (interface{}, error) { + callback := func() bool { + _, err = sess.WithTransaction(context.Background(), func(ctx SessionContext) (interface{}, error) { // Set a timeout of 300ms to cause a timeout on first insertOne // and force a retry. c, cancel := context.WithTimeout(ctx, 300*time.Millisecond) @@ -550,11 +564,17 @@ func TestConvenientTransactions(t *testing.T) { _, err := coll.InsertOne(c, bson.D{{}}) return nil, err }) - assert.Nil(t, err, "WithTransaction error: %v", err) + assert.NoError(t, err, "WithTransaction error: %v", err) + return true } // Assert that transaction passes within 2 seconds. - assert.Soon(t, callback, 2*time.Second) + assert.Eventually(t, + callback, + withTransactionTimeout, + time.Millisecond, + "expected transaction to be passed within 2s") + }) } diff --git a/sbom.json b/sbom.json new file mode 100644 index 0000000000..d561c09385 --- /dev/null +++ b/sbom.json @@ -0,0 +1,11 @@ +{ + "metadata": { + "timestamp": "2024-06-04T11:44:11.689753+00:00" + }, + "components": [], + "serialNumber": "urn:uuid:6687021d-b80d-46ed-acc9-031a17e582a3", + "version": 1, + "$schema": "http://cyclonedx.org/schema/bom-1.5.schema.json", + "bomFormat": "CycloneDX", + "specVersion": "1.5" +} diff --git a/testdata/index-management/createSearchIndex.json b/testdata/index-management/createSearchIndex.json index f9c4e44d3e..327cb61259 100644 --- a/testdata/index-management/createSearchIndex.json +++ b/testdata/index-management/createSearchIndex.json @@ -50,7 +50,8 @@ "mappings": { "dynamic": true } - } + }, + "type": "search" } }, "expectError": { @@ -73,7 +74,8 @@ "mappings": { "dynamic": true } - } + }, + "type": "search" } ], "$db": "database0" @@ -97,7 +99,8 @@ "dynamic": true } }, - "name": "test index" + "name": "test index", + "type": "search" } }, "expectError": { @@ -121,7 +124,68 @@ "dynamic": true } }, - "name": "test index" + "name": "test index", + "type": "search" + } + ], + "$db": "database0" + } + } + } + ] + } + ] + }, + { + "description": "create a vector search index", + "operations": [ + { + "name": "createSearchIndex", + "object": "collection0", + "arguments": { + "model": { + "definition": { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean" + } + ] + }, + "name": "test index", + "type": "vectorSearch" + } + }, + "expectError": { + "isError": true, + "errorContains": "Atlas" + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "createSearchIndexes": "collection0", + "indexes": [ + { + "definition": { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean" + } + ] + }, + "name": "test index", + "type": "vectorSearch" } ], "$db": "database0" diff --git a/testdata/index-management/createSearchIndex.yml b/testdata/index-management/createSearchIndex.yml index 2e3cf50f8d..a32546cacf 100644 --- a/testdata/index-management/createSearchIndex.yml +++ b/testdata/index-management/createSearchIndex.yml @@ -26,7 +26,7 @@ tests: - name: createSearchIndex object: *collection0 arguments: - model: { definition: &definition { mappings: { dynamic: true } } } + model: { definition: &definition { mappings: { dynamic: true } } , type: 'search' } expectError: # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting # that the driver constructs and sends the correct command. @@ -39,7 +39,7 @@ tests: - commandStartedEvent: command: createSearchIndexes: *collection0 - indexes: [ { definition: *definition } ] + indexes: [ { definition: *definition, type: 'search'} ] $db: *database0 - description: "name provided for an index definition" @@ -47,7 +47,7 @@ tests: - name: createSearchIndex object: *collection0 arguments: - model: { definition: &definition { mappings: { dynamic: true } } , name: 'test index' } + model: { definition: &definition { mappings: { dynamic: true } } , name: 'test index', type: 'search' } expectError: # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting # that the driver constructs and sends the correct command. @@ -60,5 +60,27 @@ tests: - commandStartedEvent: command: createSearchIndexes: *collection0 - indexes: [ { definition: *definition, name: 'test index' } ] + indexes: [ { definition: *definition, name: 'test index', type: 'search' } ] + $db: *database0 + + - description: "create a vector search index" + operations: + - name: createSearchIndex + object: *collection0 + arguments: + model: { definition: &definition { fields: [ {"type": "vector", "path": "plot_embedding", "numDimensions": 1536, "similarity": "euclidean"} ] } + , name: 'test index', type: 'vectorSearch' } + expectError: + # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting + # that the driver constructs and sends the correct command. + # The expected error message was changed in SERVER-83003. Check for the substring "Atlas" shared by both error messages. + isError: true + errorContains: Atlas + expectEvents: + - client: *client0 + events: + - commandStartedEvent: + command: + createSearchIndexes: *collection0 + indexes: [ { definition: *definition, name: 'test index', type: 'vectorSearch' } ] $db: *database0 diff --git a/testdata/index-management/createSearchIndexes.json b/testdata/index-management/createSearchIndexes.json index 3cf56ce12e..d91d7d9cf3 100644 --- a/testdata/index-management/createSearchIndexes.json +++ b/testdata/index-management/createSearchIndexes.json @@ -83,7 +83,8 @@ "mappings": { "dynamic": true } - } + }, + "type": "search" } ] }, @@ -107,7 +108,8 @@ "mappings": { "dynamic": true } - } + }, + "type": "search" } ], "$db": "database0" @@ -132,7 +134,8 @@ "dynamic": true } }, - "name": "test index" + "name": "test index", + "type": "search" } ] }, @@ -157,7 +160,70 @@ "dynamic": true } }, - "name": "test index" + "name": "test index", + "type": "search" + } + ], + "$db": "database0" + } + } + } + ] + } + ] + }, + { + "description": "create a vector search index", + "operations": [ + { + "name": "createSearchIndexes", + "object": "collection0", + "arguments": { + "models": [ + { + "definition": { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean" + } + ] + }, + "name": "test index", + "type": "vectorSearch" + } + ] + }, + "expectError": { + "isError": true, + "errorContains": "Atlas" + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "createSearchIndexes": "collection0", + "indexes": [ + { + "definition": { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean" + } + ] + }, + "name": "test index", + "type": "vectorSearch" } ], "$db": "database0" diff --git a/testdata/index-management/createSearchIndexes.yml b/testdata/index-management/createSearchIndexes.yml index db8f02e551..cac442cb87 100644 --- a/testdata/index-management/createSearchIndexes.yml +++ b/testdata/index-management/createSearchIndexes.yml @@ -48,7 +48,7 @@ tests: - name: createSearchIndexes object: *collection0 arguments: - models: [ { definition: &definition { mappings: { dynamic: true } } } ] + models: [ { definition: &definition { mappings: { dynamic: true } } , type: 'search' } ] expectError: # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting # that the driver constructs and sends the correct command. @@ -61,7 +61,7 @@ tests: - commandStartedEvent: command: createSearchIndexes: *collection0 - indexes: [ { definition: *definition } ] + indexes: [ { definition: *definition, type: 'search'} ] $db: *database0 - description: "name provided for an index definition" @@ -69,7 +69,7 @@ tests: - name: createSearchIndexes object: *collection0 arguments: - models: [ { definition: &definition { mappings: { dynamic: true } } , name: 'test index' } ] + models: [ { definition: &definition { mappings: { dynamic: true } } , name: 'test index' , type: 'search' } ] expectError: # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting # that the driver constructs and sends the correct command. @@ -82,5 +82,27 @@ tests: - commandStartedEvent: command: createSearchIndexes: *collection0 - indexes: [ { definition: *definition, name: 'test index' } ] + indexes: [ { definition: *definition, name: 'test index', type: 'search' } ] + $db: *database0 + + - description: "create a vector search index" + operations: + - name: createSearchIndexes + object: *collection0 + arguments: + models: [ { definition: &definition { fields: [ {"type": "vector", "path": "plot_embedding", "numDimensions": 1536, "similarity": "euclidean"} ] }, + name: 'test index' , type: 'vectorSearch' } ] + expectError: + # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting + # that the driver constructs and sends the correct command. + # The expected error message was changed in SERVER-83003. Check for the substring "Atlas" shared by both error messages. + isError: true + errorContains: Atlas + expectEvents: + - client: *client0 + events: + - commandStartedEvent: + command: + createSearchIndexes: *collection0 + indexes: [ { definition: *definition, name: 'test index', type: 'vectorSearch' } ] $db: *database0 diff --git a/version/version.go b/version/version.go index 040c707064..6fb59e745f 100644 --- a/version/version.go +++ b/version/version.go @@ -8,4 +8,4 @@ package version // import "go.mongodb.org/mongo-driver/version" // Driver is the current version of the driver. -var Driver = "v1.16.0-prerelease" +var Driver = "v1.17.0-prerelease" diff --git a/x/bsonx/bsoncore/bsoncore.go b/x/bsonx/bsoncore/bsoncore.go index 88133293ea..03925d7ada 100644 --- a/x/bsonx/bsoncore/bsoncore.go +++ b/x/bsonx/bsoncore/bsoncore.go @@ -8,6 +8,7 @@ package bsoncore // import "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" import ( "bytes" + "encoding/binary" "fmt" "math" "strconv" @@ -706,17 +707,16 @@ func ReserveLength(dst []byte) (int32, []byte) { // UpdateLength updates the length at index with length and returns the []byte. func UpdateLength(dst []byte, index, length int32) []byte { - dst[index] = byte(length) - dst[index+1] = byte(length >> 8) - dst[index+2] = byte(length >> 16) - dst[index+3] = byte(length >> 24) + binary.LittleEndian.PutUint32(dst[index:], uint32(length)) return dst } func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) } func appendi32(dst []byte, i32 int32) []byte { - return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24)) + b := []byte{0, 0, 0, 0} + binary.LittleEndian.PutUint32(b, uint32(i32)) + return append(dst, b...) } // ReadLength reads an int32 length from src and returns the length and the remaining bytes. If @@ -734,27 +734,26 @@ func readi32(src []byte) (int32, []byte, bool) { if len(src) < 4 { return 0, src, false } - return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true + return int32(binary.LittleEndian.Uint32(src)), src[4:], true } func appendi64(dst []byte, i64 int64) []byte { - return append(dst, - byte(i64), byte(i64>>8), byte(i64>>16), byte(i64>>24), - byte(i64>>32), byte(i64>>40), byte(i64>>48), byte(i64>>56), - ) + b := []byte{0, 0, 0, 0, 0, 0, 0, 0} + binary.LittleEndian.PutUint64(b, uint64(i64)) + return append(dst, b...) } func readi64(src []byte) (int64, []byte, bool) { if len(src) < 8 { return 0, src, false } - i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 | - int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56) - return i64, src[8:], true + return int64(binary.LittleEndian.Uint64(src)), src[8:], true } func appendu32(dst []byte, u32 uint32) []byte { - return append(dst, byte(u32), byte(u32>>8), byte(u32>>16), byte(u32>>24)) + b := []byte{0, 0, 0, 0} + binary.LittleEndian.PutUint32(b, u32) + return append(dst, b...) } func readu32(src []byte) (uint32, []byte, bool) { @@ -762,23 +761,20 @@ func readu32(src []byte) (uint32, []byte, bool) { return 0, src, false } - return (uint32(src[0]) | uint32(src[1])<<8 | uint32(src[2])<<16 | uint32(src[3])<<24), src[4:], true + return binary.LittleEndian.Uint32(src), src[4:], true } func appendu64(dst []byte, u64 uint64) []byte { - return append(dst, - byte(u64), byte(u64>>8), byte(u64>>16), byte(u64>>24), - byte(u64>>32), byte(u64>>40), byte(u64>>48), byte(u64>>56), - ) + b := []byte{0, 0, 0, 0, 0, 0, 0, 0} + binary.LittleEndian.PutUint64(b, u64) + return append(dst, b...) } func readu64(src []byte) (uint64, []byte, bool) { if len(src) < 8 { return 0, src, false } - u64 := (uint64(src[0]) | uint64(src[1])<<8 | uint64(src[2])<<16 | uint64(src[3])<<24 | - uint64(src[4])<<32 | uint64(src[5])<<40 | uint64(src[6])<<48 | uint64(src[7])<<56) - return u64, src[8:], true + return binary.LittleEndian.Uint64(src), src[8:], true } // keep in sync with readcstringbytes diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index 6eeaf0ee01..f6471cea26 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -19,8 +19,11 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) +// Config contains the configuration for an Authenticator. +type Config = driver.AuthConfig + // AuthenticatorFactory constructs an authenticator. -type AuthenticatorFactory func(cred *Cred) (Authenticator, error) +type AuthenticatorFactory func(*Cred, *http.Client) (Authenticator, error) var authFactories = make(map[string]AuthenticatorFactory) @@ -33,12 +36,13 @@ func init() { RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator) RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator) RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator) + RegisterAuthenticatorFactory(MongoDBOIDC, newOIDCAuthenticator) } // CreateAuthenticator creates an authenticator. -func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) { +func CreateAuthenticator(name string, cred *Cred, httpClient *http.Client) (Authenticator, error) { if f, ok := authFactories[name]; ok { - return f(cred) + return f(cred, httpClient) } return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil) @@ -61,7 +65,6 @@ type HandshakeOptions struct { ClusterClock *session.ClusterClock ServerAPI *driver.ServerAPIOptions LoadBalanced bool - HTTPClient *http.Client } type authHandshaker struct { @@ -97,12 +100,17 @@ func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr addr return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err) } - firstMsg, err := ah.conversation.FirstMessage() - if err != nil { - return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err) - } + // It is possible for the speculative conversation to be nil even without error if the authenticator + // cannot perform speculative authentication. An example of this is MONGODB-OIDC when there is + // no AccessToken in the cache. + if ah.conversation != nil { + firstMsg, err := ah.conversation.FirstMessage() + if err != nil { + return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err) + } - op = op.SpeculativeAuthenticate(firstMsg) + op = op.SpeculativeAuthenticate(firstMsg) + } } } @@ -132,7 +140,6 @@ func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Conne ClusterClock: ah.options.ClusterClock, HandshakeInfo: ah.handshakeInfo, ServerAPI: ah.options.ServerAPI, - HTTPClient: ah.options.HTTPClient, } if err := ah.authenticate(ctx, cfg); err != nil { @@ -170,21 +177,8 @@ func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshake } } -// Config holds the information necessary to perform an authentication attempt. -type Config struct { - Description description.Server - Connection driver.Connection - ClusterClock *session.ClusterClock - HandshakeInfo driver.HandshakeInformation - ServerAPI *driver.ServerAPIOptions - HTTPClient *http.Client -} - // Authenticator handles authenticating a connection. -type Authenticator interface { - // Auth authenticates the connection. - Auth(context.Context, *Config) error -} +type Authenticator = driver.Authenticator func newAuthError(msg string, inner error) error { return &Error{ diff --git a/x/mongo/driver/auth/auth_test.go b/x/mongo/driver/auth/auth_test.go index 9145a21595..3c07ed2cd8 100644 --- a/x/mongo/driver/auth/auth_test.go +++ b/x/mongo/driver/auth/auth_test.go @@ -7,6 +7,7 @@ package auth_test import ( + "net/http" "testing" "github.com/google/go-cmp/cmp" @@ -39,7 +40,7 @@ func TestCreateAuthenticator(t *testing.T) { PasswordSet: true, } - a, err := CreateAuthenticator(test.name, cred) + a, err := CreateAuthenticator(test.name, cred, &http.Client{}) require.NoError(t, err) require.IsType(t, test.auth, a) }) diff --git a/x/mongo/driver/auth/cred.go b/x/mongo/driver/auth/cred.go index 7b2b8f17d0..a9685f6ed8 100644 --- a/x/mongo/driver/auth/cred.go +++ b/x/mongo/driver/auth/cred.go @@ -6,11 +6,9 @@ package auth -// Cred is a user's credential. -type Cred struct { - Source string - Username string - Password string - PasswordSet bool - Props map[string]string -} +import ( + "go.mongodb.org/mongo-driver/x/mongo/driver" +) + +// Cred is the type of user credential +type Cred = driver.Cred diff --git a/x/mongo/driver/auth/default.go b/x/mongo/driver/auth/default.go index 6f2ca5224a..785a41951d 100644 --- a/x/mongo/driver/auth/default.go +++ b/x/mongo/driver/auth/default.go @@ -9,10 +9,13 @@ package auth import ( "context" "fmt" + "net/http" + + "go.mongodb.org/mongo-driver/x/mongo/driver" ) -func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { - scram, err := newScramSHA256Authenticator(cred) +func newDefaultAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { + scram, err := newScramSHA256Authenticator(cred, httpClient) if err != nil { return nil, newAuthError("failed to create internal authenticator", err) } @@ -25,6 +28,7 @@ func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { return &DefaultAuthenticator{ Cred: cred, speculativeAuthenticator: speculative, + httpClient: httpClient, }, nil } @@ -36,6 +40,8 @@ type DefaultAuthenticator struct { // The authenticator to use for speculative authentication. Because the correct auth mechanism is unknown when doing // the initial hello, SCRAM-SHA-256 is used for the speculative attempt. speculativeAuthenticator SpeculativeAuthenticator + + httpClient *http.Client } var _ SpeculativeAuthenticator = (*DefaultAuthenticator)(nil) @@ -52,11 +58,11 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { switch chooseAuthMechanism(cfg) { case SCRAMSHA256: - actual, err = newScramSHA256Authenticator(a.Cred) + actual, err = newScramSHA256Authenticator(a.Cred, a.httpClient) case SCRAMSHA1: - actual, err = newScramSHA1Authenticator(a.Cred) + actual, err = newScramSHA1Authenticator(a.Cred, a.httpClient) default: - actual, err = newMongoDBCRAuthenticator(a.Cred) + actual, err = newMongoDBCRAuthenticator(a.Cred, a.httpClient) } if err != nil { @@ -66,6 +72,11 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { return actual.Auth(ctx, cfg) } +// Reauth reauthenticates the connection. +func (a *DefaultAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("DefaultAuthenticator does not support reauthentication", nil) +} + // If a server provides a list of supported mechanisms, we choose // SCRAM-SHA-256 if it exists or else MUST use SCRAM-SHA-1. // Otherwise, we decide based on what is supported. diff --git a/x/mongo/driver/auth/gssapi.go b/x/mongo/driver/auth/gssapi.go index 4b860ba63f..037c944eb7 100644 --- a/x/mongo/driver/auth/gssapi.go +++ b/x/mongo/driver/auth/gssapi.go @@ -14,14 +14,16 @@ import ( "context" "fmt" "net" + "net/http" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/gssapi" ) // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { if cred.Source != "" && cred.Source != "$external" { return nil, newAuthError("GSSAPI source must be empty or $external", nil) } @@ -57,3 +59,8 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error { } return ConductSaslConversation(ctx, cfg, "$external", client) } + +// Reauth reauthenticates the connection. +func (a *GSSAPIAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("GSSAPI does not support reauthentication", nil) +} diff --git a/x/mongo/driver/auth/gssapi_not_enabled.go b/x/mongo/driver/auth/gssapi_not_enabled.go index 7ba5fe860c..e50553c7a1 100644 --- a/x/mongo/driver/auth/gssapi_not_enabled.go +++ b/x/mongo/driver/auth/gssapi_not_enabled.go @@ -9,9 +9,11 @@ package auth +import "net/http" + // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(*Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(*Cred, *http.Client) (Authenticator, error) { return nil, newAuthError("GSSAPI support not enabled during build (-tags gssapi)", nil) } diff --git a/x/mongo/driver/auth/gssapi_not_supported.go b/x/mongo/driver/auth/gssapi_not_supported.go index 10312c228e..12046ff67c 100644 --- a/x/mongo/driver/auth/gssapi_not_supported.go +++ b/x/mongo/driver/auth/gssapi_not_supported.go @@ -11,12 +11,13 @@ package auth import ( "fmt" + "net/http" "runtime" ) // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(*Cred, *http.Client) (Authenticator, error) { return nil, newAuthError(fmt.Sprintf("GSSAPI is not supported on %s", runtime.GOOS), nil) } diff --git a/x/mongo/driver/auth/internal/gssapi/gss.go b/x/mongo/driver/auth/internal/gssapi/gss.go index abfa4db47c..496057882d 100644 --- a/x/mongo/driver/auth/internal/gssapi/gss.go +++ b/x/mongo/driver/auth/internal/gssapi/gss.go @@ -19,6 +19,7 @@ package gssapi */ import "C" import ( + "context" "fmt" "runtime" "strings" @@ -91,12 +92,12 @@ func (sc *SaslClient) Start() (string, []byte, error) { return mechName, nil, sc.getError("unable to initialize client") } - payload, err := sc.Next(nil) + payload, err := sc.Next(nil, nil) return mechName, payload, err } -func (sc *SaslClient) Next(challenge []byte) ([]byte, error) { +func (sc *SaslClient) Next(_ context.Context, challenge []byte) ([]byte, error) { var buf unsafe.Pointer var bufLen C.size_t diff --git a/x/mongo/driver/auth/internal/gssapi/sspi.go b/x/mongo/driver/auth/internal/gssapi/sspi.go index 6e7d3ed8ad..d73da025bb 100644 --- a/x/mongo/driver/auth/internal/gssapi/sspi.go +++ b/x/mongo/driver/auth/internal/gssapi/sspi.go @@ -12,6 +12,7 @@ package gssapi // #include "sspi_wrapper.h" import "C" import ( + "context" "fmt" "net" "strconv" @@ -120,7 +121,7 @@ func (sc *SaslClient) Start() (string, []byte, error) { return mechName, payload, err } -func (sc *SaslClient) Next(challenge []byte) ([]byte, error) { +func (sc *SaslClient) Next(_ context.Context, challenge []byte) ([]byte, error) { var outBuf C.PVOID var outBufLen C.ULONG diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index 7ae4b08998..c5cebaa27f 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -9,19 +9,24 @@ package auth import ( "context" "errors" + "net/http" "go.mongodb.org/mongo-driver/internal/aws/credentials" "go.mongodb.org/mongo-driver/internal/credproviders" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds" ) // MongoDBAWS is the mechanism name for MongoDBAWS. const MongoDBAWS = "MONGODB-AWS" -func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { +func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { if cred.Source != "" && cred.Source != "$external" { return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil) } + if httpClient == nil { + return nil, errors.New("httpClient must not be nil") + } return &MongoDBAWSAuthenticator{ source: cred.Source, credentials: &credproviders.StaticProvider{ @@ -32,6 +37,7 @@ func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { SessionToken: cred.Props["AWS_SESSION_TOKEN"], }, }, + httpClient: httpClient, }, nil } @@ -39,15 +45,12 @@ func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { type MongoDBAWSAuthenticator struct { source string credentials *credproviders.StaticProvider + httpClient *http.Client } // Auth authenticates the connection. func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { - httpClient := cfg.HTTPClient - if httpClient == nil { - return errors.New("cfg.HTTPClient must not be nil") - } - providers := creds.NewAWSCredentialProvider(httpClient, a.credentials) + providers := creds.NewAWSCredentialProvider(a.httpClient, a.credentials) adapter := &awsSaslAdapter{ conversation: &awsConversation{ credentials: providers.Cred, @@ -60,6 +63,11 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *MongoDBAWSAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("AWS authentication does not support reauthentication", nil) +} + type awsSaslAdapter struct { conversation *awsConversation } @@ -74,7 +82,7 @@ func (a *awsSaslAdapter) Start() (string, []byte, error) { return MongoDBAWS, step, nil } -func (a *awsSaslAdapter) Next(challenge []byte) ([]byte, error) { +func (a *awsSaslAdapter) Next(_ context.Context, challenge []byte) ([]byte, error) { step, err := a.conversation.Step(challenge) if err != nil { return nil, err diff --git a/x/mongo/driver/auth/mongodbcr.go b/x/mongo/driver/auth/mongodbcr.go index 6e2c2f4dcb..a988011b36 100644 --- a/x/mongo/driver/auth/mongodbcr.go +++ b/x/mongo/driver/auth/mongodbcr.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "io" + "net/http" // Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need // to use MD5 here to implement the MONGODB-CR specification. @@ -28,7 +29,7 @@ import ( // MongoDB 4.0. const MONGODBCR = "MONGODB-CR" -func newMongoDBCRAuthenticator(cred *Cred) (Authenticator, error) { +func newMongoDBCRAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &MongoDBCRAuthenticator{ DB: cred.Source, Username: cred.Username, @@ -97,6 +98,11 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *MongoDBCRAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("MONGODB-CR does not support reauthentication", nil) +} + func (a *MongoDBCRAuthenticator) createKey(nonce string) string { // Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to // implement the MONGODB-CR specification. diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go new file mode 100644 index 0000000000..454a1f635d --- /dev/null +++ b/x/mongo/driver/auth/oidc.go @@ -0,0 +1,542 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package auth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" + "sync" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" +) + +// MongoDBOIDC is the string constant for the MONGODB-OIDC authentication mechanism. +const MongoDBOIDC = "MONGODB-OIDC" + +// const tokenResourceProp = "TOKEN_RESOURCE" +const environmentProp = "ENVIRONMENT" +const resourceProp = "TOKEN_RESOURCE" +const allowedHostsProp = "ALLOWED_HOSTS" + +const azureEnvironmentValue = "azure" +const gcpEnvironmentValue = "gcp" +const testEnvironmentValue = "test" + +const apiVersion = 1 +const invalidateSleepTimeout = 100 * time.Millisecond + +// The CSOT specification says to apply a 1-minute timeout if "CSOT is not applied". That's +// ambiguous for the v1.x Go Driver because it could mean either "no timeout provided" or "CSOT not +// enabled". Always use a maximum timeout duration of 1 minute, allowing us to ignore the ambiguity. +// Contexts with a shorter timeout are unaffected. +const machineCallbackTimeout = time.Minute +const humanCallbackTimeout = 5 * time.Minute + +var defaultAllowedHosts = []*regexp.Regexp{ + regexp.MustCompile(`^.*[.]mongodb[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodb-qa[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodb-dev[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodbgov[.]net(:\d+)?$`), + regexp.MustCompile(`^localhost(:\d+)?$`), + regexp.MustCompile(`^127[.]0[.]0[.]1(:\d+)?$`), + regexp.MustCompile(`^::1(:\d+)?$`), +} + +// OIDCCallback is a function that takes a context and OIDCArgs and returns an OIDCCredential. +type OIDCCallback = driver.OIDCCallback + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs = driver.OIDCArgs + +// OIDCCredential contains the access token and refresh token. +type OIDCCredential = driver.OIDCCredential + +// IDPInfo contains the information needed to perform OIDC authentication with an Identity Provider. +type IDPInfo = driver.IDPInfo + +var _ driver.Authenticator = (*OIDCAuthenticator)(nil) +var _ SpeculativeAuthenticator = (*OIDCAuthenticator)(nil) +var _ SaslClient = (*oidcOneStep)(nil) +var _ SaslClient = (*oidcTwoStep)(nil) + +// OIDCAuthenticator is synchronized and handles caching of the access token, refreshToken, +// and IDPInfo. It also provides a mechanism to refresh the access token, but this functionality +// is only for the OIDC Human flow. +type OIDCAuthenticator struct { + mu sync.Mutex // Guards all of the info in the OIDCAuthenticator struct. + + AuthMechanismProperties map[string]string + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback + + allowedHosts *[]*regexp.Regexp + userName string + httpClient *http.Client + accessToken string + refreshToken *string + idpInfo *IDPInfo + tokenGenID uint64 +} + +// SetAccessToken allows for manually setting the access token for the OIDCAuthenticator, this is +// only for testing purposes. +func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) { + oa.mu.Lock() + defer oa.mu.Unlock() + oa.accessToken = accessToken +} + +func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { + if cred.Password != "" { + return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC) + } + if cred.Props != nil { + if env, ok := cred.Props[environmentProp]; ok { + switch strings.ToLower(env) { + case azureEnvironmentValue: + fallthrough + case gcpEnvironmentValue: + if _, ok := cred.Props[resourceProp]; !ok { + return nil, fmt.Errorf("%q must be specified for %q %q", resourceProp, env, environmentProp) + } + fallthrough + case testEnvironmentValue: + if cred.OIDCMachineCallback != nil || cred.OIDCHumanCallback != nil { + return nil, fmt.Errorf("OIDC callbacks are not allowed for %q %q", env, environmentProp) + } + } + } + } + oa := &OIDCAuthenticator{ + userName: cred.Username, + httpClient: httpClient, + AuthMechanismProperties: cred.Props, + OIDCMachineCallback: cred.OIDCMachineCallback, + OIDCHumanCallback: cred.OIDCHumanCallback, + } + err := oa.setAllowedHosts() + return oa, err +} + +func createPatternsForGlobs(hosts []string) ([]*regexp.Regexp, error) { + var err error + ret := make([]*regexp.Regexp, len(hosts)) + for i := range hosts { + hosts[i] = strings.ReplaceAll(hosts[i], ".", "[.]") + hosts[i] = strings.ReplaceAll(hosts[i], "*", ".*") + hosts[i] = "^" + hosts[i] + "(:\\d+)?$" + ret[i], err = regexp.Compile(hosts[i]) + if err != nil { + return nil, err + } + } + return ret, nil +} + +func (oa *OIDCAuthenticator) setAllowedHosts() error { + if oa.AuthMechanismProperties == nil { + oa.allowedHosts = &defaultAllowedHosts + return nil + } + allowedHosts, ok := oa.AuthMechanismProperties[allowedHostsProp] + if !ok { + oa.allowedHosts = &defaultAllowedHosts + return nil + } + globs := strings.Split(allowedHosts, ",") + ret, err := createPatternsForGlobs(globs) + if err != nil { + return err + } + oa.allowedHosts = &ret + return nil +} + +func (oa *OIDCAuthenticator) validateConnectionAddressWithAllowedHosts(conn driver.Connection) error { + if oa.allowedHosts == nil { + // should be unreachable, but this is a safety check. + return newAuthError(fmt.Sprintf("%q missing", allowedHostsProp), nil) + } + allowedHosts := *oa.allowedHosts + if len(allowedHosts) == 0 { + return newAuthError(fmt.Sprintf("empty %q specified", allowedHostsProp), nil) + } + for _, pattern := range allowedHosts { + if pattern.MatchString(string(conn.Address())) { + return nil + } + } + return newAuthError(fmt.Sprintf("address %q not allowed by %q: %v", conn.Address(), allowedHostsProp, allowedHosts), nil) +} + +type oidcOneStep struct { + userName string + accessToken string +} + +type oidcTwoStep struct { + conn driver.Connection + oa *OIDCAuthenticator +} + +func jwtStepRequest(accessToken string) []byte { + return bsoncore.NewDocumentBuilder(). + AppendString("jwt", accessToken). + Build() +} + +func principalStepRequest(principal string) []byte { + doc := bsoncore.NewDocumentBuilder() + if principal != "" { + doc.AppendString("n", principal) + } + return doc.Build() +} + +func (oos *oidcOneStep) Start() (string, []byte, error) { + return MongoDBOIDC, jwtStepRequest(oos.accessToken), nil +} + +func (oos *oidcOneStep) Next(context.Context, []byte) ([]byte, error) { + return nil, newAuthError("unexpected step in OIDC authentication", nil) +} + +func (*oidcOneStep) Completed() bool { + return true +} + +func (ots *oidcTwoStep) Start() (string, []byte, error) { + return MongoDBOIDC, principalStepRequest(ots.oa.userName), nil +} + +func (ots *oidcTwoStep) Next(ctx context.Context, msg []byte) ([]byte, error) { + var idpInfo IDPInfo + err := bson.Unmarshal(msg, &idpInfo) + if err != nil { + return nil, fmt.Errorf("error unmarshaling BSON document: %w", err) + } + + accessToken, err := ots.oa.getAccessToken(ctx, + ots.conn, + &OIDCArgs{ + Version: apiVersion, + // idpInfo is nil for machine callbacks in the current spec. + IDPInfo: &idpInfo, + // there is no way there could be a refresh token when there is no IDPInfo. + RefreshToken: nil, + }, + // two-step callbacks are always human callbacks. + ots.oa.OIDCHumanCallback) + + return jwtStepRequest(accessToken), err +} + +func (*oidcTwoStep) Completed() bool { + return true +} + +func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { + env, ok := oa.AuthMechanismProperties[environmentProp] + if !ok { + return nil, nil + } + + switch env { + case azureEnvironmentValue: + resource, ok := oa.AuthMechanismProperties[resourceProp] + if !ok { + return nil, newAuthError(fmt.Sprintf("%q must be specified for Azure OIDC", resourceProp), nil) + } + return getAzureOIDCCallback(oa.userName, resource, oa.httpClient), nil + case gcpEnvironmentValue: + resource, ok := oa.AuthMechanismProperties[resourceProp] + if !ok { + return nil, newAuthError(fmt.Sprintf("%q must be specified for GCP OIDC", resourceProp), nil) + } + return getGCPOIDCCallback(resource, oa.httpClient), nil + } + + return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) +} + +// getAzureOIDCCallback returns the callback for the Azure Identity Provider. +func getAzureOIDCCallback(clientID string, resource string, httpClient *http.Client) OIDCCallback { + // return the callback parameterized by the clientID and resource, also passing in the user + // configured httpClient. + return func(ctx context.Context, args *OIDCArgs) (*OIDCCredential, error) { + resource = url.QueryEscape(resource) + var uri string + if clientID != "" { + uri = fmt.Sprintf("http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=%s&client_id=%s", resource, clientID) + } else { + uri = fmt.Sprintf("http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=%s", resource) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return nil, newAuthError("error creating http request to Azure Identity Provider", err) + } + req.Header.Add("Metadata", "true") + req.Header.Add("Accept", "application/json") + resp, err := httpClient.Do(req) + if err != nil { + return nil, newAuthError("error getting access token from Azure Identity Provider", err) + } + defer resp.Body.Close() + var azureResp struct { + AccessToken string `json:"access_token"` + ExpiresOn int64 `json:"expires_on,string"` + } + + if resp.StatusCode != http.StatusOK { + return nil, newAuthError(fmt.Sprintf("failed to get a valid response from Azure Identity Provider, http code: %d", resp.StatusCode), nil) + } + err = json.NewDecoder(resp.Body).Decode(&azureResp) + if err != nil { + return nil, newAuthError("failed parsing result from Azure Identity Provider", err) + } + expireTime := time.Unix(azureResp.ExpiresOn, 0) + return &OIDCCredential{ + AccessToken: azureResp.AccessToken, + ExpiresAt: &expireTime, + }, nil + } +} + +// getGCPOIDCCallback returns the callback for the GCP Identity Provider. +func getGCPOIDCCallback(resource string, httpClient *http.Client) OIDCCallback { + // return the callback parameterized by the clientID and resource, also passing in the user + // configured httpClient. + return func(ctx context.Context, args *OIDCArgs) (*OIDCCredential, error) { + resource = url.QueryEscape(resource) + uri := fmt.Sprintf("http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=%s", resource) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return nil, newAuthError("error creating http request to GCP Identity Provider", err) + } + req.Header.Add("Metadata-Flavor", "Google") + resp, err := httpClient.Do(req) + if err != nil { + return nil, newAuthError("error getting access token from GCP Identity Provider", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, newAuthError(fmt.Sprintf("failed to get a valid response from GCP Identity Provider, http code: %d", resp.StatusCode), nil) + } + accessToken, err := io.ReadAll(resp.Body) + if err != nil { + return nil, newAuthError("failed parsing reading response from GCP Identity Provider", err) + } + return &OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: nil, + }, nil + } +} + +func (oa *OIDCAuthenticator) getAccessToken( + ctx context.Context, + conn driver.Connection, + args *OIDCArgs, + callback OIDCCallback, +) (string, error) { + oa.mu.Lock() + defer oa.mu.Unlock() + + if oa.accessToken != "" { + return oa.accessToken, nil + } + + // Attempt to refresh the access token if a refresh token is available. + if args.RefreshToken != nil { + cred, err := callback(ctx, args) + if err == nil && cred != nil { + oa.accessToken = cred.AccessToken + oa.tokenGenID++ + conn.SetOIDCTokenGenID(oa.tokenGenID) + oa.refreshToken = cred.RefreshToken + return cred.AccessToken, nil + } + oa.refreshToken = nil + args.RefreshToken = nil + } + // If we get here this means there either was no refresh token or the refresh token failed. + cred, err := callback(ctx, args) + if err != nil { + return "", err + } + // This line should never occur, if go conventions are followed, but it is a safety check such + // that we do not throw nil pointer errors to our users if they abuse the API. + if cred == nil { + return "", newAuthError("OIDC callback returned nil credential with no specified error", nil) + } + + oa.accessToken = cred.AccessToken + oa.tokenGenID++ + conn.SetOIDCTokenGenID(oa.tokenGenID) + oa.refreshToken = cred.RefreshToken + // always set the IdPInfo, in most cases, this should just be recopying the same pointer, or nil + // in the machine flow. + oa.idpInfo = args.IDPInfo + return cred.AccessToken, nil +} + +// invalidateAccessToken invalidates the access token, if the force flag is set to true (which is +// only on a Reauth call) or if the tokenGenID of the connection is greater than or equal to the +// tokenGenID of the OIDCAuthenticator. It should never actually be greater than, but only equal, +// but this is a safety check, since extra invalidation is only a performance impact, not a +// correctness impact. +func (oa *OIDCAuthenticator) invalidateAccessToken(conn driver.Connection) { + oa.mu.Lock() + defer oa.mu.Unlock() + tokenGenID := conn.OIDCTokenGenID() + // If the connection used in a Reauth is a new connection it will not have a correct tokenGenID, + // it will instead be set to 0. In the absence of information, the only safe thing to do is to + // invalidate the cached accessToken. + if tokenGenID == 0 || tokenGenID >= oa.tokenGenID { + oa.accessToken = "" + conn.SetOIDCTokenGenID(0) + } +} + +// Reauth reauthenticates the connection when the server returns a 391 code. Reauth is part of the +// driver.Authenticator interface. +func (oa *OIDCAuthenticator) Reauth(ctx context.Context, cfg *Config) error { + oa.invalidateAccessToken(cfg.Connection) + return oa.Auth(ctx, cfg) +} + +// Auth authenticates the connection. +func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { + var err error + + if cfg == nil { + return newAuthError(fmt.Sprintf("config must be set for %q authentication", MongoDBOIDC), nil) + } + conn := cfg.Connection + + oa.mu.Lock() + cachedAccessToken := oa.accessToken + cachedRefreshToken := oa.refreshToken + cachedIDPInfo := oa.idpInfo + oa.mu.Unlock() + + if cachedAccessToken != "" { + err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + userName: oa.userName, + accessToken: cachedAccessToken, + }) + if err == nil { + return nil + } + // this seems like it could be incorrect since we could be inavlidating an access token that + // has already been replaced by a different auth attempt, but the TokenGenID will prevernt + // that from happening. + oa.invalidateAccessToken(conn) + time.Sleep(invalidateSleepTimeout) + } + + if oa.OIDCHumanCallback != nil { + return oa.doAuthHuman(ctx, cfg, oa.OIDCHumanCallback, cachedIDPInfo, cachedRefreshToken) + } + + // Handle user provided or automatic provider machine callback. + var machineCallback OIDCCallback + if oa.OIDCMachineCallback != nil { + machineCallback = oa.OIDCMachineCallback + } else { + machineCallback, err = oa.providerCallback() + if err != nil { + return fmt.Errorf("error getting built-in OIDC provider: %w", err) + } + } + + if machineCallback != nil { + return oa.doAuthMachine(ctx, cfg, machineCallback) + } + return newAuthError("no OIDC callback provided", nil) +} + +func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *Config, humanCallback OIDCCallback, idpInfo *IDPInfo, refreshToken *string) error { + // Ensure that the connection address is allowed by the allowed hosts. + err := oa.validateConnectionAddressWithAllowedHosts(cfg.Connection) + if err != nil { + return err + } + subCtx, cancel := context.WithTimeout(ctx, humanCallbackTimeout) + defer cancel() + // If the idpInfo exists, we can just do one step + if idpInfo != nil { + accessToken, err := oa.getAccessToken(subCtx, + cfg.Connection, + &OIDCArgs{ + Version: apiVersion, + // idpInfo is nil for machine callbacks in the current spec. + IDPInfo: idpInfo, + RefreshToken: refreshToken, + }, + humanCallback) + if err != nil { + return err + } + return ConductSaslConversation( + subCtx, + cfg, + "$external", + &oidcOneStep{accessToken: accessToken}, + ) + } + // otherwise, we need the two step where we ask the server for the IdPInfo first. + ots := &oidcTwoStep{ + conn: cfg.Connection, + oa: oa, + } + return ConductSaslConversation(subCtx, cfg, "$external", ots) +} + +func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, machineCallback OIDCCallback) error { + subCtx, cancel := context.WithTimeout(ctx, machineCallbackTimeout) + accessToken, err := oa.getAccessToken(subCtx, + cfg.Connection, + &OIDCArgs{ + Version: apiVersion, + // idpInfo is nil for machine callbacks in the current spec. + IDPInfo: nil, + RefreshToken: nil, + }, + machineCallback) + cancel() + if err != nil { + return err + } + return ConductSaslConversation( + ctx, + cfg, + "$external", + &oidcOneStep{accessToken: accessToken}, + ) +} + +// CreateSpeculativeConversation creates a speculative conversation for OIDC authentication. +func (oa *OIDCAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) { + oa.mu.Lock() + defer oa.mu.Unlock() + accessToken := oa.accessToken + if accessToken == "" { + return nil, nil // Skip speculative auth. + } + + return newSaslConversation(&oidcOneStep{accessToken: accessToken}, "$external", true), nil +} diff --git a/x/mongo/driver/auth/oidc_test.go b/x/mongo/driver/auth/oidc_test.go new file mode 100644 index 0000000000..dcb941aff1 --- /dev/null +++ b/x/mongo/driver/auth/oidc_test.go @@ -0,0 +1,44 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package auth + +import ( + "regexp" + "testing" + + "go.mongodb.org/mongo-driver/internal/assert" +) + +func TestCreatePatternsForGlobs(t *testing.T) { + t.Run("transform allowedHosts patterns", func(t *testing.T) { + + hosts := []string{ + "*.mongodb.net", + "*.mongodb-qa.net", + "*.mongodb-dev.net", + "*.mongodbgov.net", + "localhost", + "127.0.0.1", + "::1", + } + + check, err := createPatternsForGlobs(hosts) + assert.NoError(t, err) + assert.Equal(t, + []*regexp.Regexp{ + regexp.MustCompile(`^.*[.]mongodb[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodb-qa[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodb-dev[.]net(:\d+)?$`), + regexp.MustCompile(`^.*[.]mongodbgov[.]net(:\d+)?$`), + regexp.MustCompile(`^localhost(:\d+)?$`), + regexp.MustCompile(`^127[.]0[.]0[.]1(:\d+)?$`), + regexp.MustCompile(`^::1(:\d+)?$`), + }, + check, + ) + }) +} diff --git a/x/mongo/driver/auth/plain.go b/x/mongo/driver/auth/plain.go index 532d43e39f..9fce7ec383 100644 --- a/x/mongo/driver/auth/plain.go +++ b/x/mongo/driver/auth/plain.go @@ -8,12 +8,15 @@ package auth import ( "context" + "net/http" + + "go.mongodb.org/mongo-driver/x/mongo/driver" ) // PLAIN is the mechanism name for PLAIN. const PLAIN = "PLAIN" -func newPlainAuthenticator(cred *Cred) (Authenticator, error) { +func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &PlainAuthenticator{ Username: cred.Username, Password: cred.Password, @@ -34,6 +37,11 @@ func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error { }) } +// Reauth reauthenticates the connection. +func (a *PlainAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("Plain authentication does not support reauthentication", nil) +} + type plainSaslClient struct { username string password string @@ -46,7 +54,7 @@ func (c *plainSaslClient) Start() (string, []byte, error) { return PLAIN, b, nil } -func (c *plainSaslClient) Next([]byte) ([]byte, error) { +func (c *plainSaslClient) Next(context.Context, []byte) ([]byte, error) { return nil, newAuthError("unexpected server challenge", nil) } diff --git a/x/mongo/driver/auth/sasl.go b/x/mongo/driver/auth/sasl.go index 2a84b53a64..1ef67f02b0 100644 --- a/x/mongo/driver/auth/sasl.go +++ b/x/mongo/driver/auth/sasl.go @@ -19,7 +19,7 @@ import ( // SaslClient is the client piece of a sasl conversation. type SaslClient interface { Start() (string, []byte, error) - Next(challenge []byte) ([]byte, error) + Next(ctx context.Context, challenge []byte) ([]byte, error) Completed() bool } @@ -118,7 +118,7 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon return nil } - payload, err = sc.client.Next(saslResp.Payload) + payload, err = sc.client.Next(ctx, saslResp.Payload) if err != nil { return newError(err, sc.mechanism) } @@ -156,7 +156,6 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error { // Create a non-speculative SASL conversation. conversation := newSaslConversation(client, authSource, false) - saslStartDoc, err := conversation.FirstMessage() if err != nil { return newError(err, conversation.mechanism) diff --git a/x/mongo/driver/auth/scram.go b/x/mongo/driver/auth/scram.go index c1238cd6a9..8c04ce32cc 100644 --- a/x/mongo/driver/auth/scram.go +++ b/x/mongo/driver/auth/scram.go @@ -14,10 +14,12 @@ package auth import ( "context" + "net/http" "github.com/xdg-go/scram" "github.com/xdg-go/stringprep" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" ) const ( @@ -35,7 +37,7 @@ var ( ) ) -func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) { +func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { passdigest := mongoPasswordDigest(cred.Username, cred.Password) client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "") if err != nil { @@ -49,7 +51,7 @@ func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) { }, nil } -func newScramSHA256Authenticator(cred *Cred) (Authenticator, error) { +func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { passprep, err := stringprep.SASLprep.Prepare(cred.Password) if err != nil { return nil, newAuthError("error SASLprepping password", err) @@ -84,6 +86,11 @@ func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *ScramAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("SCRAM does not support reauthentication", nil) +} + // CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication. func (a *ScramAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) { return newSaslConversation(a.createSaslClient(), a.source, true), nil @@ -112,7 +119,7 @@ func (a *scramSaslAdapter) Start() (string, []byte, error) { return a.mechanism, []byte(step), nil } -func (a *scramSaslAdapter) Next(challenge []byte) ([]byte, error) { +func (a *scramSaslAdapter) Next(_ context.Context, challenge []byte) ([]byte, error) { step, err := a.conversation.Step(string(challenge)) if err != nil { return nil, err diff --git a/x/mongo/driver/auth/scram_test.go b/x/mongo/driver/auth/scram_test.go index ef30a07364..0a745885ee 100644 --- a/x/mongo/driver/auth/scram_test.go +++ b/x/mongo/driver/auth/scram_test.go @@ -8,6 +8,7 @@ package auth import ( "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/internal/assert" @@ -38,7 +39,7 @@ func TestSCRAM(t *testing.T) { t.Run("conversation", func(t *testing.T) { testCases := []struct { name string - createAuthenticatorFn func(*Cred) (Authenticator, error) + createAuthenticatorFn func(*Cred, *http.Client) (Authenticator, error) payloads [][]byte nonce string }{ @@ -49,11 +50,13 @@ func TestSCRAM(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - authenticator, err := tc.createAuthenticatorFn(&Cred{ - Username: "user", - Password: "pencil", - Source: "admin", - }) + authenticator, err := tc.createAuthenticatorFn( + &Cred{ + Username: "user", + Password: "pencil", + Source: "admin", + }, + &http.Client{}) assert.Nil(t, err, "error creating authenticator: %v", err) sa, _ := authenticator.(*ScramAuthenticator) sa.client = sa.client.WithNonceGenerator(func() string { diff --git a/x/mongo/driver/auth/speculative_scram_test.go b/x/mongo/driver/auth/speculative_scram_test.go index a159891adc..9108fe1d21 100644 --- a/x/mongo/driver/auth/speculative_scram_test.go +++ b/x/mongo/driver/auth/speculative_scram_test.go @@ -9,6 +9,7 @@ package auth import ( "bytes" "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/bson" @@ -63,7 +64,7 @@ func TestSpeculativeSCRAM(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Create a SCRAM authenticator and overwrite the nonce generator to make the conversation // deterministic. - authenticator, err := CreateAuthenticator(tc.mechanism, cred) + authenticator, err := CreateAuthenticator(tc.mechanism, cred, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) setNonce(t, authenticator, tc.nonce) @@ -148,7 +149,7 @@ func TestSpeculativeSCRAM(t *testing.T) { for _, tc := range testCases { t.Run(tc.mechanism, func(t *testing.T) { - authenticator, err := CreateAuthenticator(tc.mechanism, cred) + authenticator, err := CreateAuthenticator(tc.mechanism, cred, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) setNonce(t, authenticator, tc.nonce) diff --git a/x/mongo/driver/auth/speculative_x509_test.go b/x/mongo/driver/auth/speculative_x509_test.go index 85bd93191b..e26b448e79 100644 --- a/x/mongo/driver/auth/speculative_x509_test.go +++ b/x/mongo/driver/auth/speculative_x509_test.go @@ -9,6 +9,7 @@ package auth import ( "bytes" "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/bson" @@ -32,7 +33,7 @@ func TestSpeculativeX509(t *testing.T) { // Tests for X509 when the hello response contains a reply to the speculative authentication attempt. The // driver should not send any more commands after the hello. - authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}) + authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) handshaker := Handshaker(nil, &HandshakeOptions{ Authenticator: authenticator, @@ -76,7 +77,7 @@ func TestSpeculativeX509(t *testing.T) { // Tests for X509 when the hello response does not contain a reply to the speculative authentication attempt. // The driver should send an authenticate command after the hello. - authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}) + authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) handshaker := Handshaker(nil, &HandshakeOptions{ Authenticator: authenticator, diff --git a/x/mongo/driver/auth/x509.go b/x/mongo/driver/auth/x509.go index 03a9d750e2..3e84f516f8 100644 --- a/x/mongo/driver/auth/x509.go +++ b/x/mongo/driver/auth/x509.go @@ -8,6 +8,7 @@ package auth import ( "context" + "net/http" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -17,7 +18,7 @@ import ( // MongoDBX509 is the mechanism name for MongoDBX509. const MongoDBX509 = "MONGODB-X509" -func newMongoDBX509Authenticator(cred *Cred) (Authenticator, error) { +func newMongoDBX509Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &MongoDBX509Authenticator{User: cred.Username}, nil } @@ -76,3 +77,8 @@ func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *Config) error return nil } + +// Reauth reauthenticates the connection. +func (a *MongoDBX509Authenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("X509 does not support reauthentication", nil) +} diff --git a/x/mongo/driver/compression.go b/x/mongo/driver/compression.go index d79b024b74..d9a6c68fee 100644 --- a/x/mongo/driver/compression.go +++ b/x/mongo/driver/compression.go @@ -30,7 +30,11 @@ type CompressionOpts struct { // destination writer. It panics on any errors and should only be used at // package initialization time. func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder { - enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl)) + enc, err := zstd.NewWriter( + nil, + zstd.WithWindowSize(8<<20), // Set window size to 8MB. + zstd.WithEncoderLevel(lvl), + ) if err != nil { panic(err) } @@ -105,6 +109,13 @@ func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { return dst, nil } +var zstdBufPool = sync.Pool{ + New: func() interface{} { + s := make([]byte, 0) + return &s + }, +} + // CompressPayload takes a byte slice and compresses it according to the options passed func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { switch opts.Compressor { @@ -123,7 +134,13 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { if err != nil { return nil, err } - return encoder.EncodeAll(in, nil), nil + ptr := zstdBufPool.Get().(*[]byte) + b := encoder.EncodeAll(in, *ptr) + dst := make([]byte, len(b)) + copy(dst, b) + *ptr = b[:0] + zstdBufPool.Put(ptr) + return dst, nil default: return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) } diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 686458e292..a8adafb8f8 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -302,6 +302,13 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error { u.AuthSource = "admin" } } + case "mongodb-oidc": + if u.AuthSource == "" { + u.AuthSource = dbName + if u.AuthSource == "" { + u.AuthSource = "$external" + } + } case "": // Only set auth source if there is a request for authentication via non-empty credentials. if u.AuthSource == "" && (u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet) { @@ -781,6 +788,10 @@ func (u *ConnString) validateAuth() error { if u.AuthMechanismProperties != nil { return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties") } + case "mongodb-oidc": + if u.Password != "" { + return fmt.Errorf("password cannot be specified for MONGODB-OIDC") + } case "": if u.UsernameSet && u.Username == "" { return fmt.Errorf("username required if URI contains user info") diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 900729bf87..363f4d6be3 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -24,6 +24,63 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) +// AuthConfig holds the information necessary to perform an authentication attempt. +// this was moved from the auth package to avoid a circular dependency. The auth package +// reexports this under the old name to avoid breaking the public api. +type AuthConfig struct { + Description description.Server + Connection Connection + ClusterClock *session.ClusterClock + HandshakeInfo HandshakeInformation + ServerAPI *ServerAPIOptions +} + +// OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be +// nil in the OIDCArgs for the Machine flow. +type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs struct { + Version int + IDPInfo *IDPInfo + RefreshToken *string +} + +// OIDCCredential contains the access token and refresh token. +type OIDCCredential struct { + AccessToken string + ExpiresAt *time.Time + RefreshToken *string +} + +// IDPInfo contains the information needed to perform OIDC authentication with an Identity Provider. +type IDPInfo struct { + Issuer string `bson:"issuer"` + ClientID string `bson:"clientId"` + RequestScopes []string `bson:"requestScopes"` +} + +// Authenticator handles authenticating a connection. The implementers of this interface +// are all in the auth package. Most authentication mechanisms do not allow for Reauth, +// but this is included in the interface so that whenever a new mechanism is added, it +// must be explicitly considered. +type Authenticator interface { + // Auth authenticates the connection. + Auth(context.Context, *AuthConfig) error + Reauth(context.Context, *AuthConfig) error +} + +// Cred is a user's credential. +type Cred struct { + Source string + Username string + Password string + PasswordSet bool + Props map[string]string + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback +} + // Deployment is implemented by types that can select a server from a deployment. type Deployment interface { SelectServer(context.Context, description.ServerSelector) (Server, error) @@ -79,6 +136,8 @@ type Connection interface { DriverConnectionID() uint64 // TODO(GODRIVER-2824): change type to int64. Address() address.Address Stale() bool + OIDCTokenGenID() uint64 + SetOIDCTokenGenID(uint64) } // RTTMonitor represents a round-trip-time monitor. diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index 27be4c264d..d002398a5b 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -26,6 +26,16 @@ type ChannelConn struct { Desc description.Server } +// OIDCTokenGenID implements the driver.Connection interface by returning the OIDCToken generation +// (which is always 0) +func (c *ChannelConn) OIDCTokenGenID() uint64 { + return 0 +} + +// SetOIDCTokenGenID implements the driver.Connection interface by setting the OIDCToken generation +// (which is always 0) +func (c *ChannelConn) SetOIDCTokenGenID(uint64) {} + // WriteWireMessage implements the driver.Connection interface. func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error { // Copy wm in case it came from a buffer pool. diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index b557002293..cea3543d14 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -315,6 +315,10 @@ type Operation struct { // [Operation.MaxTime]. OmitCSOTMaxTimeMS bool + // Authenticator is the authenticator to use for this operation when a reauthentication is + // required. + Authenticator Authenticator + // omitReadPreference is a boolean that indicates whether to omit the // read preference from the command. This omition includes the case // where a default read preference is used when the operation @@ -912,6 +916,28 @@ func (op Operation) Execute(ctx context.Context) error { operationErr.Labels = tt.Labels operationErr.Raw = tt.Raw case Error: + // 391 is the reauthentication required error code, so we will attempt a reauth and + // retry the operation, if it is successful. + if tt.Code == 391 { + if op.Authenticator != nil { + cfg := AuthConfig{ + Description: conn.Description(), + Connection: conn, + ClusterClock: op.Clock, + ServerAPI: op.ServerAPI, + } + if err := op.Authenticator.Reauth(ctx, &cfg); err != nil { + return fmt.Errorf("error reauthenticating: %w", err) + } + if op.Client != nil && op.Client.Committing { + // Apply majority write concern for retries + op.Client.UpdateCommitTransactionWriteConcern() + op.WriteConcern = op.Client.CurrentWc + } + resetForRetry(tt) + continue + } + } if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) { if err := op.Client.ClearPinnedResources(); err != nil { return err @@ -1574,11 +1600,17 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) // operation's MaxTimeMS if set. If no MaxTimeMS is set on the operation, and context is // not a Timeout context, calculateMaxTimeMS returns 0. func (op Operation) calculateMaxTimeMS(ctx context.Context, mon RTTMonitor) (uint64, error) { - if csot.IsTimeoutContext(ctx) { - if op.OmitCSOTMaxTimeMS { - return 0, nil - } - + // If CSOT is enabled and we're not omitting the CSOT-calculated maxTimeMS + // value, then calculate maxTimeMS. + // + // This allows commands that do not currently send CSOT-calculated maxTimeMS + // (e.g. Find and Aggregate) to still use a manually-provided maxTimeMS + // value. + // + // TODO(GODRIVER-2944): Remove or refactor this logic when we add the + // "timeoutMode" option, which will allow users to opt-in to the + // CSOT-calculated maxTimeMS values if that's the behavior they want. + if csot.IsTimeoutContext(ctx) && !op.OmitCSOTMaxTimeMS { if deadline, ok := ctx.Deadline(); ok { remainingTimeout := time.Until(deadline) rtt90 := mon.P90() @@ -1893,7 +1925,6 @@ func (op Operation) decodeResult(ctx context.Context, opcode wiremessage.OpCode, return nil, errors.New("malformed wire message: insufficient bytes to read single document") } case wiremessage.DocumentSequence: - // TODO(GODRIVER-617): Implement document sequence returns. _, _, wm, ok = wiremessage.ReadMsgSectionDocumentSequence(wm) if !ok { return nil, errors.New("malformed wire message: insufficient bytes to read document sequence") diff --git a/x/mongo/driver/operation/abort_transaction.go b/x/mongo/driver/operation/abort_transaction.go index 9413727130..aeee533533 100644 --- a/x/mongo/driver/operation/abort_transaction.go +++ b/x/mongo/driver/operation/abort_transaction.go @@ -21,6 +21,7 @@ import ( // AbortTransaction performs an abortTransaction operation. type AbortTransaction struct { + authenticator driver.Authenticator recoveryToken bsoncore.Document session *session.Client clock *session.ClusterClock @@ -66,6 +67,7 @@ func (at *AbortTransaction) Execute(ctx context.Context) error { WriteConcern: at.writeConcern, ServerAPI: at.serverAPI, Name: driverutil.AbortTransactionOp, + Authenticator: at.authenticator, }.Execute(ctx) } @@ -199,3 +201,13 @@ func (at *AbortTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *Abort at.serverAPI = serverAPI return at } + +// Authenticator sets the authenticator to use for this operation. +func (at *AbortTransaction) Authenticator(authenticator driver.Authenticator) *AbortTransaction { + if at == nil { + at = new(AbortTransaction) + } + + at.authenticator = authenticator + return at +} diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index 44467df8fd..df6b8fa9dd 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -25,6 +25,7 @@ import ( // Aggregate represents an aggregate operation. type Aggregate struct { + authenticator driver.Authenticator allowDiskUse *bool batchSize *int32 bypassDocumentValidation *bool @@ -115,6 +116,7 @@ func (a *Aggregate) Execute(ctx context.Context) error { Timeout: a.timeout, Name: driverutil.AggregateOp, OmitCSOTMaxTimeMS: a.omitCSOTMaxTimeMS, + Authenticator: a.authenticator, }.Execute(ctx) } @@ -433,3 +435,13 @@ func (a *Aggregate) OmitCSOTMaxTimeMS(omit bool) *Aggregate { a.omitCSOTMaxTimeMS = omit return a } + +// Authenticator sets the authenticator to use for this operation. +func (a *Aggregate) Authenticator(authenticator driver.Authenticator) *Aggregate { + if a == nil { + a = new(Aggregate) + } + + a.authenticator = authenticator + return a +} diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index 35283794a3..9dd10f3cb0 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -22,6 +22,7 @@ import ( // Command is used to run a generic operation. type Command struct { + authenticator driver.Authenticator command bsoncore.Document database string deployment driver.Deployment @@ -107,6 +108,7 @@ func (c *Command) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Logger: c.logger, + Authenticator: c.authenticator, }.Execute(ctx) } @@ -219,3 +221,13 @@ func (c *Command) Logger(logger *logger.Logger) *Command { c.logger = logger return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Command) Authenticator(authenticator driver.Authenticator) *Command { + if c == nil { + c = new(Command) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index 11c6f69ddf..6b402bdf63 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -22,6 +22,7 @@ import ( // CommitTransaction attempts to commit a transaction. type CommitTransaction struct { + authenticator driver.Authenticator maxTime *time.Duration recoveryToken bsoncore.Document session *session.Client @@ -68,6 +69,7 @@ func (ct *CommitTransaction) Execute(ctx context.Context) error { WriteConcern: ct.writeConcern, ServerAPI: ct.serverAPI, Name: driverutil.CommitTransactionOp, + Authenticator: ct.authenticator, }.Execute(ctx) } @@ -201,3 +203,13 @@ func (ct *CommitTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *Comm ct.serverAPI = serverAPI return ct } + +// Authenticator sets the authenticator to use for this operation. +func (ct *CommitTransaction) Authenticator(authenticator driver.Authenticator) *CommitTransaction { + if ct == nil { + ct = new(CommitTransaction) + } + + ct.authenticator = authenticator + return ct +} diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 8de1e9f8d9..eaafc9a244 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -25,6 +25,7 @@ import ( // Count represents a count operation. type Count struct { + authenticator driver.Authenticator maxTime *time.Duration query bsoncore.Document session *session.Client @@ -128,6 +129,7 @@ func (c *Count) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Name: driverutil.CountOp, + Authenticator: c.authenticator, }.Execute(ctx) // Swallow error if NamespaceNotFound(26) is returned from aggregate on non-existent namespace @@ -311,3 +313,13 @@ func (c *Count) Timeout(timeout *time.Duration) *Count { c.timeout = timeout return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Count) Authenticator(authenticator driver.Authenticator) *Count { + if c == nil { + c = new(Count) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/create.go b/x/mongo/driver/operation/create.go index 45b26cb707..4878e2c777 100644 --- a/x/mongo/driver/operation/create.go +++ b/x/mongo/driver/operation/create.go @@ -20,6 +20,7 @@ import ( // Create represents a create operation. type Create struct { + authenticator driver.Authenticator capped *bool collation bsoncore.Document changeStreamPreAndPostImages bsoncore.Document @@ -77,6 +78,7 @@ func (c *Create) Execute(ctx context.Context) error { Selector: c.selector, WriteConcern: c.writeConcern, ServerAPI: c.serverAPI, + Authenticator: c.authenticator, }.Execute(ctx) } @@ -399,3 +401,13 @@ func (c *Create) ClusteredIndex(ci bsoncore.Document) *Create { c.clusteredIndex = ci return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Create) Authenticator(authenticator driver.Authenticator) *Create { + if c == nil { + c = new(Create) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/create_indexes.go b/x/mongo/driver/operation/create_indexes.go index 77daf676a4..464c1762de 100644 --- a/x/mongo/driver/operation/create_indexes.go +++ b/x/mongo/driver/operation/create_indexes.go @@ -24,21 +24,22 @@ import ( // CreateIndexes performs a createIndexes operation. type CreateIndexes struct { - commitQuorum bsoncore.Value - indexes bsoncore.Document - maxTime *time.Duration - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result CreateIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + commitQuorum bsoncore.Value + indexes bsoncore.Document + maxTime *time.Duration + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result CreateIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // CreateIndexesResult represents a createIndexes result returned by the server. @@ -119,6 +120,7 @@ func (ci *CreateIndexes) Execute(ctx context.Context) error { ServerAPI: ci.serverAPI, Timeout: ci.timeout, Name: driverutil.CreateIndexesOp, + Authenticator: ci.authenticator, }.Execute(ctx) } @@ -278,3 +280,13 @@ func (ci *CreateIndexes) Timeout(timeout *time.Duration) *CreateIndexes { ci.timeout = timeout return ci } + +// Authenticator sets the authenticator to use for this operation. +func (ci *CreateIndexes) Authenticator(authenticator driver.Authenticator) *CreateIndexes { + if ci == nil { + ci = new(CreateIndexes) + } + + ci.authenticator = authenticator + return ci +} diff --git a/x/mongo/driver/operation/create_search_indexes.go b/x/mongo/driver/operation/create_search_indexes.go index cb0d807952..8185d27fe1 100644 --- a/x/mongo/driver/operation/create_search_indexes.go +++ b/x/mongo/driver/operation/create_search_indexes.go @@ -22,18 +22,19 @@ import ( // CreateSearchIndexes performs a createSearchIndexes operation. type CreateSearchIndexes struct { - indexes bsoncore.Document - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result CreateSearchIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + indexes bsoncore.Document + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result CreateSearchIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // CreateSearchIndexResult represents a single search index result in CreateSearchIndexesResult. @@ -116,6 +117,7 @@ func (csi *CreateSearchIndexes) Execute(ctx context.Context) error { Selector: csi.selector, ServerAPI: csi.serverAPI, Timeout: csi.timeout, + Authenticator: csi.authenticator, }.Execute(ctx) } @@ -237,3 +239,13 @@ func (csi *CreateSearchIndexes) Timeout(timeout *time.Duration) *CreateSearchInd csi.timeout = timeout return csi } + +// Authenticator sets the authenticator to use for this operation. +func (csi *CreateSearchIndexes) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexes { + if csi == nil { + csi = new(CreateSearchIndexes) + } + + csi.authenticator = authenticator + return csi +} diff --git a/x/mongo/driver/operation/delete.go b/x/mongo/driver/operation/delete.go index bf95cf496d..298ec44196 100644 --- a/x/mongo/driver/operation/delete.go +++ b/x/mongo/driver/operation/delete.go @@ -25,25 +25,26 @@ import ( // Delete performs a delete operation type Delete struct { - comment bsoncore.Value - deletes []bsoncore.Document - ordered *bool - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - retry *driver.RetryMode - hint *bool - result DeleteResult - serverAPI *driver.ServerAPIOptions - let bsoncore.Document - timeout *time.Duration - logger *logger.Logger + authenticator driver.Authenticator + comment bsoncore.Value + deletes []bsoncore.Document + ordered *bool + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + retry *driver.RetryMode + hint *bool + result DeleteResult + serverAPI *driver.ServerAPIOptions + let bsoncore.Document + timeout *time.Duration + logger *logger.Logger } // DeleteResult represents a delete result returned by the server. @@ -116,6 +117,7 @@ func (d *Delete) Execute(ctx context.Context) error { Timeout: d.timeout, Logger: d.logger, Name: driverutil.DeleteOp, + Authenticator: d.authenticator, }.Execute(ctx) } @@ -328,3 +330,13 @@ func (d *Delete) Logger(logger *logger.Logger) *Delete { return d } + +// Authenticator sets the authenticator to use for this operation. +func (d *Delete) Authenticator(authenticator driver.Authenticator) *Delete { + if d == nil { + d = new(Delete) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index b7e675ce42..484d96b66b 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -24,6 +24,7 @@ import ( // Distinct performs a distinct operation. type Distinct struct { + authenticator driver.Authenticator collation bsoncore.Document key *string maxTime *time.Duration @@ -107,6 +108,7 @@ func (d *Distinct) Execute(ctx context.Context) error { ServerAPI: d.serverAPI, Timeout: d.timeout, Name: driverutil.DistinctOp, + Authenticator: d.authenticator, }.Execute(ctx) } @@ -311,3 +313,13 @@ func (d *Distinct) Timeout(timeout *time.Duration) *Distinct { d.timeout = timeout return d } + +// Authenticator sets the authenticator to use for this operation. +func (d *Distinct) Authenticator(authenticator driver.Authenticator) *Distinct { + if d == nil { + d = new(Distinct) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/drop_collection.go b/x/mongo/driver/operation/drop_collection.go index 8c65967564..5a32c2f8d4 100644 --- a/x/mongo/driver/operation/drop_collection.go +++ b/x/mongo/driver/operation/drop_collection.go @@ -23,18 +23,19 @@ import ( // DropCollection performs a drop operation. type DropCollection struct { - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result DropCollectionResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result DropCollectionResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropCollectionResult represents a dropCollection result returned by the server. @@ -104,6 +105,7 @@ func (dc *DropCollection) Execute(ctx context.Context) error { ServerAPI: dc.serverAPI, Timeout: dc.timeout, Name: driverutil.DropOp, + Authenticator: dc.authenticator, }.Execute(ctx) } @@ -222,3 +224,13 @@ func (dc *DropCollection) Timeout(timeout *time.Duration) *DropCollection { dc.timeout = timeout return dc } + +// Authenticator sets the authenticator to use for this operation. +func (dc *DropCollection) Authenticator(authenticator driver.Authenticator) *DropCollection { + if dc == nil { + dc = new(DropCollection) + } + + dc.authenticator = authenticator + return dc +} diff --git a/x/mongo/driver/operation/drop_database.go b/x/mongo/driver/operation/drop_database.go index a8f9b45ba4..19956210d1 100644 --- a/x/mongo/driver/operation/drop_database.go +++ b/x/mongo/driver/operation/drop_database.go @@ -21,15 +21,16 @@ import ( // DropDatabase performs a dropDatabase operation type DropDatabase struct { - session *session.Client - clock *session.ClusterClock - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - serverAPI *driver.ServerAPIOptions + authenticator driver.Authenticator + session *session.Client + clock *session.ClusterClock + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + serverAPI *driver.ServerAPIOptions } // NewDropDatabase constructs and returns a new DropDatabase. @@ -55,6 +56,7 @@ func (dd *DropDatabase) Execute(ctx context.Context) error { WriteConcern: dd.writeConcern, ServerAPI: dd.serverAPI, Name: driverutil.DropDatabaseOp, + Authenticator: dd.authenticator, }.Execute(ctx) } @@ -154,3 +156,13 @@ func (dd *DropDatabase) ServerAPI(serverAPI *driver.ServerAPIOptions) *DropDatab dd.serverAPI = serverAPI return dd } + +// Authenticator sets the authenticator to use for this operation. +func (dd *DropDatabase) Authenticator(authenticator driver.Authenticator) *DropDatabase { + if dd == nil { + dd = new(DropDatabase) + } + + dd.authenticator = authenticator + return dd +} diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index 0c3d459707..9cbd797be2 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -23,20 +23,21 @@ import ( // DropIndexes performs an dropIndexes operation. type DropIndexes struct { - index *string - maxTime *time.Duration - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result DropIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index any + maxTime *time.Duration + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result DropIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropIndexesResult represents a dropIndexes result returned by the server. @@ -65,9 +66,9 @@ func buildDropIndexesResult(response bsoncore.Document) (DropIndexesResult, erro } // NewDropIndexes constructs and returns a new DropIndexes. -func NewDropIndexes(index string) *DropIndexes { +func NewDropIndexes(index any) *DropIndexes { return &DropIndexes{ - index: &index, + index: index, } } @@ -101,25 +102,33 @@ func (di *DropIndexes) Execute(ctx context.Context) error { ServerAPI: di.serverAPI, Timeout: di.timeout, Name: driverutil.DropIndexesOp, + Authenticator: di.authenticator, }.Execute(ctx) } func (di *DropIndexes) command(dst []byte, _ description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendStringElement(dst, "dropIndexes", di.collection) - if di.index != nil { - dst = bsoncore.AppendStringElement(dst, "index", *di.index) + + switch di.index.(type) { + case string: + dst = bsoncore.AppendStringElement(dst, "index", di.index.(string)) + case bsoncore.Document: + if di.index != nil { + dst = bsoncore.AppendDocumentElement(dst, "index", di.index.(bsoncore.Document)) + } } + return dst, nil } // Index specifies the name of the index to drop. If '*' is specified, all indexes will be dropped. -func (di *DropIndexes) Index(index string) *DropIndexes { +func (di *DropIndexes) Index(index any) *DropIndexes { if di == nil { di = new(DropIndexes) } - di.index = &index + di.index = index return di } @@ -242,3 +251,13 @@ func (di *DropIndexes) Timeout(timeout *time.Duration) *DropIndexes { di.timeout = timeout return di } + +// Authenticator sets the authenticator to use for this operation. +func (di *DropIndexes) Authenticator(authenticator driver.Authenticator) *DropIndexes { + if di == nil { + di = new(DropIndexes) + } + + di.authenticator = authenticator + return di +} diff --git a/x/mongo/driver/operation/drop_search_index.go b/x/mongo/driver/operation/drop_search_index.go index 3992c83165..3d273434d5 100644 --- a/x/mongo/driver/operation/drop_search_index.go +++ b/x/mongo/driver/operation/drop_search_index.go @@ -21,18 +21,19 @@ import ( // DropSearchIndex performs an dropSearchIndex operation. type DropSearchIndex struct { - index string - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result DropSearchIndexResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index string + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result DropSearchIndexResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropSearchIndexResult represents a dropSearchIndex result returned by the server. @@ -93,6 +94,7 @@ func (dsi *DropSearchIndex) Execute(ctx context.Context) error { Selector: dsi.selector, ServerAPI: dsi.serverAPI, Timeout: dsi.timeout, + Authenticator: dsi.authenticator, }.Execute(ctx) } @@ -212,3 +214,13 @@ func (dsi *DropSearchIndex) Timeout(timeout *time.Duration) *DropSearchIndex { dsi.timeout = timeout return dsi } + +// Authenticator sets the authenticator to use for this operation. +func (dsi *DropSearchIndex) Authenticator(authenticator driver.Authenticator) *DropSearchIndex { + if dsi == nil { + dsi = new(DropSearchIndex) + } + + dsi.authenticator = authenticator + return dsi +} diff --git a/x/mongo/driver/operation/end_sessions.go b/x/mongo/driver/operation/end_sessions.go index 52f300bb7f..8b24b3d8c2 100644 --- a/x/mongo/driver/operation/end_sessions.go +++ b/x/mongo/driver/operation/end_sessions.go @@ -20,15 +20,16 @@ import ( // EndSessions performs an endSessions operation. type EndSessions struct { - sessionIDs bsoncore.Document - session *session.Client - clock *session.ClusterClock - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - serverAPI *driver.ServerAPIOptions + authenticator driver.Authenticator + sessionIDs bsoncore.Document + session *session.Client + clock *session.ClusterClock + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + serverAPI *driver.ServerAPIOptions } // NewEndSessions constructs and returns a new EndSessions. @@ -61,6 +62,7 @@ func (es *EndSessions) Execute(ctx context.Context) error { Selector: es.selector, ServerAPI: es.serverAPI, Name: driverutil.EndSessionsOp, + Authenticator: es.authenticator, }.Execute(ctx) } @@ -161,3 +163,13 @@ func (es *EndSessions) ServerAPI(serverAPI *driver.ServerAPIOptions) *EndSession es.serverAPI = serverAPI return es } + +// Authenticator sets the authenticator to use for this operation. +func (es *EndSessions) Authenticator(authenticator driver.Authenticator) *EndSessions { + if es == nil { + es = new(EndSessions) + } + + es.authenticator = authenticator + return es +} diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index 8950fde86d..c71b7d755e 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -25,6 +25,7 @@ import ( // Find performs a find operation. type Find struct { + authenticator driver.Authenticator allowDiskUse *bool allowPartialResults *bool awaitData *bool @@ -112,6 +113,7 @@ func (f *Find) Execute(ctx context.Context) error { Logger: f.logger, Name: driverutil.FindOp, OmitCSOTMaxTimeMS: f.omitCSOTMaxTimeMS, + Authenticator: f.authenticator, }.Execute(ctx) } @@ -575,3 +577,13 @@ func (f *Find) Logger(logger *logger.Logger) *Find { f.logger = logger return f } + +// Authenticator sets the authenticator to use for this operation. +func (f *Find) Authenticator(authenticator driver.Authenticator) *Find { + if f == nil { + f = new(Find) + } + + f.authenticator = authenticator + return f +} diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 7faf561135..ea365ccb23 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -25,6 +25,7 @@ import ( // FindAndModify performs a findAndModify operation. type FindAndModify struct { + authenticator driver.Authenticator arrayFilters bsoncore.Array bypassDocumentValidation *bool collation bsoncore.Document @@ -145,6 +146,7 @@ func (fam *FindAndModify) Execute(ctx context.Context) error { ServerAPI: fam.serverAPI, Timeout: fam.timeout, Name: driverutil.FindAndModifyOp, + Authenticator: fam.authenticator, }.Execute(ctx) } @@ -477,3 +479,13 @@ func (fam *FindAndModify) Timeout(timeout *time.Duration) *FindAndModify { fam.timeout = timeout return fam } + +// Authenticator sets the authenticator to use for this operation. +func (fam *FindAndModify) Authenticator(authenticator driver.Authenticator) *FindAndModify { + if fam == nil { + fam = new(FindAndModify) + } + + fam.authenticator = authenticator + return fam +} diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 16f2ebf6c0..60c99f063d 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -36,6 +36,7 @@ const driverName = "mongo-go-driver" // Hello is used to run the handshake operation. type Hello struct { + authenticator driver.Authenticator appname string compressors []string saslSupportedMechs string @@ -649,3 +650,13 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, func (h *Hello) FinishHandshake(context.Context, driver.Connection) error { return nil } + +// Authenticator sets the authenticator to use for this operation. +func (h *Hello) Authenticator(authenticator driver.Authenticator) *Hello { + if h == nil { + h = new(Hello) + } + + h.authenticator = authenticator + return h +} diff --git a/x/mongo/driver/operation/insert.go b/x/mongo/driver/operation/insert.go index 7da4b8b0fb..f5afe31169 100644 --- a/x/mongo/driver/operation/insert.go +++ b/x/mongo/driver/operation/insert.go @@ -25,6 +25,7 @@ import ( // Insert performs an insert operation. type Insert struct { + authenticator driver.Authenticator bypassDocumentValidation *bool comment bsoncore.Value documents []bsoncore.Document @@ -115,6 +116,7 @@ func (i *Insert) Execute(ctx context.Context) error { Timeout: i.timeout, Logger: i.logger, Name: driverutil.InsertOp, + Authenticator: i.authenticator, }.Execute(ctx) } @@ -306,3 +308,13 @@ func (i *Insert) Logger(logger *logger.Logger) *Insert { i.logger = logger return i } + +// Authenticator sets the authenticator to use for this operation. +func (i *Insert) Authenticator(authenticator driver.Authenticator) *Insert { + if i == nil { + i = new(Insert) + } + + i.authenticator = authenticator + return i +} diff --git a/x/mongo/driver/operation/listDatabases.go b/x/mongo/driver/operation/listDatabases.go index c70248e2a9..3df171e37a 100644 --- a/x/mongo/driver/operation/listDatabases.go +++ b/x/mongo/driver/operation/listDatabases.go @@ -24,6 +24,7 @@ import ( // ListDatabases performs a listDatabases operation. type ListDatabases struct { + authenticator driver.Authenticator filter bsoncore.Document authorizedDatabases *bool nameOnly *bool @@ -165,6 +166,7 @@ func (ld *ListDatabases) Execute(ctx context.Context) error { ServerAPI: ld.serverAPI, Timeout: ld.timeout, Name: driverutil.ListDatabasesOp, + Authenticator: ld.authenticator, }.Execute(ctx) } @@ -327,3 +329,13 @@ func (ld *ListDatabases) Timeout(timeout *time.Duration) *ListDatabases { ld.timeout = timeout return ld } + +// Authenticator sets the authenticator to use for this operation. +func (ld *ListDatabases) Authenticator(authenticator driver.Authenticator) *ListDatabases { + if ld == nil { + ld = new(ListDatabases) + } + + ld.authenticator = authenticator + return ld +} diff --git a/x/mongo/driver/operation/list_collections.go b/x/mongo/driver/operation/list_collections.go index 6fe68fa033..1e39f5bfbe 100644 --- a/x/mongo/driver/operation/list_collections.go +++ b/x/mongo/driver/operation/list_collections.go @@ -22,6 +22,7 @@ import ( // ListCollections performs a listCollections operation. type ListCollections struct { + authenticator driver.Authenticator filter bsoncore.Document nameOnly *bool authorizedCollections *bool @@ -83,6 +84,7 @@ func (lc *ListCollections) Execute(ctx context.Context) error { ServerAPI: lc.serverAPI, Timeout: lc.timeout, Name: driverutil.ListCollectionsOp, + Authenticator: lc.authenticator, }.Execute(ctx) } @@ -259,3 +261,13 @@ func (lc *ListCollections) Timeout(timeout *time.Duration) *ListCollections { lc.timeout = timeout return lc } + +// Authenticator sets the authenticator to use for this operation. +func (lc *ListCollections) Authenticator(authenticator driver.Authenticator) *ListCollections { + if lc == nil { + lc = new(ListCollections) + } + + lc.authenticator = authenticator + return lc +} diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index 79d50eca95..433344f307 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -21,19 +21,20 @@ import ( // ListIndexes performs a listIndexes operation. type ListIndexes struct { - batchSize *int32 - maxTime *time.Duration - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - database string - deployment driver.Deployment - selector description.ServerSelector - retry *driver.RetryMode - crypt driver.Crypt - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + batchSize *int32 + maxTime *time.Duration + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + database string + deployment driver.Deployment + selector description.ServerSelector + retry *driver.RetryMode + crypt driver.Crypt + serverAPI *driver.ServerAPIOptions + timeout *time.Duration result driver.CursorResponse } @@ -85,6 +86,7 @@ func (li *ListIndexes) Execute(ctx context.Context) error { ServerAPI: li.serverAPI, Timeout: li.timeout, Name: driverutil.ListIndexesOp, + Authenticator: li.authenticator, }.Execute(ctx) } @@ -233,3 +235,13 @@ func (li *ListIndexes) Timeout(timeout *time.Duration) *ListIndexes { li.timeout = timeout return li } + +// Authenticator sets the authenticator to use for this operation. +func (li *ListIndexes) Authenticator(authenticator driver.Authenticator) *ListIndexes { + if li == nil { + li = new(ListIndexes) + } + + li.authenticator = authenticator + return li +} diff --git a/x/mongo/driver/operation/update.go b/x/mongo/driver/operation/update.go index 881b1bcf7b..1070e7ca70 100644 --- a/x/mongo/driver/operation/update.go +++ b/x/mongo/driver/operation/update.go @@ -26,6 +26,7 @@ import ( // Update performs an update operation. type Update struct { + authenticator driver.Authenticator bypassDocumentValidation *bool comment bsoncore.Value ordered *bool @@ -167,6 +168,7 @@ func (u *Update) Execute(ctx context.Context) error { Timeout: u.timeout, Logger: u.logger, Name: driverutil.UpdateOp, + Authenticator: u.authenticator, }.Execute(ctx) } @@ -414,3 +416,13 @@ func (u *Update) Logger(logger *logger.Logger) *Update { u.logger = logger return u } + +// Authenticator sets the authenticator to use for this operation. +func (u *Update) Authenticator(authenticator driver.Authenticator) *Update { + if u == nil { + u = new(Update) + } + + u.authenticator = authenticator + return u +} diff --git a/x/mongo/driver/operation/update_search_index.go b/x/mongo/driver/operation/update_search_index.go index 64f2da7f6f..4ed9946c69 100644 --- a/x/mongo/driver/operation/update_search_index.go +++ b/x/mongo/driver/operation/update_search_index.go @@ -21,19 +21,20 @@ import ( // UpdateSearchIndex performs a updateSearchIndex operation. type UpdateSearchIndex struct { - index string - definition bsoncore.Document - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result UpdateSearchIndexResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index string + definition bsoncore.Document + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result UpdateSearchIndexResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // UpdateSearchIndexResult represents a single index in the updateSearchIndexResult result. @@ -95,6 +96,7 @@ func (usi *UpdateSearchIndex) Execute(ctx context.Context) error { Selector: usi.selector, ServerAPI: usi.serverAPI, Timeout: usi.timeout, + Authenticator: usi.authenticator, }.Execute(ctx) } @@ -225,3 +227,13 @@ func (usi *UpdateSearchIndex) Timeout(timeout *time.Duration) *UpdateSearchIndex usi.timeout = timeout return usi } + +// Authenticator sets the authenticator to use for this operation. +func (usi *UpdateSearchIndex) Authenticator(authenticator driver.Authenticator) *UpdateSearchIndex { + if usi == nil { + usi = new(UpdateSearchIndex) + } + + usi.authenticator = authenticator + return usi +} diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 6445c9d0f6..27ef3a090d 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -789,6 +789,8 @@ func (m *mockConnection) SupportsStreaming() bool { return m.rCanStream func (m *mockConnection) CurrentlyStreaming() bool { return m.rStreaming } func (m *mockConnection) SetStreaming(streaming bool) { m.rStreaming = streaming } func (m *mockConnection) Stale() bool { return false } +func (m *mockConnection) OIDCTokenGenID() uint64 { return 0 } +func (m *mockConnection) SetOIDCTokenGenID(uint64) {} // TODO:(GODRIVER-2824) replace return type with int64. func (m *mockConnection) DriverConnectionID() uint64 { return 0 } diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index 8dac0932de..4a6be9c5e4 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -90,6 +90,8 @@ type LoadBalancedTransactionConnection interface { DriverConnectionID() uint64 // TODO(GODRIVER-2824): change type to int64. Address() address.Address Stale() bool + OIDCTokenGenID() uint64 + SetOIDCTokenGenID(uint64) // Functions copied over from driver.PinnedConnection that are not part of Connection or Expirable. PinToCursor() error diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 649e87b3d1..49a613aef8 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -82,6 +82,10 @@ type connection struct { // awaitingResponse indicates that the server response was not completely // read before returning the connection to the pool. awaitingResponse bool + + // oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate + // accessTokens in the OIDC authenticator cache. + oidcTokenGenID uint64 } // newConnection handles the creation of a connection. It does not connect the connection. @@ -606,6 +610,8 @@ type Connection struct { refCount int cleanupPoolFn func() + oidcTokenGenID uint64 + // cleanupServerFn resets the server state when a connection is returned to the connection pool // via Close() or expired via Expire(). cleanupServerFn func() @@ -860,6 +866,16 @@ func configureTLS(ctx context.Context, return client, nil } +// OIDCTokenGenID returns the OIDC token generation ID. +func (c *Connection) OIDCTokenGenID() uint64 { + return c.oidcTokenGenID +} + +// SetOIDCTokenGenID sets the OIDC token generation ID. +func (c *Connection) SetOIDCTokenGenID(genID uint64) { + c.oidcTokenGenID = genID +} + // TODO: Naming? // cancellListener listens for context cancellation and notifies listeners via a @@ -903,3 +919,11 @@ func (c *cancellListener) StopListening() bool { c.done <- struct{}{} return c.aborted } + +func (c *connection) OIDCTokenGenID() uint64 { + return c.oidcTokenGenID +} + +func (c *connection) SetOIDCTokenGenID(genID uint64) { + c.oidcTokenGenID = genID +} diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index dc774b469b..946f74d8f2 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -236,13 +236,18 @@ func TestConnection(t *testing.T) { conn := newConnection("", connOpts...) var connectErr error - callback := func(ctx context.Context) { - connectCtx, cancel := context.WithTimeout(ctx, tc.contextTimeout) + callback := func() bool { + connectCtx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout) defer cancel() connectErr = conn.connect(connectCtx) + return true } - assert.Soon(t, callback, tc.maxConnectTime) + assert.Eventually(t, + callback, + tc.maxConnectTime, + time.Millisecond, + "expected timeout to apply to socket establishment after maximum connect time") ce, ok := connectErr.(ConnectionError) assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{}) @@ -271,13 +276,18 @@ func TestConnection(t *testing.T) { conn := newConnection(address.Address(l.Addr().String()), connOpts...) var connectErr error - callback := func(ctx context.Context) { - connectCtx, cancel := context.WithTimeout(ctx, tc.contextTimeout) + callback := func() bool { + connectCtx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout) defer cancel() connectErr = conn.connect(connectCtx) + return true } - assert.Soon(t, callback, tc.maxConnectTime) + assert.Eventually(t, + callback, + tc.maxConnectTime, + time.Millisecond, + "expected timeout to apply to TLS handshake after maximum connect time") ce, ok := connectErr.(ConnectionError) assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{}) diff --git a/x/mongo/driver/topology/fsm.go b/x/mongo/driver/topology/fsm.go index 2acf527b9d..1d097b65c7 100644 --- a/x/mongo/driver/topology/fsm.go +++ b/x/mongo/driver/topology/fsm.go @@ -22,7 +22,7 @@ var ( MinSupportedMongoDBVersion = "3.6" // SupportedWireVersions is the range of wire versions supported by the driver. - SupportedWireVersions = description.NewVersionRange(6, 21) + SupportedWireVersions = description.NewVersionRange(6, 25) ) type fsm struct { diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 99f8dd618b..a29eea4a6d 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -666,7 +666,7 @@ func (s *Server) update() { s.monitorOnce.Do(s.rttMonitor.connect) } - if isStreamable(s) || connectionIsStreaming || transitionedFromNetworkError { + if isStreamingEnabled(s) && (isStreamable(s) || connectionIsStreaming) || transitionedFromNetworkError { continue } diff --git a/x/mongo/driver/topology/topology_errors_test.go b/x/mongo/driver/topology/topology_errors_test.go index c09ef9731c..c7dc7336e9 100644 --- a/x/mongo/driver/topology/topology_errors_test.go +++ b/x/mongo/driver/topology/topology_errors_test.go @@ -46,15 +46,21 @@ func TestTopologyErrors(t *testing.T) { assert.Nil(t, err, "error creating topology: %v", err) var serverSelectionErr error - callback := func(ctx context.Context) { - selectServerCtx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + callback := func() bool { + selectServerCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() state := newServerSelectionState(selectNone, make(<-chan time.Time)) subCh := make(<-chan description.Topology) _, serverSelectionErr = topo.selectServerFromSubscription(selectServerCtx, subCh, state) + return true } - assert.Soon(t, callback, 150*time.Millisecond) + assert.Eventually(t, + callback, + 150*time.Millisecond, + time.Millisecond, + "expected context deadline to fail within 150ms") + assert.True(t, errors.Is(serverSelectionErr, context.DeadlineExceeded), "expected %v, received %v", context.DeadlineExceeded, serverSelectionErr) }) diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index b5eb4a9729..0563e5524e 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -72,8 +72,30 @@ func newLogger(opts *options.LoggerOptions) (*logger.Logger, error) { } // NewConfig will translate data from client options into a topology config for building non-default deployments. -// Server and topology options are not honored if a custom deployment is used. func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, error) { + // Auth & Database & Password & Username + if co.Auth != nil { + cred := &auth.Cred{ + Username: co.Auth.Username, + Password: co.Auth.Password, + PasswordSet: co.Auth.PasswordSet, + Props: co.Auth.AuthMechanismProperties, + Source: co.Auth.AuthSource, + } + mechanism := co.Auth.AuthMechanism + authenticator, err := auth.CreateAuthenticator(mechanism, cred, co.HTTPClient) + if err != nil { + return nil, err + } + return NewConfigWithAuthenticator(co, clock, authenticator) + } + return NewConfigWithAuthenticator(co, clock, nil) +} + +// NewConfigWithAuthenticator will translate data from client options into a topology config for building non-default deployments. +// Server and topology options are not honored if a custom deployment is used. It uses a passed in +// authenticator to authenticate the connection. +func NewConfigWithAuthenticator(co *options.ClientOptions, clock *session.ClusterClock, authenticator driver.Authenticator) (*Config, error) { var serverAPI *driver.ServerAPIOptions if err := co.Validate(); err != nil { @@ -180,11 +202,6 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, } } - authenticator, err := auth.CreateAuthenticator(mechanism, cred) - if err != nil { - return nil, err - } - handshakeOpts := &auth.HandshakeOptions{ AppName: appName, Authenticator: authenticator, @@ -192,7 +209,6 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, ServerAPI: serverAPI, LoadBalanced: loadBalanced, ClusterClock: clock, - HTTPClient: co.HTTPClient, } if mechanism == "" { diff --git a/x/mongo/driver/wiremessage/wiremessage.go b/x/mongo/driver/wiremessage/wiremessage.go index fbdd21753f..2199f855ba 100644 --- a/x/mongo/driver/wiremessage/wiremessage.go +++ b/x/mongo/driver/wiremessage/wiremessage.go @@ -15,6 +15,7 @@ package wiremessage import ( "bytes" + "encoding/binary" "strings" "sync/atomic" @@ -238,10 +239,11 @@ func ReadHeader(src []byte) (length, requestID, responseTo int32, opcode OpCode, if len(src) < 16 { return 0, 0, 0, 0, src, false } - length = (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24) - requestID = (int32(src[4]) | int32(src[5])<<8 | int32(src[6])<<16 | int32(src[7])<<24) - responseTo = (int32(src[8]) | int32(src[9])<<8 | int32(src[10])<<16 | int32(src[11])<<24) - opcode = OpCode(int32(src[12]) | int32(src[13])<<8 | int32(src[14])<<16 | int32(src[15])<<24) + + length = readi32unsafe(src) + requestID = readi32unsafe(src[4:]) + responseTo = readi32unsafe(src[8:]) + opcode = OpCode(readi32unsafe(src[12:])) return length, requestID, responseTo, opcode, src[16:], true } @@ -577,12 +579,16 @@ func ReadKillCursorsCursorIDs(src []byte, numIDs int32) (cursorIDs []int64, rem return cursorIDs, src, true } -func appendi32(dst []byte, i32 int32) []byte { - return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24)) +func appendi32(dst []byte, x int32) []byte { + b := []byte{0, 0, 0, 0} + binary.LittleEndian.PutUint32(b, uint32(x)) + return append(dst, b...) } -func appendi64(b []byte, i int64) []byte { - return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24), byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56)) +func appendi64(dst []byte, x int64) []byte { + b := []byte{0, 0, 0, 0, 0, 0, 0, 0} + binary.LittleEndian.PutUint64(b, uint64(x)) + return append(dst, b...) } func appendCString(b []byte, str string) []byte { @@ -594,21 +600,18 @@ func readi32(src []byte) (int32, []byte, bool) { if len(src) < 4 { return 0, src, false } - - return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true + return readi32unsafe(src), src[4:], true } func readi32unsafe(src []byte) int32 { - return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24) + return int32(binary.LittleEndian.Uint32(src)) } func readi64(src []byte) (int64, []byte, bool) { if len(src) < 8 { return 0, src, false } - i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 | - int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56) - return i64, src[8:], true + return int64(binary.LittleEndian.Uint64(src)), src[8:], true } func readcstring(src []byte) (string, []byte, bool) { diff --git a/x/mongo/driver/wiremessage/wiremessage_test.go b/x/mongo/driver/wiremessage/wiremessage_test.go new file mode 100644 index 0000000000..26cb2637a6 --- /dev/null +++ b/x/mongo/driver/wiremessage/wiremessage_test.go @@ -0,0 +1,472 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package wiremessage + +import ( + "math" + "testing" + + "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +func TestAppendHeaderStart(t *testing.T) { + testCases := []struct { + desc string + dst []byte + reqid int32 + respto int32 + opcode OpCode + wantIdx int32 + wantBytes []byte + }{ + { + desc: "OP_MSG", + reqid: 2, + respto: 1, + opcode: OpMsg, + wantIdx: 0, + wantBytes: []byte{0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 221, 7, 0, 0}, + }, + { + desc: "OP_QUERY", + reqid: 2, + respto: 1, + opcode: OpQuery, + wantIdx: 0, + wantBytes: []byte{0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 212, 7, 0, 0}, + }, + { + desc: "non-empty buffer", + dst: []byte{0, 99}, + reqid: 2, + respto: 1, + opcode: OpMsg, + wantIdx: 2, + wantBytes: []byte{0, 99, 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 221, 7, 0, 0}, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + idx, b := AppendHeaderStart(tc.dst, tc.reqid, tc.respto, tc.opcode) + assert.Equal(t, tc.wantIdx, idx, "appended slice index does not match") + assert.Equal(t, tc.wantBytes, b, "appended bytes do not match") + }) + } +} + +func TestReadHeader(t *testing.T) { + testCases := []struct { + desc string + src []byte + wantLength int32 + wantRequestID int32 + wantResponseTo int32 + wantOpcode OpCode + wantRem []byte + wantOK bool + }{ + { + desc: "OP_MSG", + src: []byte{0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 221, 7, 0, 0}, + wantLength: 0, + wantRequestID: 2, + wantResponseTo: 1, + wantOpcode: OpMsg, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "OP_QUERY", + src: []byte{0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 212, 7, 0, 0}, + wantLength: 0, + wantRequestID: 2, + wantResponseTo: 1, + wantOpcode: OpQuery, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "not enough bytes", + src: []byte{0, 99}, + wantLength: 0, + wantRequestID: 0, + wantResponseTo: 0, + wantOpcode: 0, + wantRem: []byte{0, 99}, + wantOK: false, + }, + { + desc: "nil", + src: nil, + wantLength: 0, + wantRequestID: 0, + wantResponseTo: 0, + wantOpcode: 0, + wantRem: nil, + wantOK: false, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + length, requestID, responseTo, opcode, rem, ok := ReadHeader(tc.src) + assert.Equal(t, tc.wantLength, length, "length does not match") + assert.Equal(t, tc.wantRequestID, requestID, "requestID does not match") + assert.Equal(t, tc.wantResponseTo, responseTo, "responseTo does not match") + assert.Equal(t, tc.wantOpcode, opcode, "OpCode does not match") + assert.Equal(t, tc.wantRem, rem, "remaining bytes do not match") + assert.Equal(t, tc.wantOK, ok, "OK does not match") + }) + } +} + +func TestReadMsgSectionDocumentSequence(t *testing.T) { + testCases := []struct { + desc string + src []byte + wantIdentifier string + wantDocs []bsoncore.Document + wantRem []byte + wantOK bool + }{ + { + desc: "valid document sequence", + // Data: | len=17 | "id" | empty doc | empty doc | + src: []byte{17, 0, 0, 0, 105, 100, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0, 0}, + wantIdentifier: "id", + wantDocs: []bsoncore.Document{ + {0x5, 0x0, 0x0, 0x0, 0x0}, + {0x5, 0x0, 0x0, 0x0, 0x0}, + }, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "valid document sequence with remaining bytes", + // Data: | len=17 | "id" | empty doc | empty doc | rem | + src: []byte{17, 0, 0, 0, 105, 100, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 99, 99}, + wantIdentifier: "id", + wantDocs: []bsoncore.Document{ + {0x5, 0x0, 0x0, 0x0, 0x0}, + {0x5, 0x0, 0x0, 0x0, 0x0}, + }, + wantRem: []byte{99, 99}, + wantOK: true, + }, + { + desc: "not enough bytes", + src: []byte{0, 1}, + wantIdentifier: "", + wantDocs: nil, + wantRem: []byte{0, 1}, + wantOK: false, + }, + { + desc: "nil", + src: nil, + wantIdentifier: "", + wantDocs: nil, + wantRem: nil, + wantOK: false, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + identifier, docs, rem, ok := ReadMsgSectionDocumentSequence(tc.src) + assert.Equal(t, tc.wantIdentifier, identifier, "identifier does not match") + assert.Equal(t, tc.wantDocs, docs, "docs do not match") + assert.Equal(t, tc.wantRem, rem, "responseTo does not match") + assert.Equal(t, tc.wantOK, ok, "OK does not match") + }) + } +} + +func TestAppendi32(t *testing.T) { + testCases := []struct { + desc string + dst []byte + x int32 + want []byte + }{ + { + desc: "0", + x: 0, + want: []byte{0, 0, 0, 0}, + }, + { + desc: "1", + x: 1, + want: []byte{1, 0, 0, 0}, + }, + { + desc: "-1", + x: -1, + want: []byte{255, 255, 255, 255}, + }, + { + desc: "max", + x: math.MaxInt32, + want: []byte{255, 255, 255, 127}, + }, + { + desc: "min", + x: math.MinInt32, + want: []byte{0, 0, 0, 128}, + }, + { + desc: "non-empty dst", + dst: []byte{0, 1, 2, 3}, + x: 1, + want: []byte{0, 1, 2, 3, 1, 0, 0, 0}, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + b := appendi32(tc.dst, tc.x) + assert.Equal(t, tc.want, b, "bytes do not match") + }) + } +} + +func TestAppendi64(t *testing.T) { + testCases := []struct { + desc string + dst []byte + x int64 + want []byte + }{ + { + desc: "0", + x: 0, + want: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + }, + { + desc: "1", + x: 1, + want: []byte{1, 0, 0, 0, 0, 0, 0, 0}, + }, + { + desc: "-1", + x: -1, + want: []byte{255, 255, 255, 255, 255, 255, 255, 255}, + }, + { + desc: "max", + x: math.MaxInt64, + want: []byte{255, 255, 255, 255, 255, 255, 255, 127}, + }, + { + desc: "min", + x: math.MinInt64, + want: []byte{0, 0, 0, 0, 0, 0, 0, 128}, + }, + { + desc: "non-empty dst", + dst: []byte{0, 1, 2, 3}, + x: 1, + want: []byte{0, 1, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0}, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + b := appendi64(tc.dst, tc.x) + assert.Equal(t, tc.want, b, "bytes do not match") + }) + } +} + +func TestReadi32(t *testing.T) { + testCases := []struct { + desc string + src []byte + want int32 + wantRem []byte + wantOK bool + }{ + { + desc: "0", + src: []byte{0, 0, 0, 0}, + want: 0, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "1", + src: []byte{1, 0, 0, 0}, + want: 1, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "-1", + src: []byte{255, 255, 255, 255}, + want: -1, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "max", + src: []byte{255, 255, 255, 127}, + want: math.MaxInt32, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "min", + src: []byte{0, 0, 0, 128}, + want: math.MinInt32, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "non-empty remaining", + src: []byte{1, 0, 0, 0, 0, 1, 2, 3}, + want: 1, + wantRem: []byte{0, 1, 2, 3}, + wantOK: true, + }, + { + desc: "not enough bytes", + src: []byte{0, 1, 2}, + want: 0, + wantRem: []byte{0, 1, 2}, + wantOK: false, + }, + { + desc: "nil", + src: nil, + want: 0, + wantRem: nil, + wantOK: false, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + x, rem, ok := readi32(tc.src) + assert.Equal(t, tc.want, x, "int32 result does not match") + assert.Equal(t, tc.wantRem, rem, "remaining bytes do not match") + assert.Equal(t, tc.wantOK, ok, "OK does not match") + }) + } +} + +func TestReadi64(t *testing.T) { + testCases := []struct { + desc string + src []byte + want int64 + wantRem []byte + wantOK bool + }{ + { + desc: "0", + src: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + want: 0, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "1", + src: []byte{1, 0, 0, 0, 0, 0, 0, 0}, + want: 1, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "-1", + src: []byte{255, 255, 255, 255, 255, 255, 255, 255}, + want: -1, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "max", + src: []byte{255, 255, 255, 255, 255, 255, 255, 127}, + want: math.MaxInt64, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "min", + src: []byte{0, 0, 0, 0, 0, 0, 0, 128}, + want: math.MinInt64, + wantRem: []byte{}, + wantOK: true, + }, + { + desc: "non-empty remaining", + src: []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3}, + want: 1, + wantRem: []byte{0, 1, 2, 3}, + wantOK: true, + }, + { + desc: "not enough bytes", + src: []byte{0, 1, 2, 3, 4, 5, 6}, + want: 0, + wantRem: []byte{0, 1, 2, 3, 4, 5, 6}, + wantOK: false, + }, + { + desc: "not enough bytes", + src: []byte{0, 1, 2, 3, 4, 5, 6}, + want: 0, + wantRem: []byte{0, 1, 2, 3, 4, 5, 6}, + wantOK: false, + }, + { + desc: "nil", + src: nil, + want: 0, + wantRem: nil, + wantOK: false, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + x, rem, ok := readi64(tc.src) + assert.Equal(t, tc.want, x, "int64 result does not match") + assert.Equal(t, tc.wantRem, rem, "remaining bytes do not match") + assert.Equal(t, tc.wantOK, ok, "OK does not match") + }) + } +}