Skip to content

Commit

Permalink
Merge pull request #1418 from ydb-platform/pool-sema
Browse files Browse the repository at this point in the history
refactoring of internal/pool
  • Loading branch information
asmyasnikov authored Aug 19, 2024
2 parents aa4bd13 + 3a98bcb commit c93c91c
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 184 deletions.
1 change: 1 addition & 0 deletions internal/pool/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ import (
var (
errClosedPool = errors.New("closed pool")
errItemIsNotAlive = errors.New("item is not alive")
errPoolIsOverflow = errors.New("pool is overflow")
)
223 changes: 70 additions & 153 deletions internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@ type (
IsAlive() bool
Close(ctx context.Context) error
}
safeStats struct {
mu xsync.RWMutex
v Stats
onChange func(Stats)
}
statsItemAddr struct {
v *int
onChange func(func())
}
Pool[PT Item[T], T any] struct {
trace *Trace
limit int
Expand All @@ -37,80 +28,18 @@ type (
createTimeout time.Duration
closeTimeout time.Duration

mu xsync.Mutex
sema chan struct{}

mu xsync.RWMutex
idle []PT
index map[PT]struct{}
done chan struct{}

stats *safeStats
stats *Stats
}
option[PT Item[T], T any] func(p *Pool[PT, T])
)

func (field statsItemAddr) Inc() {
field.onChange(func() {
*field.v++
})
}

func (field statsItemAddr) Dec() {
field.onChange(func() {
*field.v--
})
}

func (s *safeStats) Get() Stats {
s.mu.RLock()
defer s.mu.RUnlock()

return s.v
}

func (s *safeStats) Index() statsItemAddr {
s.mu.RLock()
defer s.mu.RUnlock()

return statsItemAddr{
v: &s.v.Index,
onChange: func(f func()) {
s.mu.WithLock(f)
if s.onChange != nil {
s.onChange(s.Get())
}
},
}
}

func (s *safeStats) Idle() statsItemAddr {
s.mu.RLock()
defer s.mu.RUnlock()

return statsItemAddr{
v: &s.v.Idle,
onChange: func(f func()) {
s.mu.WithLock(f)
if s.onChange != nil {
s.onChange(s.Get())
}
},
}
}

func (s *safeStats) InUse() statsItemAddr {
s.mu.RLock()
defer s.mu.RUnlock()

return statsItemAddr{
v: &s.v.InUse,
onChange: func(f func()) {
s.mu.WithLock(f)
if s.onChange != nil {
s.onChange(s.Get())
}
},
}
}

func WithCreateFunc[PT Item[T], T any](f func(ctx context.Context) (PT, error)) option[PT, T] {
return func(p *Pool[PT, T]) {
p.createItem = f
Expand Down Expand Up @@ -170,13 +99,10 @@ func New[PT Item[T], T any](
}()

p.createItem = createItemWithTimeoutHandling(p.createItem, p)

p.sema = make(chan struct{}, p.limit)
p.idle = make([]PT, 0, p.limit)
p.index = make(map[PT]struct{}, p.limit)
p.stats = &safeStats{
v: Stats{Limit: p.limit},
onChange: p.trace.OnChange,
}
p.stats = &Stats{Limit: p.limit}

return p
}
Expand Down Expand Up @@ -263,7 +189,7 @@ func createItemWithContext[PT Item[T], T any](
if len(p.index) < p.limit {
p.idle = append(p.idle, newItem)
p.index[newItem] = struct{}{}
p.stats.Index().Inc()
p.stats.Index++
needCloseItem = false
}

Expand All @@ -276,10 +202,13 @@ func createItemWithContext[PT Item[T], T any](
}

func (p *Pool[PT, T]) Stats() Stats {
return p.stats.Get()
p.mu.RLock()
defer p.mu.RUnlock()

return *p.stats
}

func (p *Pool[PT, T]) getItem(ctx context.Context) (_ PT, finalErr error) {
func (p *Pool[PT, T]) getItem(ctx context.Context) (item PT, finalErr error) {
onDone := p.trace.OnGet(&GetStartInfo{
Context: &ctx,
Call: stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/pool.(*Pool).getItem"),
Expand All @@ -290,56 +219,29 @@ func (p *Pool[PT, T]) getItem(ctx context.Context) (_ PT, finalErr error) {
})
}()

if err := ctx.Err(); err != nil {
return nil, xerrors.WithStackTrace(err)
}
p.mu.Lock()
defer p.mu.Unlock()

for {
select {
case <-p.done:
return nil, xerrors.WithStackTrace(errClosedPool)
case <-ctx.Done():
return nil, xerrors.WithStackTrace(ctx.Err())
default:
var item PT
p.mu.WithLock(func() {
if len(p.idle) > 0 {
item, p.idle = p.idle[0], p.idle[1:]
p.stats.Idle().Dec()
}
})
if len(p.idle) > 0 {
item, p.idle = p.idle[0], p.idle[1:]
p.stats.Idle--

if item != nil {
if item.IsAlive() {
return item, nil
}
_ = p.closeItem(ctx, item)
p.mu.WithLock(func() {
delete(p.index, item)
})
p.stats.Index().Dec()
}
var err error
var newItem PT
p.mu.WithLock(func() {
if len(p.index) >= p.limit {
return
}
newItem, err = p.createItem(ctx)
if err != nil {
return
}
p.index[newItem] = struct{}{}
p.stats.Index().Inc()
})
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
if newItem != nil {
return newItem, nil
}
if item.IsAlive() {
return item, nil
}

_ = p.closeItem(ctx, item)
}

newItem, err := p.createItem(ctx)
if err != nil {
return nil, xerrors.WithStackTrace(xerrors.Retryable(err))
}

p.stats.Index++
p.index[newItem] = struct{}{}

return newItem, nil
}

func (p *Pool[PT, T]) putItem(ctx context.Context, item PT) (finalErr error) {
Expand All @@ -353,37 +255,28 @@ func (p *Pool[PT, T]) putItem(ctx context.Context, item PT) (finalErr error) {
})
}()

if err := ctx.Err(); err != nil {
return xerrors.WithStackTrace(err)
if !item.IsAlive() {
_ = p.closeItem(ctx, item)

return xerrors.WithStackTrace(errItemIsNotAlive)
}

select {
case <-p.done:
return xerrors.WithStackTrace(errClosedPool)
default:
if !item.IsAlive() {
_ = p.closeItem(ctx, item)
p.mu.Lock()
defer p.mu.Unlock()

p.mu.WithLock(func() {
delete(p.index, item)
})
p.stats.Index().Dec()
if len(p.idle) >= p.limit {
_ = p.closeItem(ctx, item)

return xerrors.WithStackTrace(errItemIsNotAlive)
}
return xerrors.WithStackTrace(errPoolIsOverflow)
}

p.mu.WithLock(func() {
p.idle = append(p.idle, item)
})
p.stats.Idle().Inc()
p.idle = append(p.idle, item)
p.stats.Idle--

return nil
}
return nil
}

func (p *Pool[PT, T]) closeItem(ctx context.Context, item PT) error {
ctx = xcontext.ValueOnly(ctx)

var cancel context.CancelFunc
if d := p.closeTimeout; d > 0 {
ctx, cancel = xcontext.WithTimeout(ctx, d)
Expand All @@ -392,6 +285,13 @@ func (p *Pool[PT, T]) closeItem(ctx context.Context, item PT) error {
}
defer cancel()

defer func() {
p.mu.WithLock(func() {
delete(p.index, item)
p.stats.Index--
})
}()

return item.Close(ctx)
}

Expand All @@ -406,6 +306,17 @@ func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item
})
}()

select {
case <-p.done:
return xerrors.WithStackTrace(errClosedPool)
case <-ctx.Done():
return xerrors.WithStackTrace(ctx.Err())
case p.sema <- struct{}{}:
defer func() {
<-p.sema
}()
}

item, err := p.getItem(ctx)
if err != nil {
if xerrors.IsYdb(err) {
Expand All @@ -419,8 +330,14 @@ func (p *Pool[PT, T]) try(ctx context.Context, f func(ctx context.Context, item
_ = p.putItem(ctx, item)
}()

p.stats.InUse().Inc()
defer p.stats.InUse().Dec()
p.mu.Lock()
p.stats.InUse++
p.mu.Unlock()
defer func() {
p.mu.Lock()
p.stats.InUse--
p.mu.Unlock()
}()

err = f(ctx, item)
if err != nil {
Expand Down
33 changes: 2 additions & 31 deletions internal/pool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package pool
import (
"context"
"errors"
"math/rand"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -287,7 +286,7 @@ func TestPool(t *testing.T) {
require.NoError(t, err)
}()
wg.Wait()
}, xtest.StopAfter(42*time.Second))
}, xtest.StopAfter(14*time.Second))
})
t.Run("ParallelCreation", func(t *testing.T) {
xtest.TestManyTimes(t, func(t testing.TB) {
Expand All @@ -313,34 +312,6 @@ func TestPool(t *testing.T) {
require.Equal(t, DefaultLimit, stats.Limit)
require.Equal(t, 0, stats.InUse)
require.LessOrEqual(t, stats.Idle, DefaultLimit)
}, xtest.StopAfter(30*time.Second))
}, xtest.StopAfter(14*time.Second))
})
}

func TestSafeStatsRace(t *testing.T) {
xtest.TestManyTimes(t, func(t testing.TB) {
var (
wg sync.WaitGroup
s = &safeStats{}
)
wg.Add(1000)
for range make([]struct{}, 1000) {
go func() {
defer wg.Done()
require.NotPanics(t, func() {
switch rand.Int31n(4) { //nolint:gosec
case 0:
s.Index().Inc()
case 1:
s.InUse().Inc()
case 2:
s.Idle().Inc()
default:
s.Get()
}
})
}()
}
wg.Wait()
}, xtest.StopAfter(5*time.Second))
}

0 comments on commit c93c91c

Please sign in to comment.