From 37736729c40f1e7ad18521d1eb5c2a13e314ebde Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 26 Jan 2024 01:09:47 +0000 Subject: [PATCH] Fix handling of count with non-zero revision Signed-off-by: Brad Davidson --- pkg/drivers/generic/generic.go | 26 +++++++++++++++++++++---- pkg/drivers/nats/backend.go | 4 ++-- pkg/drivers/nats/backend_test.go | 6 +++--- pkg/drivers/nats/kv.go | 31 +++++++++++++++++++++++------- pkg/drivers/nats/logger.go | 8 ++++---- pkg/logstructured/logstructured.go | 8 ++++---- pkg/logstructured/sqllog/sql.go | 8 ++++++-- pkg/server/list.go | 21 ++++++++++++-------- pkg/server/types.go | 5 +++-- 9 files changed, 81 insertions(+), 36 deletions(-) diff --git a/pkg/drivers/generic/generic.go b/pkg/drivers/generic/generic.go index 481f051e..42e99a51 100644 --- a/pkg/drivers/generic/generic.go +++ b/pkg/drivers/generic/generic.go @@ -91,7 +91,8 @@ type Generic struct { RevisionSQL string ListRevisionStartSQL string GetRevisionAfterSQL string - CountSQL string + CountCurrentSQL string + CountRevisionSQL string AfterSQL string DeleteSQL string CompactSQL string @@ -219,12 +220,18 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered), GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, idOfKey), paramCharacter, numbered), - CountSQL: q(fmt.Sprintf(` + CountCurrentSQL: q(fmt.Sprintf(` SELECT (%s), COUNT(c.theid) FROM ( %s ) c`, revSQL, fmt.Sprintf(listSQL, "")), paramCharacter, numbered), + CountRevisionSQL: q(fmt.Sprintf(` + SELECT (%s), COUNT(c.theid) + FROM ( + %s + ) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.id <= ?")), paramCharacter, numbered), + AfterSQL: q(fmt.Sprintf(` SELECT (%s), (%s), %s FROM kine AS kv @@ -360,13 +367,24 @@ func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revi return d.query(ctx, sql, prefix, revision, startKey, revision, includeDeleted) } -func (d *Generic) Count(ctx context.Context, prefix string) (int64, int64, error) { +func (d *Generic) CountCurrent(ctx context.Context, prefix string) (int64, int64, error) { + var ( + rev sql.NullInt64 + id int64 + ) + + row := d.queryRow(ctx, d.CountCurrentSQL, prefix, false) + err := row.Scan(&rev, &id) + return rev.Int64, id, err +} + +func (d *Generic) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) { var ( rev sql.NullInt64 id int64 ) - row := d.queryRow(ctx, d.CountSQL, prefix, false) + row := d.queryRow(ctx, d.CountRevisionSQL, prefix, revision, false) err := row.Scan(&rev, &id) return rev.Int64, id, err } diff --git a/pkg/drivers/nats/backend.go b/pkg/drivers/nats/backend.go index c96fa902..5cbbcdb0 100644 --- a/pkg/drivers/nats/backend.go +++ b/pkg/drivers/nats/backend.go @@ -135,8 +135,8 @@ func (b *Backend) CurrentRevision(ctx context.Context) (int64, error) { } // Count returns an exact count of the number of matching keys and the current revision of the database. -func (b *Backend) Count(ctx context.Context, prefix string) (int64, int64, error) { - count, err := b.kv.Count(ctx, prefix) +func (b *Backend) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) { + count, err := b.kv.Count(ctx, prefix, revision) if err != nil { return 0, 0, err } diff --git a/pkg/drivers/nats/backend_test.go b/pkg/drivers/nats/backend_test.go index eea2bbd3..e5ba3f33 100644 --- a/pkg/drivers/nats/backend_test.go +++ b/pkg/drivers/nats/backend_test.go @@ -129,14 +129,14 @@ func TestBackend_Create(t *testing.T) { time.Sleep(2 * time.Millisecond) - srev, count, err := b.Count(ctx, "/") + srev, count, err := b.Count(ctx, "/", 0) noErr(t, err) expEqual(t, 4, srev) expEqual(t, 4, count) time.Sleep(time.Second) - srev, count, err = b.Count(ctx, "/") + srev, count, err = b.Count(ctx, "/", 0) noErr(t, err) expEqual(t, 4, srev) expEqual(t, 3, count) @@ -149,7 +149,7 @@ func TestBackend_Create(t *testing.T) { time.Sleep(2 * time.Millisecond) - srev, count, err = b.Count(ctx, "/") + srev, count, err = b.Count(ctx, "/", 0) noErr(t, err) expEqual(t, 6, srev) expEqual(t, 4, count) diff --git a/pkg/drivers/nats/kv.go b/pkg/drivers/nats/kv.go index 9a67eebb..3949f8fe 100644 --- a/pkg/drivers/nats/kv.go +++ b/pkg/drivers/nats/kv.go @@ -376,7 +376,7 @@ type keySeq struct { seq uint64 } -func (e *KeyValue) Count(ctx context.Context, prefix string) (int64, error) { +func (e *KeyValue) Count(ctx context.Context, prefix string, revision int64) (int64, error) { it := e.bt.Iter() if prefix != "" { @@ -396,11 +396,27 @@ func (e *KeyValue) Count(ctx context.Context, prefix string) (int64, error) { break } v := it.Value() - so := v[len(v)-1] - if so.op == jetstream.KeyValuePut { - if so.ex.IsZero() || so.ex.After(now) { - count++ + // Get the latest update for the key. + if revision <= 0 { + so := v[len(v)-1] + if so.op == jetstream.KeyValuePut { + if so.ex.IsZero() || so.ex.After(now) { + count++ + } + } + } else { + // Find the latest update below the given revision. + for i := len(v) - 1; i >= 0; i-- { + so := v[i] + if so.seq <= uint64(revision) { + if so.op == jetstream.KeyValuePut { + if so.ex.IsZero() || so.ex.After(now) { + count++ + } + } + break + } } } @@ -429,6 +445,7 @@ func (e *KeyValue) List(ctx context.Context, prefix, startKey string, limit, rev } var matches []*keySeq + now := time.Now() e.btm.RLock() @@ -448,7 +465,7 @@ func (e *KeyValue) List(ctx context.Context, prefix, startKey string, limit, rev if revision <= 0 { so := v[len(v)-1] if so.op == jetstream.KeyValuePut { - if so.ex.IsZero() || so.ex.After(time.Now()) { + if so.ex.IsZero() || so.ex.After(now) { matches = append(matches, &keySeq{key: k, seq: so.seq}) } } @@ -458,7 +475,7 @@ func (e *KeyValue) List(ctx context.Context, prefix, startKey string, limit, rev so := v[i] if so.seq <= uint64(revision) { if so.op == jetstream.KeyValuePut { - if so.ex.IsZero() || so.ex.After(time.Now()) { + if so.ex.IsZero() || so.ex.After(now) { matches = append(matches, &keySeq{key: k, seq: so.seq}) } } diff --git a/pkg/drivers/nats/logger.go b/pkg/drivers/nats/logger.go index f794fc66..913cf7d0 100644 --- a/pkg/drivers/nats/logger.go +++ b/pkg/drivers/nats/logger.go @@ -81,15 +81,15 @@ func (b *BackendLogger) List(ctx context.Context, prefix, startKey string, limit } // Count returns an exact count of the number of matching keys and the current revision of the database -func (b *BackendLogger) Count(ctx context.Context, prefix string) (revRet int64, count int64, err error) { +func (b *BackendLogger) Count(ctx context.Context, prefix string, revision int64) (revRet int64, count int64, err error) { start := time.Now() defer func() { dur := time.Since(start) - fStr := "COUNT %s => rev=%d, count=%d, err=%v, duration=%s" - b.logMethod(dur, fStr, prefix, revRet, count, err, dur) + fStr := "COUNT %s, rev=%d => rev=%d, count=%d, err=%v, duration=%s" + b.logMethod(dur, fStr, prefix, revision, revRet, count, err, dur) }() - return b.backend.Count(ctx, prefix) + return b.backend.Count(ctx, prefix, revision) } func (b *BackendLogger) Update(ctx context.Context, key string, value []byte, revision, lease int64) (revRet int64, kvRet *server.KeyValue, updateRet bool, errRet error) { diff --git a/pkg/logstructured/logstructured.go b/pkg/logstructured/logstructured.go index 90a71480..982bcbcd 100644 --- a/pkg/logstructured/logstructured.go +++ b/pkg/logstructured/logstructured.go @@ -22,7 +22,7 @@ type Log interface { List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeletes bool) (int64, []*server.Event, error) After(ctx context.Context, prefix string, revision, limit int64) (int64, []*server.Event, error) Watch(ctx context.Context, prefix string) <-chan []*server.Event - Count(ctx context.Context, prefix string) (int64, int64, error) + Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) Append(ctx context.Context, event *server.Event) (int64, error) DbSize(ctx context.Context) (int64, error) } @@ -198,11 +198,11 @@ func (l *LogStructured) List(ctx context.Context, prefix, startKey string, limit return rev, kvs, nil } -func (l *LogStructured) Count(ctx context.Context, prefix string) (revRet int64, count int64, err error) { +func (l *LogStructured) Count(ctx context.Context, prefix string, revision int64) (revRet int64, count int64, err error) { defer func() { - logrus.Tracef("COUNT %s => rev=%d, count=%d, err=%v", prefix, revRet, count, err) + logrus.Tracef("COUNT %s, rev=%d => rev=%d, count=%d, err=%v", prefix, revision, revRet, count, err) }() - rev, count, err := l.log.Count(ctx, prefix) + rev, count, err := l.log.Count(ctx, prefix, revision) if err != nil { return 0, 0, err } diff --git a/pkg/logstructured/sqllog/sql.go b/pkg/logstructured/sqllog/sql.go index 59cf6ebf..4fe0e41e 100644 --- a/pkg/logstructured/sqllog/sql.go +++ b/pkg/logstructured/sqllog/sql.go @@ -524,11 +524,15 @@ func canSkipRevision(rev, skip int64, skipTime time.Time) bool { return rev == skip && time.Since(skipTime) > time.Second } -func (s *SQLLog) Count(ctx context.Context, prefix string) (int64, int64, error) { +func (s *SQLLog) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) { if strings.HasSuffix(prefix, "/") { prefix += "%" } - return s.d.Count(ctx, prefix) + + if revision == 0 { + return s.d.CountCurrent(ctx, prefix) + } + return s.d.Count(ctx, prefix, revision) } func (s *SQLLog) Append(ctx context.Context, event *server.Event) (int64, error) { diff --git a/pkg/server/list.go b/pkg/server/list.go index e256e415..9c66b528 100644 --- a/pkg/server/list.go +++ b/pkg/server/list.go @@ -20,13 +20,14 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest) prefix = prefix + "/" } start := string(bytes.TrimRight(r.Key, "\x00")) + revision := r.Revision if r.CountOnly { - rev, count, err := l.backend.Count(ctx, prefix) + rev, count, err := l.backend.Count(ctx, prefix, revision) if err != nil { return nil, err } - logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, r.Revision, rev, count) + logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, count) return &RangeResponse{ Header: txnHeader(rev), Count: count, @@ -38,29 +39,33 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest) limit++ } - rev, kvs, err := l.backend.List(ctx, prefix, start, limit, r.Revision) + rev, kvs, err := l.backend.List(ctx, prefix, start, limit, revision) if err != nil { return nil, err } - logrus.Tracef("LIST key=%s, end=%s, revision=%d, currentRev=%d count=%d, limit=%d", r.Key, r.RangeEnd, r.Revision, rev, len(kvs), r.Limit) + logrus.Tracef("LIST key=%s, end=%s, revision=%d, currentRev=%d count=%d, limit=%d", r.Key, r.RangeEnd, revision, rev, len(kvs), r.Limit) resp := &RangeResponse{ Header: txnHeader(rev), Count: int64(len(kvs)), Kvs: kvs, } + // count the actual number of results if there are more items in the db. if limit > 0 && resp.Count > r.Limit { resp.More = true resp.Kvs = kvs[0 : limit-1] - // count the actual number of results if there are more items in the db. - _, count, err := l.backend.Count(ctx, prefix) + if revision == 0 { + revision = rev + } + + rev, resp.Count, err = l.backend.Count(ctx, prefix, revision) if err != nil { return nil, err } - logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, r.Revision, rev, count) - resp.Count = count + logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, resp.Count) + resp.Header = txnHeader(rev) } return resp, nil diff --git a/pkg/server/types.go b/pkg/server/types.go index 989f71d7..8b040829 100644 --- a/pkg/server/types.go +++ b/pkg/server/types.go @@ -23,7 +23,7 @@ type Backend interface { Create(ctx context.Context, key string, value []byte, lease int64) (int64, error) Delete(ctx context.Context, key string, revision int64) (int64, *KeyValue, bool, error) List(ctx context.Context, prefix, startKey string, limit, revision int64) (int64, []*KeyValue, error) - Count(ctx context.Context, prefix string) (int64, int64, error) + Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) Update(ctx context.Context, key string, value []byte, revision, lease int64) (int64, *KeyValue, bool, error) Watch(ctx context.Context, key string, revision int64) WatchResult DbSize(ctx context.Context) (int64, error) @@ -33,7 +33,8 @@ type Backend interface { type Dialect interface { ListCurrent(ctx context.Context, prefix string, limit int64, includeDeleted bool) (*sql.Rows, error) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) - Count(ctx context.Context, prefix string) (int64, int64, error) + CountCurrent(ctx context.Context, prefix string) (int64, int64, error) + Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) CurrentRevision(ctx context.Context) (int64, error) After(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error) Insert(ctx context.Context, key string, create, delete bool, createRevision, previousRevision int64, ttl int64, value, prevValue []byte) (int64, error)