diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml index 7babfcef25d..10d27864ec8 100644 --- a/.github/workflows/go-test.yml +++ b/.github/workflows/go-test.yml @@ -11,7 +11,7 @@ jobs: steps: - uses: actions/setup-go@v3 with: - go-version: 1.19.x + go-version: 1.20.3 - uses: actions/checkout@v3 - uses: actions/cache@v3 with: diff --git a/.github/workflows/web-test.yml b/.github/workflows/web-test.yml index 3eb6c633529..d2dd3de75af 100644 --- a/.github/workflows/web-test.yml +++ b/.github/workflows/web-test.yml @@ -40,7 +40,7 @@ jobs: - name: Set up go for E2E uses: actions/setup-go@v3 with: - go-version: 1.19.x + go-version: 1.20.3 - name: go build cache uses: actions/cache@v3 with: diff --git a/runtime/caches.go b/runtime/caches.go index 8574a6b835c..4e5ffc6368b 100644 --- a/runtime/caches.go +++ b/runtime/caches.go @@ -3,32 +3,16 @@ package runtime import ( "context" "errors" - "fmt" "sync" - "github.com/dgraph-io/ristretto" "github.com/hashicorp/golang-lru/simplelru" "github.com/rilldata/rill/runtime/drivers" - "github.com/rilldata/rill/runtime/pkg/observability" - "github.com/rilldata/rill/runtime/pkg/singleflight" "github.com/rilldata/rill/runtime/services/catalog" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" - "go.opentelemetry.io/otel/metric/global" "go.uber.org/zap" ) var errConnectionCacheClosed = errors.New("connectionCache: closed") -var ( - meter = global.Meter("runtime") - queryCacheHitsCounter = observability.Must(meter.Int64ObservableCounter("query_cache.hits")) - queryCacheMissesCounter = observability.Must(meter.Int64ObservableCounter("query_cache.misses")) - queryCacheItemCountGauge = observability.Must(meter.Int64ObservableGauge("query_cache.items")) - queryCacheSizeBytesGauge = observability.Must(meter.Int64ObservableGauge("query_cache.size", metric.WithUnit("bytes"))) - queryCacheEntrySizeHistogram = observability.Must(meter.Int64Histogram("query_cache.entry_size", metric.WithUnit("bytes"))) -) - // cache for instance specific connections only // all instance specific connections should be opened via connection cache only type connectionCache struct { @@ -156,83 +140,3 @@ func (c *migrationMetaCache) evict(ctx context.Context, instID string) { defer c.lock.Unlock() c.cache.Remove(instID) } - -type queryCache struct { - cache *ristretto.Cache - group *singleflight.Group -} - -func newQueryCache(sizeInBytes int64) *queryCache { - if sizeInBytes <= 100 { - panic(fmt.Sprintf("invalid cache size should be greater than 100: %v", sizeInBytes)) - } - cache, err := ristretto.NewCache(&ristretto.Config{ - // Use 5% of cache memory for storing counters. Each counter takes roughly 3 bytes. - // Recommended value is 10x the number of items in cache when full. - // Tune this again based on metrics. - NumCounters: int64(float64(sizeInBytes) * 0.05 / 3), - MaxCost: int64(float64(sizeInBytes) * 0.95), - BufferItems: 64, - Metrics: true, - }) - if err != nil { - panic(err) - } - - observability.Must(meter.RegisterCallback(func(ctx context.Context, observer metric.Observer) error { - observer.ObserveInt64(queryCacheHitsCounter, int64(cache.Metrics.Hits())) - observer.ObserveInt64(queryCacheMissesCounter, int64(cache.Metrics.Misses())) - observer.ObserveInt64(queryCacheItemCountGauge, int64(cache.Metrics.KeysAdded()-cache.Metrics.KeysEvicted())) - observer.ObserveInt64(queryCacheSizeBytesGauge, int64(cache.Metrics.CostAdded()-cache.Metrics.CostEvicted())) - return nil - }, queryCacheHitsCounter, queryCacheMissesCounter, queryCacheItemCountGauge, queryCacheSizeBytesGauge)) - return &queryCache{ - cache: cache, - group: &singleflight.Group{}, - } -} - -// getOrLoad gets the key from cache if present. If absent, it looks up the key using the loadFn and puts it into cache before returning value. -func (c *queryCache) getOrLoad(ctx context.Context, key, queryName string, loadFn func(context.Context) (any, error)) (any, bool, error) { - if val, ok := c.cache.Get(key); ok { - return val, true, nil - } - - cached := true - val, err := c.group.Do(ctx, key, func(ctx context.Context) (interface{}, error) { - // check the cache again - if val, ok := c.cache.Get(key); ok { - return val, nil - } - - val, err := loadFn(ctx) - if err != nil { - return nil, err - } - - // only one caller of load gets this return value - cached = false - cachedObject := val.(*QueryResult) - attrs := attribute.NewSet( - attribute.String("query", queryName), - ) - queryCacheEntrySizeHistogram.Record(ctx, cachedObject.Bytes, metric.WithAttributeSet(attrs)) - c.cache.Set(key, cachedObject.Value, cachedObject.Bytes) - return cachedObject.Value, nil - }) - if err != nil { - return nil, false, err - } - - return val, cached, nil -} - -// nolint:unused // use in tests -func (c *queryCache) add(key, val any, cost int64) bool { - return c.cache.Set(key, val, cost) -} - -// nolint:unused // use in tests -func (c *queryCache) get(key any) (any, bool) { - return c.cache.Get(key) -} diff --git a/runtime/caches_test.go b/runtime/caches_test.go index 01276eda2ce..9545bfa4230 100644 --- a/runtime/caches_test.go +++ b/runtime/caches_test.go @@ -2,11 +2,8 @@ package runtime import ( "context" - "sync" "testing" - "time" - "github.com/c2h5oh/datasize" "github.com/google/uuid" _ "github.com/rilldata/rill/runtime/drivers/sqlite" "github.com/stretchr/testify/require" @@ -33,77 +30,3 @@ func TestConnectionCache(t *testing.T) { require.True(t, conn1 == conn2) require.False(t, conn2 == conn3) } - -func TestNilValues(t *testing.T) { - qc := newQueryCache(int64(datasize.MB * 100)) - defer qc.cache.Close() - - qc.add(queryCacheKey{"1", "1", "1"}.String(), "value", 1) - qc.cache.Wait() - v, ok := qc.get(queryCacheKey{"1", "1", "1"}.String()) - require.Equal(t, "value", v) - require.True(t, ok) - - qc.add(queryCacheKey{"1", "1", "1"}.String(), nil, 1) - qc.cache.Wait() - v, ok = qc.get(queryCacheKey{"1", "1", "1"}.String()) - require.Nil(t, v) - require.True(t, ok) - - v, ok = qc.get(queryCacheKey{"nosuch", "nosuch", "nosuch"}.String()) - require.Nil(t, v) - require.False(t, ok) -} - -func Test_queryCache_getOrLoad(t *testing.T) { - qc := newQueryCache(int64(datasize.MB)) - defer qc.cache.Close() - - f := func(ctx context.Context) (interface{}, error) { - for { - select { - case <-ctx.Done(): - // Handle context cancellation - return nil, ctx.Err() - case <-time.After(200 * time.Millisecond): - // Simulate some work - return &QueryResult{Value: "hello"}, nil - } - } - } - errs := make([]error, 5) - values := make([]interface{}, 5) - cached := make([]bool, 5) - var wg sync.WaitGroup - for i := 0; i < 5; i++ { - wg.Add(1) - go func(i int) { - var ctx context.Context - var cancel context.CancelFunc - if i%2 == 0 { - // cancel all even requests - ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) - } else { - ctx, cancel = context.WithCancel(context.TODO()) - } - defer cancel() - defer wg.Done() - values[i], cached[i], errs[i] = qc.getOrLoad(ctx, "key", "query", f) - - }(i) - time.Sleep(10 * time.Millisecond) // ensure that first goroutine starts the work - } - wg.Wait() - - require.False(t, cached[0]) - require.Error(t, errs[0]) - for i := 1; i < 5; i++ { - if i%2 == 0 { - require.Error(t, errs[i]) - } else { - require.True(t, cached[i]) - require.NoError(t, errs[i]) - require.Equal(t, values[i], "hello") - } - } -} diff --git a/runtime/pkg/singleflight/singleflight.go b/runtime/pkg/singleflight/singleflight.go index 1e97b15b87b..66e4145cfd4 100644 --- a/runtime/pkg/singleflight/singleflight.go +++ b/runtime/pkg/singleflight/singleflight.go @@ -44,34 +44,34 @@ func newPanicError(v interface{}) error { } // call is an in-flight or completed singleflight.Do call -type call struct { +type call[V any] struct { ctx context.Context cancel context.CancelFunc counter uint - val interface{} + val V err error } // Group represents a class of work and forms a namespace in // which units of work can be executed with duplicate suppression. -type Group struct { - mu sync.Mutex // protects m - m map[string]*call // lazily initialized +type Group[K comparable, V any] struct { + mu sync.Mutex // protects m + m map[K]*call[V] // lazily initialized } // Do executes and returns the results of the given function, making // sure that only one execution is in-flight for a given key at a // time. If a duplicate comes in, the duplicate caller waits for the // original to complete and receives the same results. -func (g *Group) Do(ctx context.Context, key string, fn func(context.Context) (any, error)) (any, error) { +func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(context.Context) (V, error)) (V, error) { g.mu.Lock() if g.m == nil { - g.m = make(map[string]*call) + g.m = make(map[K]*call[V]) } c, ok := g.m[key] if !ok { cctx, cancel := withCancelAndContextValues(ctx) - c = &call{ + c = &call[V]{ ctx: cctx, cancel: cancel, } @@ -97,7 +97,8 @@ func (g *Group) Do(ctx context.Context, key string, fn func(context.Context) (an g.mu.Unlock() if ctx.Err() != nil { - return nil, ctx.Err() + var empty V + return empty, ctx.Err() } pErr := &panicError{} @@ -110,7 +111,7 @@ func (g *Group) Do(ctx context.Context, key string, fn func(context.Context) (an } // doCall handles the single call for a key. -func (g *Group) doCall(c *call, key string, fn func(ctx context.Context) (interface{}, error)) { +func (g *Group[K, V]) doCall(c *call[V], key K, fn func(ctx context.Context) (V, error)) { normalReturn := false recovered := false diff --git a/runtime/pkg/singleflight/singleflight_test.go b/runtime/pkg/singleflight/singleflight_test.go index 3b762fdae50..e5bbe35dfec 100644 --- a/runtime/pkg/singleflight/singleflight_test.go +++ b/runtime/pkg/singleflight/singleflight_test.go @@ -15,8 +15,8 @@ import ( ) func TestDo(t *testing.T) { - var g Group - v, err := g.Do(context.Background(), "key", func(context.Context) (interface{}, error) { + var g Group[string, string] + v, err := g.Do(context.Background(), "key", func(context.Context) (string, error) { return "bar", nil }) if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { @@ -28,9 +28,9 @@ func TestDo(t *testing.T) { } func TestDoErr(t *testing.T) { - var g Group + var g Group[string, any] someErr := errors.New("Some error") - v, err := g.Do(context.Background(), "key", func(context.Context) (interface{}, error) { + v, err := g.Do(context.Background(), "key", func(context.Context) (any, error) { return nil, someErr }) if err != someErr { @@ -42,11 +42,11 @@ func TestDoErr(t *testing.T) { } func TestDoDupSuppress(t *testing.T) { - var g Group + var g Group[string, string] var wg1, wg2 sync.WaitGroup c := make(chan string, 1) var calls int32 - fn := func(ctx context.Context) (interface{}, error) { + fn := func(ctx context.Context) (string, error) { if atomic.AddInt32(&calls, 1) == 1 { // First invocation. wg1.Done() @@ -72,7 +72,7 @@ func TestDoDupSuppress(t *testing.T) { t.Errorf("Do error: %v", err) return } - if s, _ := v.(string); s != "bar" { + if v != "bar" { t.Errorf("Do = %T %v; want %q", v, v, "bar") } }() @@ -90,7 +90,7 @@ func TestDoDupSuppress(t *testing.T) { // Test singleflight behaves correctly after Do panic. // See https://github.com/golang/go/issues/41133 func TestPanicDo(t *testing.T) { - var g Group + var g Group[string, any] fn := func(context.Context) (interface{}, error) { panic("invalid memory address or nil pointer dereference") } @@ -127,7 +127,7 @@ func TestPanicDo(t *testing.T) { } func TestGoexitDo(t *testing.T) { - var g Group + var g Group[string, any] fn := func(ctx context.Context) (interface{}, error) { runtime.Goexit() return nil, nil @@ -159,14 +159,14 @@ func TestGoexitDo(t *testing.T) { } func TestContextCancelledForSome(t *testing.T) { - s := Group{} + var g Group[string, string] - f := func(ctx context.Context) (interface{}, error) { + f := func(ctx context.Context) (string, error) { for { select { case <-ctx.Done(): // Handle context cancellation - return nil, ctx.Err() + return "", ctx.Err() case <-time.After(200 * time.Millisecond): // Simulate some work return "hello", nil @@ -189,9 +189,7 @@ func TestContextCancelledForSome(t *testing.T) { } defer cancel() defer wg.Done() - values[i], errs[i] = s.Do(ctx, "key", func(ctx context.Context) (interface{}, error) { - return f(ctx) - }) + values[i], errs[i] = g.Do(ctx, "key", f) }(i) time.Sleep(10 * time.Millisecond) // ensure that first goroutine starts the work } @@ -208,14 +206,14 @@ func TestContextCancelledForSome(t *testing.T) { } func TestContextCancelledForAll(t *testing.T) { - s := Group{} + var g Group[string, string] - f := func(ctx context.Context) (interface{}, error) { + f := func(ctx context.Context) (string, error) { for { select { case <-ctx.Done(): // Handle context cancellation - return nil, ctx.Err() + return "", ctx.Err() case <-time.After(200 * time.Millisecond): // Simulate some work return "hello", nil @@ -231,9 +229,7 @@ func TestContextCancelledForAll(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() defer wg.Done() - values[i], errs[i] = s.Do(ctx, "key", func(ctx context.Context) (interface{}, error) { - return f(ctx) - }) + values[i], errs[i] = g.Do(ctx, "key", f) }(i) } wg.Wait() diff --git a/runtime/queries/column_timeseries.go b/runtime/queries/column_timeseries.go index 7411d49f7d8..1345ce49f8d 100644 --- a/runtime/queries/column_timeseries.go +++ b/runtime/queries/column_timeseries.go @@ -78,17 +78,17 @@ func (q *ColumnTimeseries) Resolve(ctx context.Context, rt *runtime.Runtime, ins return fmt.Errorf("not available for dialect '%s'", olap.Dialect()) } - return olap.WithConnection(ctx, priority, func(ctx context.Context, ensuredCtx context.Context) error { - timeRange, err := q.resolveNormaliseTimeRange(ctx, rt, instanceID, priority) - if err != nil { - return err - } + timeRange, err := q.resolveNormaliseTimeRange(ctx, rt, instanceID, priority) + if err != nil { + return err + } - if timeRange.Interval == runtimev1.TimeGrain_TIME_GRAIN_UNSPECIFIED { - q.Result = &ColumnTimeseriesResult{} - return nil - } + if timeRange.Interval == runtimev1.TimeGrain_TIME_GRAIN_UNSPECIFIED { + q.Result = &ColumnTimeseriesResult{} + return nil + } + return olap.WithConnection(ctx, priority, func(ctx context.Context, ensuredCtx context.Context) error { filter, args, err := buildFilterClauseForMetricsViewFilter(q.Filters, olap.Dialect()) if err != nil { return err @@ -290,15 +290,13 @@ func (q *ColumnTimeseries) createTimestampRollupReduction( valueColumn string, ) ([]*runtimev1.TimeSeriesValue, error) { safeTimestampColumnName := safeName(timestampColumnName) - tc := &TableCardinality{ - TableName: tableName, - } - err := tc.Resolve(ctx, rt, instanceID, priority) + + rowCount, err := q.resolveRowCount(ctx, tableName, olap, priority) if err != nil { return nil, err } - if tc.Result < int64(q.Pixels*4) { + if rowCount < int64(q.Pixels*4) { rows, err := olap.Execute(ctx, &drivers.Statement{ Query: `SELECT ` + safeTimestampColumnName + ` as ts, "` + valueColumn + `" as count FROM "` + tableName + `"`, Priority: priority, @@ -425,6 +423,32 @@ func (q *ColumnTimeseries) createTimestampRollupReduction( return results, nil } +func (q *ColumnTimeseries) resolveRowCount(ctx context.Context, tableName string, olap drivers.OLAPStore, priority int) (int64, error) { + rows, err := olap.Execute(ctx, &drivers.Statement{ + Query: fmt.Sprintf("SELECT count(*) AS count FROM %s", safeName(tableName)), + Priority: priority, + }) + if err != nil { + return 0, err + } + defer rows.Close() + + var count int64 + for rows.Next() { + err = rows.Scan(&count) + if err != nil { + return 0, err + } + } + + err = rows.Err() + if err != nil { + return 0, err + } + + return count, nil +} + // normaliseMeasures is called before this method so measure.SqlName will be non empty func getExpressionColumnsFromMeasures(measures []*runtimev1.ColumnTimeSeriesRequest_BasicMeasure) string { var result string diff --git a/runtime/query.go b/runtime/query.go index e873590d12f..f00aa5520dd 100644 --- a/runtime/query.go +++ b/runtime/query.go @@ -4,6 +4,22 @@ import ( "context" "fmt" "strings" + + "github.com/dgraph-io/ristretto" + "github.com/rilldata/rill/runtime/pkg/observability" + "github.com/rilldata/rill/runtime/pkg/singleflight" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/metric/global" +) + +var ( + meter = global.Meter("runtime") + queryCacheHitsCounter = observability.Must(meter.Int64ObservableCounter("query_cache.hits")) + queryCacheMissesCounter = observability.Must(meter.Int64ObservableCounter("query_cache.misses")) + queryCacheItemCountGauge = observability.Must(meter.Int64ObservableGauge("query_cache.items")) + queryCacheSizeBytesGauge = observability.Must(meter.Int64ObservableGauge("query_cache.size", metric.WithUnit("bytes"))) + queryCacheEntrySizeHistogram = observability.Must(meter.Int64Histogram("query_cache.entry_size", metric.WithUnit("bytes"))) ) type QueryResult struct { @@ -26,16 +42,6 @@ type Query interface { Resolve(ctx context.Context, rt *Runtime, instanceID string, priority int) error } -type queryCacheKey struct { - instanceID string - queryKey string - dependencyKey string -} - -func (q queryCacheKey) String() string { - return fmt.Sprintf("InstanceID:%sQueryKey:%sDependencyKey:%s", q.instanceID, q.queryKey, q.dependencyKey) -} - func (r *Runtime) Query(ctx context.Context, instanceID string, query Query, priority int) error { // If key is empty, skip caching qk := query.Key() @@ -73,27 +79,96 @@ func (r *Runtime) Query(ctx context.Context, instanceID string, query Query, pri instanceID: instanceID, queryKey: query.Key(), dependencyKey: depKey, + }.String() + + // Try to get from cache + if val, ok := r.queryCache.cache.Get(key); ok { + return query.UnmarshalResult(val) } - val, ok, err := r.queryCache.getOrLoad(ctx, key.String(), queryName(query), func(ctx context.Context) (any, error) { + // Load with singleflight + owner := false + val, err := r.queryCache.singleflight.Do(ctx, key, func(ctx context.Context) (any, error) { + // Try cache again + if val, ok := r.queryCache.cache.Get(key); ok { + return val, nil + } + + // Load err := query.Resolve(ctx, r, instanceID, priority) if err != nil { return nil, err } + owner = true res := query.MarshalResult() - return res, nil + r.queryCache.cache.Set(key, res.Value, res.Bytes) + queryCacheEntrySizeHistogram.Record(ctx, res.Bytes, metric.WithAttributes(attribute.String("query", queryName(query)))) + return res.Value, nil }) if err != nil { return err } - if ok { + if !owner { return query.UnmarshalResult(val) } return nil } +type queryCacheKey struct { + instanceID string + queryKey string + dependencyKey string +} + +func (k queryCacheKey) String() string { + return fmt.Sprintf("inst:%s deps:%s qry:%s", k.instanceID, k.dependencyKey, k.queryKey) +} + +type queryCache struct { + cache *ristretto.Cache + singleflight *singleflight.Group[string, any] + metrics metric.Registration +} + +func newQueryCache(sizeInBytes int64) *queryCache { + if sizeInBytes <= 100 { + panic(fmt.Sprintf("invalid cache size should be greater than 100: %v", sizeInBytes)) + } + cache, err := ristretto.NewCache(&ristretto.Config{ + // Use 5% of cache memory for storing counters. Each counter takes roughly 3 bytes. + // Recommended value is 10x the number of items in cache when full. + // Tune this again based on metrics. + NumCounters: int64(float64(sizeInBytes) * 0.05 / 3), + MaxCost: int64(float64(sizeInBytes) * 0.95), + BufferItems: 64, + Metrics: true, + }) + if err != nil { + panic(err) + } + + metrics := observability.Must(meter.RegisterCallback(func(ctx context.Context, observer metric.Observer) error { + observer.ObserveInt64(queryCacheHitsCounter, int64(cache.Metrics.Hits())) + observer.ObserveInt64(queryCacheMissesCounter, int64(cache.Metrics.Misses())) + observer.ObserveInt64(queryCacheItemCountGauge, int64(cache.Metrics.KeysAdded()-cache.Metrics.KeysEvicted())) + observer.ObserveInt64(queryCacheSizeBytesGauge, int64(cache.Metrics.CostAdded()-cache.Metrics.CostEvicted())) + return nil + }, queryCacheHitsCounter, queryCacheMissesCounter, queryCacheItemCountGauge, queryCacheSizeBytesGauge)) + + return &queryCache{ + cache: cache, + singleflight: &singleflight.Group[string, any]{}, + metrics: metrics, + } +} + +func (c *queryCache) close() error { + c.cache.Close() + return c.metrics.Unregister() +} + func queryName(q Query) string { nameWithPkg := fmt.Sprintf("%T", q) _, after, _ := strings.Cut(nameWithPkg, ".") diff --git a/runtime/runtime.go b/runtime/runtime.go index 18a7b7d8ada..ad9bb190ca0 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -2,6 +2,7 @@ package runtime import ( "context" + "errors" "fmt" "math" @@ -55,11 +56,9 @@ func New(opts *Options, logger *zap.Logger) (*Runtime, error) { } func (r *Runtime) Close() error { - err1 := r.metastore.Close() - err2 := r.connCache.Close() - r.queryCache.cache.Close() - if err1 != nil { - return err1 - } - return err2 + return errors.Join( + r.metastore.Close(), + r.connCache.Close(), + r.queryCache.close(), + ) }