diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a64a7910..eb961a459 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Fixed updating last usage timestamp for smart parking of the conns + ## v3.59.0 * Added `Struct` support for `ydb.ParamsBuilder()` * Added support of `TzDate`,`TzDateTime`,`TzTimestamp` types in `ydb.ParamsBuilder()` diff --git a/internal/conn/conn.go b/internal/conn/conn.go index 69d3d724f..33aef1d05 100644 --- a/internal/conn/conn.go +++ b/internal/conn/conn.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "time" + "github.com/jonboulle/clockwork" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" @@ -55,7 +56,7 @@ type conn struct { endpoint endpoint.Endpoint // ro access closed bool state atomic.Uint32 - lastUsage time.Time + lastUsage *lastUsage onClose []func(*conn) onTransportErrors []func(ctx context.Context, cc Conn, cause error) } @@ -80,7 +81,7 @@ func (c *conn) LastUsage() time.Time { c.mtx.RLock() defer c.mtx.RUnlock() - return c.lastUsage + return c.lastUsage.Get() } func (c *conn) IsState(states ...State) bool { @@ -244,12 +245,6 @@ func (c *conn) onTransportError(ctx context.Context, cause error) { } } -func (c *conn) touchLastUsage() { - c.mtx.Lock() - defer c.mtx.Unlock() - c.lastUsage = time.Now() -} - func isAvailable(raw *grpc.ClientConn) bool { return raw != nil && raw.GetState() == connectivity.Ready } @@ -332,8 +327,7 @@ func (c *conn) Invoke( return c.wrapError(err) } - c.touchLastUsage() - defer c.touchLastUsage() + defer c.lastUsage.Lock()() ctx, traceID, err := meta.TraceID(ctx) if err != nil { @@ -418,8 +412,7 @@ func (c *conn) NewStream( return nil, c.wrapError(err) } - c.touchLastUsage() - defer c.touchLastUsage() + defer c.lastUsage.Lock()() ctx, traceID, err := meta.TraceID(ctx) if err != nil { @@ -494,10 +487,15 @@ func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, ca } func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn { + clock := clockwork.NewRealClock() c := &conn{ endpoint: e, config: config, done: make(chan struct{}), + lastUsage: &lastUsage{ + t: clock.Now(), + clock: clock, + }, } c.state.Store(uint32(Created)) for _, opt := range opts { diff --git a/internal/conn/grpc_client_stream.go b/internal/conn/grpc_client_stream.go index 120a5e3a7..8af9e3b15 100644 --- a/internal/conn/grpc_client_stream.go +++ b/internal/conn/grpc_client_stream.go @@ -30,6 +30,8 @@ func (s *grpcClientStream) CloseSend() (err error) { onDone(err) }() + defer s.c.lastUsage.Lock()() + err = s.ClientStream.CloseSend() if err != nil { @@ -59,6 +61,8 @@ func (s *grpcClientStream) SendMsg(m interface{}) (err error) { onDone(err) }() + defer s.c.lastUsage.Lock()() + err = s.ClientStream.SendMsg(m) if err != nil { @@ -96,6 +100,8 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) { onDone(err) }() + defer s.c.lastUsage.Lock()() + defer func() { if err != nil { md := s.ClientStream.Trailer() diff --git a/internal/conn/last_usage.go b/internal/conn/last_usage.go new file mode 100644 index 000000000..c04a30374 --- /dev/null +++ b/internal/conn/last_usage.go @@ -0,0 +1,45 @@ +package conn + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/jonboulle/clockwork" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" +) + +type lastUsage struct { + locks atomic.Int64 + mu xsync.RWMutex + t time.Time + clock clockwork.Clock +} + +func (l *lastUsage) Get() time.Time { + if l.locks.CompareAndSwap(0, 1) { + defer func() { + l.locks.Add(-1) + }() + + l.mu.RLock() + defer l.mu.RUnlock() + + return l.t + } + + return l.clock.Now() +} + +func (l *lastUsage) Lock() (releaseFunc func()) { + l.locks.Add(1) + + return sync.OnceFunc(func() { + if l.locks.Add(-1) == 0 { + l.mu.WithLock(func() { + l.t = l.clock.Now() + }) + } + }) +} diff --git a/internal/conn/last_usage_test.go b/internal/conn/last_usage_test.go new file mode 100644 index 000000000..f39abfb6c --- /dev/null +++ b/internal/conn/last_usage_test.go @@ -0,0 +1,97 @@ +package conn + +import ( + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +func Test_lastUsage_Lock(t *testing.T) { + t.Run("NowFromLocked", func(t *testing.T) { + start := time.Unix(0, 0) + clock := clockwork.NewFakeClockAt(start) + lu := &lastUsage{ + t: start, + clock: clock, + } + t1 := lu.Get() + require.Equal(t, start, t1) + f := lu.Lock() + clock.Advance(time.Hour) + t2 := lu.Get() + require.Equal(t, start.Add(time.Hour), t2) + clock.Advance(time.Hour) + f() + t3 := lu.Get() + require.Equal(t, start.Add(2*time.Hour), t3) + clock.Advance(time.Hour) + t4 := lu.Get() + require.Equal(t, start.Add(2*time.Hour), t4) + }) + t.Run("UpdateAfterLastUnlock", func(t *testing.T) { + start := time.Unix(0, 0) + clock := clockwork.NewFakeClockAt(start) + lu := &lastUsage{ + t: start, + clock: clock, + } + t1 := lu.Get() + require.Equal(t, start, t1) + f1 := lu.Lock() + clock.Advance(time.Hour) + t2 := lu.Get() + require.Equal(t, start.Add(time.Hour), t2) + f2 := lu.Lock() + clock.Advance(time.Hour) + f1() + f3 := lu.Lock() + clock.Advance(time.Hour) + t3 := lu.Get() + require.Equal(t, start.Add(3*time.Hour), t3) + clock.Advance(time.Hour) + t4 := lu.Get() + require.Equal(t, start.Add(4*time.Hour), t4) + f3() + t5 := lu.Get() + require.Equal(t, start.Add(4*time.Hour), t5) + clock.Advance(time.Hour) + t6 := lu.Get() + require.Equal(t, start.Add(5*time.Hour), t6) + clock.Advance(time.Hour) + f2() + t7 := lu.Get() + require.Equal(t, start.Add(6*time.Hour), t7) + clock.Advance(time.Hour) + f2() + t8 := lu.Get() + require.Equal(t, start.Add(6*time.Hour), t8) + }) + t.Run("DeferRelease", func(t *testing.T) { + start := time.Unix(0, 0) + clock := clockwork.NewFakeClockAt(start) + lu := &lastUsage{ + t: start, + clock: clock, + } + func() { + t1 := lu.Get() + require.Equal(t, start, t1) + clock.Advance(time.Hour) + t2 := lu.Get() + require.Equal(t, start, t2) + clock.Advance(time.Hour) + defer lu.Lock()() + t3 := lu.Get() + require.Equal(t, start.Add(2*time.Hour), t3) + clock.Advance(time.Hour) + t4 := lu.Get() + require.Equal(t, start.Add(3*time.Hour), t4) + clock.Advance(time.Hour) + }() + clock.Advance(time.Hour) + t5 := lu.Get() + require.Equal(t, start.Add(4*time.Hour), t5) + }) +}