Skip to content

Commit

Permalink
Fix query cache singleflight race condition (#2489)
Browse files Browse the repository at this point in the history
* Fix query cache singleflight race condition

* Upgrade Go version

* Fix query key type

* Fix type issue
  • Loading branch information
begelundmuller authored May 24, 2023
1 parent e4615b9 commit 9a9e528
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 240 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
steps:
- uses: actions/setup-go@v3
with:
go-version: 1.19.x
go-version: 1.20.3
- uses: actions/checkout@v3
- uses: actions/cache@v3
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/web-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
- name: Set up go for E2E
uses: actions/setup-go@v3
with:
go-version: 1.19.x
go-version: 1.20.3
- name: go build cache
uses: actions/cache@v3
with:
Expand Down
96 changes: 0 additions & 96 deletions runtime/caches.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,16 @@ package runtime
import (
"context"
"errors"
"fmt"
"sync"

"github.com/dgraph-io/ristretto"
"github.com/hashicorp/golang-lru/simplelru"
"github.com/rilldata/rill/runtime/drivers"
"github.com/rilldata/rill/runtime/pkg/observability"
"github.com/rilldata/rill/runtime/pkg/singleflight"
"github.com/rilldata/rill/runtime/services/catalog"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/global"
"go.uber.org/zap"
)

var errConnectionCacheClosed = errors.New("connectionCache: closed")

var (
meter = global.Meter("runtime")
queryCacheHitsCounter = observability.Must(meter.Int64ObservableCounter("query_cache.hits"))
queryCacheMissesCounter = observability.Must(meter.Int64ObservableCounter("query_cache.misses"))
queryCacheItemCountGauge = observability.Must(meter.Int64ObservableGauge("query_cache.items"))
queryCacheSizeBytesGauge = observability.Must(meter.Int64ObservableGauge("query_cache.size", metric.WithUnit("bytes")))
queryCacheEntrySizeHistogram = observability.Must(meter.Int64Histogram("query_cache.entry_size", metric.WithUnit("bytes")))
)

// cache for instance specific connections only
// all instance specific connections should be opened via connection cache only
type connectionCache struct {
Expand Down Expand Up @@ -156,83 +140,3 @@ func (c *migrationMetaCache) evict(ctx context.Context, instID string) {
defer c.lock.Unlock()
c.cache.Remove(instID)
}

type queryCache struct {
cache *ristretto.Cache
group *singleflight.Group
}

func newQueryCache(sizeInBytes int64) *queryCache {
if sizeInBytes <= 100 {
panic(fmt.Sprintf("invalid cache size should be greater than 100: %v", sizeInBytes))
}
cache, err := ristretto.NewCache(&ristretto.Config{
// Use 5% of cache memory for storing counters. Each counter takes roughly 3 bytes.
// Recommended value is 10x the number of items in cache when full.
// Tune this again based on metrics.
NumCounters: int64(float64(sizeInBytes) * 0.05 / 3),
MaxCost: int64(float64(sizeInBytes) * 0.95),
BufferItems: 64,
Metrics: true,
})
if err != nil {
panic(err)
}

observability.Must(meter.RegisterCallback(func(ctx context.Context, observer metric.Observer) error {
observer.ObserveInt64(queryCacheHitsCounter, int64(cache.Metrics.Hits()))
observer.ObserveInt64(queryCacheMissesCounter, int64(cache.Metrics.Misses()))
observer.ObserveInt64(queryCacheItemCountGauge, int64(cache.Metrics.KeysAdded()-cache.Metrics.KeysEvicted()))
observer.ObserveInt64(queryCacheSizeBytesGauge, int64(cache.Metrics.CostAdded()-cache.Metrics.CostEvicted()))
return nil
}, queryCacheHitsCounter, queryCacheMissesCounter, queryCacheItemCountGauge, queryCacheSizeBytesGauge))
return &queryCache{
cache: cache,
group: &singleflight.Group{},
}
}

// getOrLoad gets the key from cache if present. If absent, it looks up the key using the loadFn and puts it into cache before returning value.
func (c *queryCache) getOrLoad(ctx context.Context, key, queryName string, loadFn func(context.Context) (any, error)) (any, bool, error) {
if val, ok := c.cache.Get(key); ok {
return val, true, nil
}

cached := true
val, err := c.group.Do(ctx, key, func(ctx context.Context) (interface{}, error) {
// check the cache again
if val, ok := c.cache.Get(key); ok {
return val, nil
}

val, err := loadFn(ctx)
if err != nil {
return nil, err
}

// only one caller of load gets this return value
cached = false
cachedObject := val.(*QueryResult)
attrs := attribute.NewSet(
attribute.String("query", queryName),
)
queryCacheEntrySizeHistogram.Record(ctx, cachedObject.Bytes, metric.WithAttributeSet(attrs))
c.cache.Set(key, cachedObject.Value, cachedObject.Bytes)
return cachedObject.Value, nil
})
if err != nil {
return nil, false, err
}

return val, cached, nil
}

// nolint:unused // use in tests
func (c *queryCache) add(key, val any, cost int64) bool {
return c.cache.Set(key, val, cost)
}

// nolint:unused // use in tests
func (c *queryCache) get(key any) (any, bool) {
return c.cache.Get(key)
}
77 changes: 0 additions & 77 deletions runtime/caches_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ package runtime

import (
"context"
"sync"
"testing"
"time"

"github.com/c2h5oh/datasize"
"github.com/google/uuid"
_ "github.com/rilldata/rill/runtime/drivers/sqlite"
"github.com/stretchr/testify/require"
Expand All @@ -33,77 +30,3 @@ func TestConnectionCache(t *testing.T) {
require.True(t, conn1 == conn2)
require.False(t, conn2 == conn3)
}

func TestNilValues(t *testing.T) {
qc := newQueryCache(int64(datasize.MB * 100))
defer qc.cache.Close()

qc.add(queryCacheKey{"1", "1", "1"}.String(), "value", 1)
qc.cache.Wait()
v, ok := qc.get(queryCacheKey{"1", "1", "1"}.String())
require.Equal(t, "value", v)
require.True(t, ok)

qc.add(queryCacheKey{"1", "1", "1"}.String(), nil, 1)
qc.cache.Wait()
v, ok = qc.get(queryCacheKey{"1", "1", "1"}.String())
require.Nil(t, v)
require.True(t, ok)

v, ok = qc.get(queryCacheKey{"nosuch", "nosuch", "nosuch"}.String())
require.Nil(t, v)
require.False(t, ok)
}

func Test_queryCache_getOrLoad(t *testing.T) {
qc := newQueryCache(int64(datasize.MB))
defer qc.cache.Close()

f := func(ctx context.Context) (interface{}, error) {
for {
select {
case <-ctx.Done():
// Handle context cancellation
return nil, ctx.Err()
case <-time.After(200 * time.Millisecond):
// Simulate some work
return &QueryResult{Value: "hello"}, nil
}
}
}
errs := make([]error, 5)
values := make([]interface{}, 5)
cached := make([]bool, 5)
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(i int) {
var ctx context.Context
var cancel context.CancelFunc
if i%2 == 0 {
// cancel all even requests
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
} else {
ctx, cancel = context.WithCancel(context.TODO())
}
defer cancel()
defer wg.Done()
values[i], cached[i], errs[i] = qc.getOrLoad(ctx, "key", "query", f)

}(i)
time.Sleep(10 * time.Millisecond) // ensure that first goroutine starts the work
}
wg.Wait()

require.False(t, cached[0])
require.Error(t, errs[0])
for i := 1; i < 5; i++ {
if i%2 == 0 {
require.Error(t, errs[i])
} else {
require.True(t, cached[i])
require.NoError(t, errs[i])
require.Equal(t, values[i], "hello")
}
}
}
21 changes: 11 additions & 10 deletions runtime/pkg/singleflight/singleflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,34 @@ func newPanicError(v interface{}) error {
}

// call is an in-flight or completed singleflight.Do call
type call struct {
type call[V any] struct {
ctx context.Context
cancel context.CancelFunc
counter uint
val interface{}
val V
err error
}

// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
type Group[K comparable, V any] struct {
mu sync.Mutex // protects m
m map[K]*call[V] // lazily initialized
}

// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
func (g *Group) Do(ctx context.Context, key string, fn func(context.Context) (any, error)) (any, error) {
func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(context.Context) (V, error)) (V, error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
g.m = make(map[K]*call[V])
}
c, ok := g.m[key]
if !ok {
cctx, cancel := withCancelAndContextValues(ctx)
c = &call{
c = &call[V]{
ctx: cctx,
cancel: cancel,
}
Expand All @@ -97,7 +97,8 @@ func (g *Group) Do(ctx context.Context, key string, fn func(context.Context) (an
g.mu.Unlock()

if ctx.Err() != nil {
return nil, ctx.Err()
var empty V
return empty, ctx.Err()
}

pErr := &panicError{}
Expand All @@ -110,7 +111,7 @@ func (g *Group) Do(ctx context.Context, key string, fn func(context.Context) (an
}

// doCall handles the single call for a key.
func (g *Group) doCall(c *call, key string, fn func(ctx context.Context) (interface{}, error)) {
func (g *Group[K, V]) doCall(c *call[V], key K, fn func(ctx context.Context) (V, error)) {
normalReturn := false
recovered := false

Expand Down
Loading

0 comments on commit 9a9e528

Please sign in to comment.