diff --git a/go.mod b/go.mod index 045a5b2476..803d986a22 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,13 @@ go 1.21 require ( github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 + github.com/docker/go-units v0.5.0 github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.5.3 github.com/google/btree v1.1.2 github.com/google/uuid v1.3.1 github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 + github.com/ninedraft/israce v0.0.3 github.com/opentracing/opentracing-go v1.2.0 github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c github.com/pingcap/failpoint v0.0.0-20220801062533-2eaa32854a6c diff --git a/go.sum b/go.sum index 54e9974c8e..1979f44a86 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -58,6 +60,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/ninedraft/israce v0.0.3 h1:F/Y1u6OlvgE75Syv1WbBatyg3CjGCdxLojLE7ydv2yE= +github.com/ninedraft/israce v0.0.3/go.mod h1:4L1ITFl340650ZmexVbUcBwG18ozlWiMe47pltZAmn4= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= @@ -112,8 +116,6 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a h1:J/YdBZ46WKpXsxsW93SG+q0F8KI+yFrcIDT4c/RNoc4= github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= -github.com/tikv/pd/client v0.0.0-20231116062916-ef6ba8551e52 h1:wucAo/ks8INgayRVfbrzZ+BSWEwRLETj0XfngDcrZ4k= -github.com/tikv/pd/client v0.0.0-20231116062916-ef6ba8551e52/go.mod h1:cd6zBqRM9aogxf26K8NnFRPVtq9BnRE59tKEpX8IaWQ= github.com/tikv/pd/client v0.0.0-20240202093025-6978558f4e97 h1:LlsMm3/IIRRXrip19M0yYb4bZ7BvdQgzF77csNnGK1Y= github.com/tikv/pd/client v0.0.0-20240202093025-6978558f4e97/go.mod h1:AwjTSpM7CgAynYwB6qTG5R5fVC9/eXlQXiTO6zDL1HI= github.com/twmb/murmur3 v1.1.3 h1:D83U0XYKcHRYwYIpBKf3Pks91Z0Byda/9SJ8B6EMRcA= diff --git a/integration_tests/go.mod b/integration_tests/go.mod index a73de3fd1a..3688f8baab 100644 --- a/integration_tests/go.mod +++ b/integration_tests/go.mod @@ -25,7 +25,7 @@ require ( github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/BurntSushi/toml v1.3.2 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudfoundry/gosigar v1.3.6 // indirect github.com/cockroachdb/errors v1.8.1 // indirect github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f // indirect diff --git a/integration_tests/go.sum b/integration_tests/go.sum index 39e3a5b00b..703a243d34 100644 --- a/integration_tests/go.sum +++ b/integration_tests/go.sum @@ -65,8 +65,8 @@ github.com/cenkalti/backoff/v4 v4.1.1 h1:G2HAfAmvm/GcKan2oOQpBXOd2tT2G57ZnZGWa1P github.com/cenkalti/backoff/v4 v4.1.1/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cheggaaa/pb/v3 v3.0.8 h1:bC8oemdChbke2FHIIGy9mn4DPJ2caZYQnfbRqwmdCoA= github.com/cheggaaa/pb/v3 v3.0.8/go.mod h1:UICbiLec/XO6Hw6k+BHEtHeQFzzBH4i2/qk/ow1EJTA= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= @@ -118,8 +118,8 @@ github.com/dgryski/go-farm v0.0.0-20190104051053-3adb47b1fb0f/go.mod h1:SqUrOPUn github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= -github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= -github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= github.com/dolthub/swiss v0.2.1 h1:gs2osYs5SJkAaH5/ggVJqXQxRXtWshF6uE0lgR/Y3Gw= diff --git a/internal/locate/region_cache_test.go b/internal/locate/region_cache_test.go index 15a1865fe5..e012dbf4e9 100644 --- a/internal/locate/region_cache_test.go +++ b/internal/locate/region_cache_test.go @@ -55,6 +55,7 @@ import ( "github.com/tikv/client-go/v2/internal/mockstore/mocktikv" "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" pd "github.com/tikv/pd/client" ) @@ -1055,7 +1056,7 @@ func (s *testRegionCacheSuite) TestRegionEpochOnTiFlash() { s.Equal(ctxTiFlash.Peer.Id, s.peer1) ctxTiFlash.Peer.Role = metapb.PeerRole_Learner r := ctxTiFlash.Meta - reqSend := NewRegionRequestSender(s.cache, nil) + reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{}) regionErr := &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{CurrentRegions: []*metapb.Region{r}}} reqSend.onRegionError(s.bo, ctxTiFlash, nil, regionErr) @@ -1691,7 +1692,7 @@ func (s *testRegionCacheSuite) TestShouldNotRetryFlashback() { ctx, err := s.cache.GetTiKVRPCContext(retry.NewBackofferWithVars(context.Background(), 100, nil), loc.Region, kv.ReplicaReadLeader, 0) s.NotNil(ctx) s.NoError(err) - reqSend := NewRegionRequestSender(s.cache, nil) + reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{}) shouldRetry, err := reqSend.onRegionError(s.bo, ctx, nil, &errorpb.Error{FlashbackInProgress: &errorpb.FlashbackInProgress{}}) s.Error(err) s.False(shouldRetry) diff --git a/internal/locate/region_request.go b/internal/locate/region_request.go index da7b1ab239..d4f335cc09 100644 --- a/internal/locate/region_request.go +++ b/internal/locate/region_request.go @@ -46,6 +46,7 @@ import ( "sync/atomic" "time" + "github.com/tikv/client-go/v2/oracle" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -105,6 +106,7 @@ type RegionRequestSender struct { regionCache *RegionCache apiVersion kvrpcpb.APIVersion client client.Client + readTSValidator oracle.ReadTSValidator storeAddr string rpcError error replicaSelector *replicaSelector @@ -372,11 +374,12 @@ func (s *ReplicaAccessStats) String() string { } // NewRegionRequestSender creates a new sender. -func NewRegionRequestSender(regionCache *RegionCache, client client.Client) *RegionRequestSender { +func NewRegionRequestSender(regionCache *RegionCache, client client.Client, readTSValidator oracle.ReadTSValidator) *RegionRequestSender { return &RegionRequestSender{ - regionCache: regionCache, - apiVersion: regionCache.codec.GetAPIVersion(), - client: client, + regionCache: regionCache, + apiVersion: regionCache.codec.GetAPIVersion(), + client: client, + readTSValidator: readTSValidator, } } @@ -1558,6 +1561,11 @@ func (s *RegionRequestSender) SendReqCtx( } } + if err = s.validateReadTS(bo.GetCtx(), req); err != nil { + logutil.Logger(bo.GetCtx()).Error("validate read ts failed for request", zap.Stringer("reqType", req.Type), zap.Stringer("req", req.Req.(fmt.Stringer)), zap.Stringer("context", &req.Context), zap.Stack("stack"), zap.Error(err)) + return nil, nil, 0, err + } + // If the MaxExecutionDurationMs is not set yet, we set it to be the RPC timeout duration // so TiKV can give up the requests whose response TiDB cannot receive due to timeout. if req.Context.MaxExecutionDurationMs == 0 { @@ -2535,6 +2543,44 @@ func (s *RegionRequestSender) onRegionError( return false, nil } +func (s *RegionRequestSender) validateReadTS(ctx context.Context, req *tikvrpc.Request) error { + if req.StoreTp == tikvrpc.TiDB { + // Skip the checking if the store type is TiDB. + return nil + } + + var readTS uint64 + switch req.Type { + case tikvrpc.CmdGet, tikvrpc.CmdScan, tikvrpc.CmdBatchGet, tikvrpc.CmdCop, tikvrpc.CmdCopStream, tikvrpc.CmdBatchCop, tikvrpc.CmdScanLock: + readTS = req.GetStartTS() + + // TODO: Check transactional write requests that has implicit read. + // case tikvrpc.CmdPessimisticLock: + // readTS = req.PessimisticLock().GetForUpdateTs() + // case tikvrpc.CmdPrewrite: + // inner := req.Prewrite() + // readTS = inner.GetForUpdateTs() + // if readTS == 0 { + // readTS = inner.GetStartVersion() + // } + // case tikvrpc.CmdCheckTxnStatus: + // inner := req.CheckTxnStatus() + // // TiKV uses the greater one of these three fields to update the max_ts. + // readTS = inner.GetLockTs() + // if inner.GetCurrentTs() != math.MaxUint64 && inner.GetCurrentTs() > readTS { + // readTS = inner.GetCurrentTs() + // } + // if inner.GetCallerStartTs() != math.MaxUint64 && inner.GetCallerStartTs() > readTS { + // readTS = inner.GetCallerStartTs() + // } + // case tikvrpc.CmdCheckSecondaryLocks, tikvrpc.CmdCleanup, tikvrpc.CmdBatchRollback: + // readTS = req.GetStartTS() + default: + return nil + } + return s.readTSValidator.ValidateReadTS(ctx, readTS, req.StaleRead, &oracle.Option{TxnScope: req.TxnScope}) +} + type staleReadMetricsCollector struct { } diff --git a/internal/locate/region_request3_test.go b/internal/locate/region_request3_test.go index 8d3fa4bbf8..13780afa4b 100644 --- a/internal/locate/region_request3_test.go +++ b/internal/locate/region_request3_test.go @@ -85,7 +85,7 @@ func (s *testRegionRequestToThreeStoresSuite) SetupTest() { s.cache = NewRegionCache(pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) - s.regionRequestSender = NewRegionRequestSender(s.cache, client) + s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{}) } func (s *testRegionRequestToThreeStoresSuite) TearDownTest() { @@ -150,7 +150,8 @@ func (s *testRegionRequestToThreeStoresSuite) loadAndGetLeaderStore() (*Store, s } func (s *testRegionRequestToThreeStoresSuite) TestForwarding() { - s.regionRequestSender.regionCache.enableForwarding = true + sender := NewRegionRequestSender(s.cache, s.regionRequestSender.client, oracle.NoopReadTSValidator{}) + sender.regionCache.enableForwarding = true // First get the leader's addr from region cache leaderStore, leaderAddr := s.loadAndGetLeaderStore() @@ -1240,7 +1241,7 @@ func (s *testRegionRequestToThreeStoresSuite) TestSendReqFirstTimeout() { } resetStats := func() { reqTargetAddrs = make(map[string]struct{}) - s.regionRequestSender = NewRegionRequestSender(s.cache, mockRPCClient) + s.regionRequestSender = NewRegionRequestSender(s.cache, mockRPCClient, oracle.NoopReadTSValidator{}) s.regionRequestSender.Stats = NewRegionRequestRuntimeStats() } @@ -1564,7 +1565,7 @@ func (s *testRegionRequestToThreeStoresSuite) TestStaleReadTryFollowerAfterTimeo } return &tikvrpc.Response{Resp: &kvrpcpb.GetResponse{Value: []byte("value")}}, nil }} - s.regionRequestSender = NewRegionRequestSender(s.cache, mockRPCClient) + s.regionRequestSender = NewRegionRequestSender(s.cache, mockRPCClient, oracle.NoopReadTSValidator{}) s.regionRequestSender.Stats = NewRegionRequestRuntimeStats() getLocFn := func() *KeyLocation { loc, err := s.regionRequestSender.regionCache.LocateKey(bo, []byte("a")) diff --git a/internal/locate/region_request_state_test.go b/internal/locate/region_request_state_test.go index 826bd5a43e..0d742985cf 100644 --- a/internal/locate/region_request_state_test.go +++ b/internal/locate/region_request_state_test.go @@ -34,6 +34,7 @@ import ( "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" ) @@ -76,7 +77,7 @@ func (s *testRegionCacheStaleReadSuite) SetupTest() { s.cache = NewRegionCache(pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) - s.regionRequestSender = NewRegionRequestSender(s.cache, client) + s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{}) s.setClient() s.injection = testRegionCacheFSMSuiteInjection{ unavailableStoreIDs: make(map[uint64]struct{}), diff --git a/internal/locate/region_request_test.go b/internal/locate/region_request_test.go index 40133913f2..580b6a0108 100644 --- a/internal/locate/region_request_test.go +++ b/internal/locate/region_request_test.go @@ -62,7 +62,9 @@ import ( "github.com/tikv/client-go/v2/internal/client/mock_server" "github.com/tikv/client-go/v2/internal/mockstore/mocktikv" "github.com/tikv/client-go/v2/internal/retry" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" + pd "github.com/tikv/pd/client" pderr "github.com/tikv/pd/client/errs" "google.golang.org/grpc" ) @@ -77,6 +79,7 @@ type testRegionRequestToSingleStoreSuite struct { store uint64 peer uint64 region uint64 + pdCli pd.Client cache *RegionCache bo *retry.Backoffer regionRequestSender *RegionRequestSender @@ -87,11 +90,11 @@ func (s *testRegionRequestToSingleStoreSuite) SetupTest() { s.mvccStore = mocktikv.MustNewMVCCStore() s.cluster = mocktikv.NewCluster(s.mvccStore) s.store, s.peer, s.region = mocktikv.BootstrapWithSingleStore(s.cluster) - pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster), apicodec.NewCodecV1(apicodec.ModeTxn)} - s.cache = NewRegionCache(pdCli) + s.pdCli = &CodecPDClient{mocktikv.NewPDClient(s.cluster), apicodec.NewCodecV1(apicodec.ModeTxn)} + s.cache = NewRegionCache(s.pdCli) s.bo = retry.NewNoopBackoff(context.Background()) client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil) - s.regionRequestSender = NewRegionRequestSender(s.cache, client) + s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{}) } func (s *testRegionRequestToSingleStoreSuite) TearDownTest() { @@ -573,7 +576,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa }() cli := client.NewRPCClient() - sender := NewRegionRequestSender(s.cache, cli) + sender := NewRegionRequestSender(s.cache, cli, oracle.NoopReadTSValidator{}) req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{ Key: []byte("key"), Value: []byte("value"), @@ -592,7 +595,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa Client: client.NewRPCClient(), redirectAddr: addr, } - sender = NewRegionRequestSender(s.cache, client1) + sender = NewRegionRequestSender(s.cache, client1, oracle.NoopReadTSValidator{}) sender.SendReq(s.bo, req, region.Region, 3*time.Second) // cleanup @@ -812,7 +815,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestBatchClientSendLoopPanic() { cancel() }() req := tikvrpc.NewRequest(tikvrpc.CmdCop, &coprocessor.Request{Data: []byte("a"), StartTs: 1}) - regionRequestSender := NewRegionRequestSender(s.cache, fnClient) + regionRequestSender := NewRegionRequestSender(s.cache, fnClient, oracle.NoopReadTSValidator{}) regionRequestSender.regionCache.testingKnobs.mockRequestLiveness.Store((*livenessFunc)(&tf)) regionRequestSender.SendReq(bo, req, region.Region, client.ReadTimeoutShort) } diff --git a/oracle/oracle.go b/oracle/oracle.go index 7ace335ec0..88de9d3ae5 100644 --- a/oracle/oracle.go +++ b/oracle/oracle.go @@ -36,6 +36,7 @@ package oracle import ( "context" + "fmt" "time" ) @@ -64,12 +65,17 @@ type Oracle interface { GetExternalTimestamp(ctx context.Context) (uint64, error) SetExternalTimestamp(ctx context.Context, ts uint64) error - // ValidateSnapshotReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts - // that has been allocated by the oracle, so that it's safe to use this ts to perform snapshot read, stale read, - // etc. + ReadTSValidator +} + +// ReadTSValidator is the interface for providing the ability for verifying whether a timestamp is safe to be used +// for readings, as part of the `Oracle` interface. +type ReadTSValidator interface { + // ValidateReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts + // that has been allocated by the oracle, so that it's safe to use this ts to perform read operations. // Note that this method only checks the ts from the oracle's perspective. It doesn't check whether the snapshot // has been GCed. - ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *Option) error + ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error } // Future is a future which promises to return a timestamp. @@ -121,3 +127,27 @@ func GoTimeToTS(t time.Time) uint64 { func GoTimeToLowerLimitStartTS(now time.Time, maxTxnTimeUse int64) uint64 { return GoTimeToTS(now.Add(-time.Duration(maxTxnTimeUse) * time.Millisecond)) } + +// NoopReadTSValidator is a dummy implementation of ReadTSValidator that always let the validation pass. +// Only use this when using RPCs that are not related to ts (e.g. rawkv), or in tests where `Oracle` is not available +// and the validation is not necessary. +type NoopReadTSValidator struct{} + +func (NoopReadTSValidator) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error { + return nil +} + +type ErrFutureTSRead struct { + ReadTS uint64 + CurrentTS uint64 +} + +func (e ErrFutureTSRead) Error() string { + return fmt.Sprintf("cannot set read timestamp to a future time, readTS: %d, currentTS: %d", e.ReadTS, e.CurrentTS) +} + +type ErrLatestStaleRead struct{} + +func (ErrLatestStaleRead) Error() string { + return "cannot set read ts to max uint64 for stale read" +} diff --git a/oracle/oracles/export_test.go b/oracle/oracles/export_test.go index 08df25783d..78e7c0a8bb 100644 --- a/oracle/oracles/export_test.go +++ b/oracle/oracles/export_test.go @@ -65,6 +65,6 @@ func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) { case *pdOracle: lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, &atomic.Pointer[lastTSO]{}) lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO]) - lastTSPointer.Store(&lastTSO{tso: ts, arrival: ts}) + lastTSPointer.Store(&lastTSO{tso: ts, arrival: oracle.GetTimeFromTS(ts)}) } } diff --git a/oracle/oracles/local.go b/oracle/oracles/local.go index 1e6b747c98..e916286ac3 100644 --- a/oracle/oracles/local.go +++ b/oracle/oracles/local.go @@ -36,6 +36,7 @@ package oracles import ( "context" + "math" "sync" "time" @@ -136,13 +137,23 @@ func (l *localOracle) GetExternalTimestamp(ctx context.Context) (uint64, error) return l.getExternalTimestamp(ctx) } -func (l *localOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { +func (l *localOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error { + if readTS == math.MaxUint64 { + if isStaleRead { + return oracle.ErrLatestStaleRead{} + } + return nil + } + currentTS, err := l.GetTimestamp(ctx, opt) if err != nil { return errors.Errorf("fail to validate read timestamp: %v", err) } if currentTS < readTS { - return errors.Errorf("cannot set read timestamp to a future time") + return oracle.ErrFutureTSRead{ + ReadTS: readTS, + CurrentTS: currentTS, + } } return nil } diff --git a/oracle/oracles/mock.go b/oracle/oracles/mock.go index 183b4c2d60..da8874d5c8 100644 --- a/oracle/oracles/mock.go +++ b/oracle/oracles/mock.go @@ -36,6 +36,7 @@ package oracles import ( "context" + "math" "sync" "time" @@ -122,13 +123,27 @@ func (o *MockOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *or return o.GetTimestampAsync(ctx, opt) } -func (o *MockOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { +func (o *MockOracle) SetLowResolutionTimestampUpdateInterval(time.Duration) error { + return nil +} + +func (o *MockOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error { + if readTS == math.MaxUint64 { + if isStaleRead { + return oracle.ErrLatestStaleRead{} + } + return nil + } + currentTS, err := o.GetTimestamp(ctx, opt) if err != nil { return errors.Errorf("fail to validate read timestamp: %v", err) } if currentTS < readTS { - return errors.Errorf("cannot set read timestamp to a future time") + return oracle.ErrFutureTSRead{ + ReadTS: readTS, + CurrentTS: currentTS, + } } return nil } diff --git a/oracle/oracles/pd.go b/oracle/oracles/pd.go index 6e7fb9b6f1..805d1b5c7a 100644 --- a/oracle/oracles/pd.go +++ b/oracle/oracles/pd.go @@ -37,6 +37,7 @@ package oracles import ( "context" "fmt" + "math" "strings" "sync" "sync/atomic" @@ -149,7 +150,7 @@ type pdOracle struct { // When the low resolution ts is not new enough and there are many concurrent stane read / snapshot read // operations that needs to validate the read ts, we can use this to avoid too many concurrent GetTS calls by - // reusing a result for different `ValidateSnapshotReadTS` calls. This can be done because that + // reusing a result for different `ValidateReadTS` calls. This can be done because that // we don't require the ts for validation to be strictly the latest one. // Note that the result can't be reused for different txnScopes. The txnScope is used as the key. tsForValidation singleflight.Group @@ -158,7 +159,7 @@ type pdOracle struct { // lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched. type lastTSO struct { tso uint64 - arrival uint64 + arrival time.Time } type PDOracleOptions struct { @@ -272,17 +273,13 @@ func (o *pdOracle) getTimestamp(ctx context.Context, txnScope string) (uint64, e return oracle.ComposeTS(physical, logical), nil } -func (o *pdOracle) getArrivalTimestamp() uint64 { - return oracle.GoTimeToTS(time.Now()) -} - func (o *pdOracle) setLastTS(ts uint64, txnScope string) { if txnScope == "" { txnScope = oracle.GlobalTxnScope } current := &lastTSO{ tso: ts, - arrival: o.getArrivalTimestamp(), + arrival: time.Now(), } lastTSInterface, ok := o.lastTSMap.Load(txnScope) if !ok { @@ -294,9 +291,12 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) { lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO]) for { last := lastTSPointer.Load() - if current.tso <= last.tso || current.arrival <= last.arrival { + if current.tso <= last.tso { return } + if last.arrival.After(current.arrival) { + current.arrival = last.arrival + } if lastTSPointer.CompareAndSwap(last, current) { return } @@ -561,8 +561,11 @@ func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64 if !ok { return 0, errors.Errorf("get stale timestamp fail, txnScope: %s", txnScope) } - ts, arrivalTS := last.tso, last.arrival - arrivalTime := oracle.GetTimeFromTS(arrivalTS) + return o.getStaleTimestampWithLastTS(last, prevSecond) +} + +func (o *pdOracle) getStaleTimestampWithLastTS(last *lastTSO, prevSecond uint64) (uint64, error) { + ts, arrivalTime := last.tso, last.arrival physicalTime := oracle.GetTimeFromTS(ts) if uint64(physicalTime.Unix()) <= prevSecond { return 0, errors.Errorf("invalid prevSecond %v", prevSecond) @@ -617,22 +620,34 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op } } -func (o *pdOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { - latestTS, err := o.GetLowResolutionTimestamp(ctx, opt) - // If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double-check. +func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) (errRet error) { + if readTS == math.MaxUint64 { + if isStaleRead { + return oracle.ErrLatestStaleRead{} + } + return nil + } + + latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope) + // If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check. // But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function // loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls. - if err != nil || readTS > latestTS { + if !exists || readTS > latestTSInfo.tso { currentTS, err := o.getCurrentTSForValidation(ctx, opt) if err != nil { return errors.Errorf("fail to validate read timestamp: %v", err) } - o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now()) + if isStaleRead { + o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now()) + } if readTS > currentTS { - return errors.Errorf("cannot set read timestamp to a future time") + return oracle.ErrFutureTSRead{ + ReadTS: readTS, + CurrentTS: currentTS, + } } - } else { - estimatedCurrentTS, err := o.getStaleTimestamp(opt.TxnScope, 0) + } else if isStaleRead { + estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0) if err != nil { logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval", zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope)) @@ -643,6 +658,9 @@ func (o *pdOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, op return nil } +// adjustUpdateLowResolutionTSIntervalWithRequestedStaleness triggers adjustments the update interval of low resolution +// ts, if necessary, to suite the usage of stale read. +// This method is not supposed to be called when performing non-stale-read operations. func (o *pdOracle) adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS uint64, currentTS uint64, now time.Time) { requiredStaleness := oracle.GetTimeFromTS(currentTS).Sub(oracle.GetTimeFromTS(readTS)) diff --git a/oracle/oracles/pd_test.go b/oracle/oracles/pd_test.go index 48739fd5ec..25345f3b85 100644 --- a/oracle/oracles/pd_test.go +++ b/oracle/oracles/pd_test.go @@ -237,40 +237,54 @@ func TestAdaptiveUpdateTSInterval(t *testing.T) { assert.Equal(t, adaptiveUpdateTSIntervalStateNormal, o.adaptiveUpdateIntervalState.state) } -func TestValidateSnapshotReadTS(t *testing.T) { - pdClient := MockPdClient{} - o, err := NewPdOracle(&pdClient, &PDOracleOptions{ - UpdateInterval: time.Second * 2, - }) - assert.NoError(t, err) - defer o.Close() - - ctx := context.Background() - opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} - ts, err := o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - assert.GreaterOrEqual(t, ts, uint64(1)) +func TestValidateReadTS(t *testing.T) { + testImpl := func(staleRead bool) { + pdClient := MockPdClient{} + o, err := NewPdOracle(&pdClient, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + }) + assert.NoError(t, err) + defer o.Close() + + ctx := context.Background() + opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} + + // Always returns error for MaxUint64 + err = o.ValidateReadTS(ctx, math.MaxUint64, staleRead, opt) + if staleRead { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } - err = o.ValidateSnapshotReadTS(ctx, 1, opt) - assert.NoError(t, err) - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - // The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to - // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. - err = o.ValidateSnapshotReadTS(ctx, ts+1, opt) - assert.NoError(t, err) - // It can't pass if the readTS is newer than previous ts + 2. - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - err = o.ValidateSnapshotReadTS(ctx, ts+2, opt) - assert.Error(t, err) + ts, err := o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + assert.GreaterOrEqual(t, ts, uint64(1)) + + err = o.ValidateReadTS(ctx, 1, staleRead, opt) + assert.NoError(t, err) + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + // The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to + // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. + err = o.ValidateReadTS(ctx, ts+1, staleRead, opt) + assert.NoError(t, err) + // It can't pass if the readTS is newer than previous ts + 2. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + err = o.ValidateReadTS(ctx, ts+2, staleRead, opt) + assert.Error(t, err) + + // Simulate other PD clients requests a timestamp. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + pdClient.logicalTimestamp.Add(2) + err = o.ValidateReadTS(ctx, ts+3, staleRead, opt) + assert.NoError(t, err) + } - // Simulate other PD clients requests a timestamp. - ts, err = o.GetTimestamp(ctx, opt) - assert.NoError(t, err) - pdClient.logicalTimestamp.Add(2) - err = o.ValidateSnapshotReadTS(ctx, ts+3, opt) - assert.NoError(t, err) + testImpl(true) + testImpl(false) } type MockPDClientWithPause struct { @@ -292,7 +306,7 @@ func (c *MockPDClientWithPause) Resume() { c.mu.Unlock() } -func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { +func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) { pdClient := &MockPDClientWithPause{} o, err := NewPdOracle(pdClient, &PDOracleOptions{ UpdateInterval: time.Second * 2, @@ -304,7 +318,7 @@ func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { asyncValidate := func(ctx context.Context, readTS uint64) chan error { ch := make(chan error, 1) go func() { - err := o.ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + err := o.ValidateReadTS(ctx, readTS, true, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) ch <- err }() return ch @@ -313,7 +327,7 @@ func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { noResult := func(ch chan error) { select { case <-ch: - assert.FailNow(t, "a ValidateSnapshotReadTS operation is not blocked while it's expected to be blocked") + assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked") default: } } @@ -391,3 +405,79 @@ func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { } } } + +func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) { + oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + o := oracleInterface.(*pdOracle) + defer o.Close() + + ctx := context.Background() + opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} + + // Validating read ts for non-stale-read requests must not trigger updating the adaptive update interval of + // low resolution ts. + mustNoNotify := func() { + select { + case <-o.adaptiveUpdateIntervalState.shrinkIntervalCh: + assert.Fail(t, "expects not notifying shrinking update interval immediately, but message was received") + default: + } + } + + ts, err := o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + assert.GreaterOrEqual(t, ts, uint64(1)) + + err = o.ValidateReadTS(ctx, ts, false, opt) + assert.NoError(t, err) + mustNoNotify() + + // It loads `ts + 1` from the mock PD, and the check cannot pass. + err = o.ValidateReadTS(ctx, ts+2, false, opt) + assert.Error(t, err) + mustNoNotify() + + // Do the check again. It loads `ts + 2` from the mock PD, and the check passes. + err = o.ValidateReadTS(ctx, ts+2, false, opt) + assert.NoError(t, err) + mustNoNotify() +} + +func TestSetLastTSAlwaysPushTS(t *testing.T) { + oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + o := oracleInterface.(*pdOracle) + defer o.Close() + + var wg sync.WaitGroup + cancel := make(chan struct{}) + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ctx := context.Background() + for { + select { + case <-cancel: + return + default: + } + ts, err := o.GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + assert.NoError(t, err) + lastTS, found := o.getLastTS(oracle.GlobalTxnScope) + assert.True(t, found) + assert.GreaterOrEqual(t, lastTS, ts) + } + }() + } + time.Sleep(time.Second) + close(cancel) + wg.Wait() +} diff --git a/rawkv/rawkv.go b/rawkv/rawkv.go index ddae7289b3..be965117c4 100644 --- a/rawkv/rawkv.go +++ b/rawkv/rawkv.go @@ -48,6 +48,7 @@ import ( "github.com/tikv/client-go/v2/internal/locate" "github.com/tikv/client-go/v2/internal/retry" "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" "github.com/tikv/client-go/v2/tikvrpc" pd "github.com/tikv/pd/client" @@ -687,7 +688,7 @@ func (c *Client) CompareAndSwap(ctx context.Context, key, previousValue, newValu func (c *Client) sendReq(ctx context.Context, key []byte, req *tikvrpc.Request, reverse bool) (*tikvrpc.Response, *locate.KeyLocation, error) { bo := retry.NewBackofferWithVars(ctx, rawkvMaxBackoff, nil) - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) for { var loc *locate.KeyLocation var err error @@ -784,7 +785,7 @@ func (c *Client) doBatchReq(bo *retry.Backoffer, batch kvrpc.Batch, options *raw }) } - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) req.MaxExecutionDurationMs = uint64(client.MaxWriteExecutionTime.Milliseconds()) resp, _, err := sender.SendReq(bo, req, batch.RegionID, client.ReadTimeoutShort) @@ -834,7 +835,7 @@ func (c *Client) doBatchReq(bo *retry.Backoffer, batch kvrpc.Batch, options *raw // TODO: Is there any better way to avoid duplicating code with func `sendReq` ? func (c *Client) sendDeleteRangeReq(ctx context.Context, startKey []byte, endKey []byte, opts *rawOptions) (*tikvrpc.Response, []byte, error) { bo := retry.NewBackofferWithVars(ctx, rawkvMaxBackoff, nil) - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) for { loc, err := c.regionCache.LocateKey(bo, startKey) if err != nil { @@ -936,7 +937,7 @@ func (c *Client) doBatchPut(bo *retry.Backoffer, batch kvrpc.Batch, opts *rawOpt Ttl: ttl, }) - sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient) + sender := locate.NewRegionRequestSender(c.regionCache, c.rpcClient, oracle.NoopReadTSValidator{}) req.MaxExecutionDurationMs = uint64(client.MaxWriteExecutionTime.Milliseconds()) req.ApiVersion = c.apiVersion resp, _, err := sender.SendReq(bo, req, batch.RegionID, client.ReadTimeoutShort) diff --git a/tikv/kv.go b/tikv/kv.go index aaf267784f..51a16accf6 100644 --- a/tikv/kv.go +++ b/tikv/kv.go @@ -462,7 +462,7 @@ func (s *KVStore) SupportDeleteRange() (supported bool) { func (s *KVStore) SendReq( bo *Backoffer, req *tikvrpc.Request, regionID locate.RegionVerID, timeout time.Duration, ) (*tikvrpc.Response, error) { - sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient()) + sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient(), s.oracle) resp, _, err := sender.SendReq(bo, req, regionID, timeout) return resp, err } diff --git a/tikv/region.go b/tikv/region.go index e935657c2e..2a3e2f9b3d 100644 --- a/tikv/region.go +++ b/tikv/region.go @@ -41,6 +41,7 @@ import ( "github.com/tikv/client-go/v2/internal/apicodec" "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/locate" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" pd "github.com/tikv/pd/client" ) @@ -160,8 +161,8 @@ func GetStoreTypeByMeta(store *metapb.Store) tikvrpc.EndpointType { } // NewRegionRequestSender creates a new sender. -func NewRegionRequestSender(regionCache *RegionCache, client client.Client) *RegionRequestSender { - return locate.NewRegionRequestSender(regionCache, client) +func NewRegionRequestSender(regionCache *RegionCache, client client.Client, readTSValidator oracle.ReadTSValidator) *RegionRequestSender { + return locate.NewRegionRequestSender(regionCache, client, readTSValidator) } // LoadShuttingDown atomically loads ShuttingDown. diff --git a/tikv/split_region.go b/tikv/split_region.go index 2844b3889c..6f9d1f9dd7 100644 --- a/tikv/split_region.go +++ b/tikv/split_region.go @@ -148,7 +148,7 @@ func (s *KVStore) batchSendSingleRegion(bo *Backoffer, batch kvrpc.Batch, scatte RequestSource: util.RequestSourceFromCtx(bo.GetCtx()), }) - sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient()) + sender := locate.NewRegionRequestSender(s.regionCache, s.GetTiKVClient(), s.oracle) resp, _, err := sender.SendReq(bo, req, batch.RegionID, client.ReadTimeoutShort) batchResp := kvrpc.BatchResult{Response: resp} diff --git a/txnkv/transaction/commit.go b/txnkv/transaction/commit.go index 9e3eac4fe4..3818e39783 100644 --- a/txnkv/transaction/commit.go +++ b/txnkv/transaction/commit.go @@ -95,7 +95,7 @@ func (action actionCommit) handleSingleBatch(c *twoPhaseCommitter, bo *retry.Bac tBegin := time.Now() attempts := 0 - sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.GetOracle()) for { attempts++ reqBegin := time.Now() diff --git a/txnkv/transaction/pessimistic.go b/txnkv/transaction/pessimistic.go index 3db5670281..5855a08df8 100644 --- a/txnkv/transaction/pessimistic.go +++ b/txnkv/transaction/pessimistic.go @@ -184,7 +184,7 @@ func (action actionPessimisticLock) handleSingleBatch( time.Sleep(300 * time.Millisecond) return errors.WithStack(&tikverr.ErrWriteConflict{WriteConflict: nil}) } - sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.GetOracle()) startTime := time.Now() resp, _, err := sender.SendReq(bo, req, batch.region, client.ReadTimeoutShort) diagCtx.reqDuration = time.Since(startTime) diff --git a/txnkv/transaction/prewrite.go b/txnkv/transaction/prewrite.go index e83fee3f86..d74b5fe6d1 100644 --- a/txnkv/transaction/prewrite.go +++ b/txnkv/transaction/prewrite.go @@ -268,7 +268,7 @@ func (action actionPrewrite) handleSingleBatch( attempts := 0 req := c.buildPrewriteRequest(batch, txnSize) - sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.GetOracle()) var resolvingRecordToken *int defer func() { if err != nil { diff --git a/txnkv/txnsnapshot/client_helper.go b/txnkv/txnsnapshot/client_helper.go index bbd4f70cea..b6038d4240 100644 --- a/txnkv/txnsnapshot/client_helper.go +++ b/txnkv/txnsnapshot/client_helper.go @@ -40,6 +40,7 @@ import ( "github.com/tikv/client-go/v2/internal/client" "github.com/tikv/client-go/v2/internal/locate" "github.com/tikv/client-go/v2/internal/retry" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/txnkv/txnlock" "github.com/tikv/client-go/v2/util" @@ -62,6 +63,7 @@ type ClientHelper struct { committedLocks *util.TSSet client client.Client resolveLite bool + oracle oracle.Oracle Stats *locate.RegionRequestRuntimeStats } @@ -74,6 +76,7 @@ func NewClientHelper(store kvstore, resolvedLocks *util.TSSet, committedLocks *u committedLocks: committedLocks, client: store.GetTiKVClient(), resolveLite: resolveLite, + oracle: store.GetOracle(), } } @@ -136,7 +139,7 @@ func (ch *ClientHelper) ResolveLocksDone(callerStartTS uint64, token int) { // SendReqCtx wraps the SendReqCtx function and use the resolved lock result in the kvrpcpb.Context. func (ch *ClientHelper) SendReqCtx(bo *retry.Backoffer, req *tikvrpc.Request, regionID locate.RegionVerID, timeout time.Duration, et tikvrpc.EndpointType, directStoreAddr string, opts ...locate.StoreSelectorOption) (*tikvrpc.Response, *locate.RPCContext, string, error) { - sender := locate.NewRegionRequestSender(ch.regionCache, ch.client) + sender := locate.NewRegionRequestSender(ch.regionCache, ch.client, ch.oracle) if len(directStoreAddr) > 0 { sender.SetStoreAddr(directStoreAddr) } diff --git a/txnkv/txnsnapshot/scan.go b/txnkv/txnsnapshot/scan.go index 59c8fca84b..0b99e93c49 100644 --- a/txnkv/txnsnapshot/scan.go +++ b/txnkv/txnsnapshot/scan.go @@ -197,7 +197,7 @@ func (s *Scanner) getData(bo *retry.Backoffer) error { zap.String("nextEndKey", kv.StrKey(s.nextEndKey)), zap.Bool("reverse", s.reverse), zap.Uint64("txnStartTS", s.startTS())) - sender := locate.NewRegionRequestSender(s.snapshot.store.GetRegionCache(), s.snapshot.store.GetTiKVClient()) + sender := locate.NewRegionRequestSender(s.snapshot.store.GetRegionCache(), s.snapshot.store.GetTiKVClient(), s.snapshot.store.GetOracle()) var reqEndKey, reqStartKey []byte var loc *locate.KeyLocation var resolvingRecordToken *int