From 4f1538dfa3dd2983cff9dcd152b5e7f73b73464c Mon Sep 17 00:00:00 2001 From: pirosiki197 Date: Sat, 1 Feb 2025 20:57:24 +0900 Subject: [PATCH] feat(cache): enhance cache strategy of transaction --- template/cache.go | 56 +++++++++++++++++-- template/cache.tmpl | 56 +++++++++++++++++-- template/driver.go | 39 ++++++++------ template/driver.tmpl | 39 ++++++++------ template/stmt.go | 124 ++++++++++++++++++++++++++++--------------- template/stmt.tmpl | 124 ++++++++++++++++++++++++++++--------------- test/cache/cache.go | 56 +++++++++++++++++-- test/cache/driver.go | 39 ++++++++------ test/cache/reset.go | 2 +- test/cache/stmt.go | 124 ++++++++++++++++++++++++++++--------------- test/cache_test.go | 124 +++++++++++++++++++++++++++++++++++++++++-- 11 files changed, 588 insertions(+), 195 deletions(-) diff --git a/template/cache.go b/template/cache.go index 2f04e23..3ccf9c4 100644 --- a/template/cache.go +++ b/template/cache.go @@ -5,8 +5,41 @@ import ( "database/sql/driver" "fmt" "strings" + "sync" + "sync/atomic" + "time" + + "github.com/motoki317/sc" + "github.com/traP-jp/isuc/domains" ) +type cacheWithInfo struct { + *sc.Cache[string, *cacheRows] + query string + info domains.CachePlanSelectQuery + uniqueOnly bool // if true, query is like "SELECT * FROM table WHERE pk = ?" + lastUpdate atomic.Int64 // time.Time.UnixNano() + lastUpdateByKey syncMap[int64] +} + +func (c *cacheWithInfo) updateTx() { + c.lastUpdate.Store(time.Now().UnixNano()) +} + +func (c *cacheWithInfo) updateByKeyTx(key string) { + c.lastUpdateByKey.Store(key, time.Now().UnixNano()) +} + +func (c *cacheWithInfo) isNewerThan(key string, t int64) bool { + if c.lastUpdate.Load() > t { + return true + } + if update, ok := c.lastUpdateByKey.Load(key); ok && update > t { + return true + } + return false +} + type ( queryKey struct{} stmtKey struct{} @@ -18,7 +51,7 @@ type ( func ExportMetrics() string { res := "" for query, cache := range caches { - stats := cache.cache.Stats() + stats := cache.Stats() progress := "[" for i := 0; i < 20; i++ { if i < int(stats.HitRatio()*20) { @@ -43,7 +76,7 @@ type CacheStats struct { func ExportCacheStats() map[string]CacheStats { res := make(map[string]CacheStats) for query, cache := range caches { - stats := cache.cache.Stats() + stats := cache.Stats() res[query] = CacheStats{ Query: query, HitRatio: stats.HitRatio(), @@ -56,7 +89,7 @@ func ExportCacheStats() map[string]CacheStats { func PurgeAllCaches() { for _, cache := range caches { - cache.cache.Purge() + cache.Purge() } } @@ -109,3 +142,20 @@ func replaceFn(ctx context.Context, key string) (*cacheRows, error) { } return cacheRows.clone(), nil } + +type syncMap[T any] struct { + m sync.Map +} + +func (m *syncMap[T]) Load(key string) (T, bool) { + var zero T + v, ok := m.m.Load(key) + if !ok { + return zero, false + } + return v.(T), true +} + +func (m *syncMap[T]) Store(key string, value T) { + m.m.Store(key, value) +} diff --git a/template/cache.tmpl b/template/cache.tmpl index db4f6d8..90f69fe 100644 --- a/template/cache.tmpl +++ b/template/cache.tmpl @@ -5,8 +5,41 @@ import ( "database/sql/driver" "fmt" "strings" + "sync" + "sync/atomic" + "time" + + "github.com/motoki317/sc" + "github.com/traP-jp/isuc/domains" ) +type cacheWithInfo struct { + *sc.Cache[string, *cacheRows] + query string + info domains.CachePlanSelectQuery + uniqueOnly bool // if true, query is like "SELECT * FROM table WHERE pk = ?" + lastUpdate atomic.Int64 // time.Time.UnixNano() + lastUpdateByKey syncMap[int64] +} + +func (c *cacheWithInfo) updateTx() { + c.lastUpdate.Store(time.Now().UnixNano()) +} + +func (c *cacheWithInfo) updateByKeyTx(key string) { + c.lastUpdateByKey.Store(key, time.Now().UnixNano()) +} + +func (c *cacheWithInfo) isNewerThan(key string, t int64) bool { + if c.lastUpdate.Load() > t { + return true + } + if update, ok := c.lastUpdateByKey.Load(key); ok && update > t { + return true + } + return false +} + type ( queryKey struct{} stmtKey struct{} @@ -18,7 +51,7 @@ type ( func ExportMetrics() string { res := "" for query, cache := range caches { - stats := cache.cache.Stats() + stats := cache.Stats() progress := "[" for i := 0; i < 20; i++ { if i < int(stats.HitRatio()*20) { @@ -43,7 +76,7 @@ type CacheStats struct { func ExportCacheStats() map[string]CacheStats { res := make(map[string]CacheStats) for query, cache := range caches { - stats := cache.cache.Stats() + stats := cache.Stats() res[query] = CacheStats{ Query: query, HitRatio: stats.HitRatio(), @@ -56,7 +89,7 @@ func ExportCacheStats() map[string]CacheStats { func PurgeAllCaches() { for _, cache := range caches { - cache.cache.Purge() + cache.Purge() } } @@ -109,3 +142,20 @@ func replaceFn(ctx context.Context, key string) (*cacheRows, error) { } return cacheRows.clone(), nil } + +type syncMap[T any] struct { + m sync.Map +} + +func (m *syncMap[T]) Load(key string) (T, bool) { + var zero T + v, ok := m.m.Load(key) + if !ok { + return zero, false + } + return v.(T), true +} + +func (m *syncMap[T]) Store(key string, value T) { + m.m.Store(key, value) +} diff --git a/template/driver.go b/template/driver.go index da7a84c..7b2ba75 100644 --- a/template/driver.go +++ b/template/driver.go @@ -51,18 +51,18 @@ func init() { conditions := query.Select.Conditions if isSingleUniqueCondition(conditions, query.Select.Table) { - caches[normalized] = cacheWithInfo{ + caches[normalized] = &cacheWithInfo{ + Cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), query: normalized, info: *query.Select, - cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), uniqueOnly: true, } continue } - caches[query.Query] = cacheWithInfo{ + caches[query.Query] = &cacheWithInfo{ + Cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), query: query.Query, info: *query.Select, - cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), uniqueOnly: false, } @@ -105,7 +105,8 @@ var ( type cacheConn struct { inner driver.Conn tx bool - cleanUp []func() + txStart int64 // time.Time.UnixNano() + cleanUp cleanUpTask } func (c *cacheConn) Prepare(rawQuery string) (driver.Stmt, error) { @@ -148,23 +149,26 @@ func (c *cacheConn) Begin() (driver.Tx, error) { return nil, err } c.tx = true + c.txStart = time.Now().UnixNano() return &cacheTx{conn: c, inner: inner}, nil } func (c *cacheConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var inner driver.Tx + var err error if i, ok := c.inner.(driver.ConnBeginTx); ok { - inner, err := i.BeginTx(ctx, opts) + inner, err = i.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + } else { + inner, err = c.inner.Begin() if err != nil { return nil, err } - c.tx = true - return &cacheTx{conn: c, inner: inner}, nil - } - inner, err := c.inner.Begin() - if err != nil { - return nil, err } c.tx = true + c.txStart = time.Now().UnixNano() return &cacheTx{conn: c, inner: inner}, nil } @@ -185,10 +189,13 @@ type cacheTx struct { func (t *cacheTx) Commit() error { t.conn.tx = false defer func() { - for _, c := range t.conn.cleanUp { - c() + for _, c := range t.conn.cleanUp.purge { + c.Purge() + } + for _, forget := range t.conn.cleanUp.forget { + forget.cache.Forget(forget.key) } - t.conn.cleanUp = t.conn.cleanUp[:0] + t.conn.cleanUp.reset() }() return t.inner.Commit() } @@ -196,7 +203,7 @@ func (t *cacheTx) Commit() error { func (t *cacheTx) Rollback() error { t.conn.tx = false // no need to clean up - t.conn.cleanUp = nil + t.conn.cleanUp.reset() return t.inner.Rollback() } diff --git a/template/driver.tmpl b/template/driver.tmpl index 391f28e..ad29c21 100644 --- a/template/driver.tmpl +++ b/template/driver.tmpl @@ -50,18 +50,18 @@ func init() { conditions := query.Select.Conditions if isSingleUniqueCondition(conditions, query.Select.Table) { - caches[normalized] = cacheWithInfo{ + caches[normalized] = &cacheWithInfo{ + Cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), query: normalized, info: *query.Select, - cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), uniqueOnly: true, } continue } - caches[query.Query] = cacheWithInfo{ + caches[query.Query] = &cacheWithInfo{ + Cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), query: query.Query, info: *query.Select, - cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), uniqueOnly: false, } @@ -104,7 +104,8 @@ var ( type cacheConn struct { inner driver.Conn tx bool - cleanUp []func() + txStart int64 // time.Time.UnixNano() + cleanUp cleanUpTask } func (c *cacheConn) Prepare(rawQuery string) (driver.Stmt, error) { @@ -147,23 +148,26 @@ func (c *cacheConn) Begin() (driver.Tx, error) { return nil, err } c.tx = true + c.txStart = time.Now().UnixNano() return &cacheTx{conn: c, inner: inner}, nil } func (c *cacheConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var inner driver.Tx + var err error if i, ok := c.inner.(driver.ConnBeginTx); ok { - inner, err := i.BeginTx(ctx, opts) + inner, err = i.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + } else { + inner, err = c.inner.Begin() if err != nil { return nil, err } - c.tx = true - return &cacheTx{conn: c, inner: inner}, nil - } - inner, err := c.inner.Begin() - if err != nil { - return nil, err } c.tx = true + c.txStart = time.Now().UnixNano() return &cacheTx{conn: c, inner: inner}, nil } @@ -184,10 +188,13 @@ type cacheTx struct { func (t *cacheTx) Commit() error { t.conn.tx = false defer func() { - for _, c := range t.conn.cleanUp { - c() + for _, c := range t.conn.cleanUp.purge { + c.Purge() + } + for _, forget := range t.conn.cleanUp.forget { + forget.cache.Forget(forget.key) } - t.conn.cleanUp = t.conn.cleanUp[:0] + t.conn.cleanUp.reset() }() return t.inner.Commit() } @@ -195,7 +202,7 @@ func (t *cacheTx) Commit() error { func (t *cacheTx) Rollback() error { t.conn.tx = false // no need to clean up - t.conn.cleanUp = nil + t.conn.cleanUp.reset() return t.inner.Rollback() } diff --git a/template/stmt.go b/template/stmt.go index 8b201bc..f8d9ced 100644 --- a/template/stmt.go +++ b/template/stmt.go @@ -6,22 +6,14 @@ import ( "log" "slices" - "github.com/motoki317/sc" "github.com/traP-jp/isuc/domains" "github.com/traP-jp/isuc/normalizer" ) -type cacheWithInfo struct { - query string - info domains.CachePlanSelectQuery - uniqueOnly bool // if true, query is like "SELECT * FROM table WHERE pk = ?" - cache *sc.Cache[string, *cacheRows] -} - // NOTE: no write happens to this map, so it's safe to use in concurrent environment -var caches = make(map[string]cacheWithInfo) +var caches = make(map[string]*cacheWithInfo) -var cacheByTable = make(map[string][]cacheWithInfo) +var cacheByTable = make(map[string][]*cacheWithInfo) var _ driver.Stmt = &customCacheStatement{} @@ -101,10 +93,19 @@ func (c *cacheConn) ExecContext(ctx context.Context, rawQuery string, nvargs []d } if !c.tx { - for _, cleanUp := range c.cleanUp { - cleanUp() + c.cleanUp.do() + c.cleanUp.reset() + } else { + // cleanups are deferred until the transaction is committed + // because we need to forget the cache only if the transaction is committed + + // update the cache + for _, c := range c.cleanUp.purge { + c.updateTx() + } + for _, forget := range c.cleanUp.forget { + forget.cache.updateByKeyTx(forget.key) } - c.cleanUp = c.cleanUp[:0] } return res, err @@ -117,7 +118,7 @@ func (c *cacheConn) execInsert(ctx context.Context, rawQuery string, queryInfo d } cleanUp := handleInsertQuery(queryInfo.Query, *queryInfo.Insert, args) - c.cleanUp = append(c.cleanUp, cleanUp...) + c.cleanUp.append(cleanUp) return inner.ExecContext(ctx, rawQuery, nvargs) } @@ -126,7 +127,7 @@ func (c *cacheConn) execUpdate(ctx context.Context, rawQuery string, queryInfo d args := namedToValue(nvargs) cleanUp := handleUpdateQuery(*queryInfo.Update, args) - c.cleanUp = append(c.cleanUp, cleanUp...) + c.cleanUp.append(cleanUp) return inner.ExecContext(ctx, rawQuery, nvargs) } @@ -135,7 +136,7 @@ func (c *cacheConn) execDelete(ctx context.Context, rawQuery string, queryInfo d args := namedToValue(nvargs) cleanUp := handleDeleteQuery(*queryInfo.Delete, args) - c.cleanUp = append(c.cleanUp, cleanUp...) + c.cleanUp.append(cleanUp) return inner.ExecContext(ctx, rawQuery, nvargs) } @@ -150,7 +151,7 @@ func (s *customCacheStatement) Query(args []driver.Value) (driver.Rows, error) { return s.inQuery(args) } - rows, err := caches[cacheName(s.query)].cache.Get(ctx, cacheKey(args)) + rows, err := caches[cacheName(s.query)].Get(ctx, cacheKey(args)) if err != nil { return nil, err } @@ -169,7 +170,7 @@ func (s *customCacheStatement) inQuery(args []driver.Value) (driver.Rows, error) var cache *cacheWithInfo for _, c := range cacheByTable[table] { if len(c.info.Conditions) == 1 && c.info.Conditions[0].Column == s.queryInfo.Select.Conditions[0].Column && c.info.Conditions[0].Operator == domains.CachePlanOperator_EQ { - cache = &c + cache = c } } if cache == nil { @@ -185,7 +186,7 @@ func (s *customCacheStatement) inQuery(args []driver.Value) (driver.Rows, error) } ctx := context.WithValue(context.Background(), stmtKey{}, stmt) ctx = context.WithValue(ctx, argsKey{}, []driver.Value{condValue}) - rows, err := cache.cache.Get(ctx, cacheKey([]driver.Value{condValue})) + rows, err := cache.Get(ctx, cacheKey([]driver.Value{condValue})) if err != nil { return nil, err } @@ -201,10 +202,6 @@ func (c *cacheConn) QueryContext(ctx context.Context, rawQuery string, nvargs [] return nil, driver.ErrSkip } - if c.tx { - return inner.QueryContext(ctx, rawQuery, nvargs) - } - normalizedQuery := normalizer.NormalizeQuery(rawQuery) queryInfo, ok := queryMap[normalizedQuery] @@ -226,11 +223,19 @@ func (c *cacheConn) QueryContext(ctx context.Context, rawQuery string, nvargs [] args[i] = nv.Value } - cache := caches[queryInfo.Query].cache + cache := caches[queryInfo.Query] + key := cacheKey(args) + + if c.tx && cache.isNewerThan(key, c.txStart) { + // cache is newer than the transaction start time + // we should not use the cache + return inner.QueryContext(ctx, rawQuery, nvargs) + } + cachectx := context.WithValue(ctx, namedValueArgsKey{}, nvargs) cachectx = context.WithValue(cachectx, queryerCtxKey{}, inner) cachectx = context.WithValue(cachectx, queryKey{}, rawQuery) - rows, err := cache.Get(cachectx, cacheKey(args)) + rows, err := cache.Get(cachectx, key) if err != nil { return nil, err } @@ -252,7 +257,7 @@ func (c *cacheConn) inQuery(ctx context.Context, query string, args []driver.Nam var cache *cacheWithInfo for _, c := range cacheByTable[table] { if len(c.info.Conditions) == 1 && c.info.Conditions[0].Column == queryInfo.Select.Conditions[0].Column && c.info.Conditions[0].Operator == domains.CachePlanOperator_EQ { - cache = &c + cache = c } } if cache == nil { @@ -265,7 +270,7 @@ func (c *cacheConn) inQuery(ctx context.Context, query string, args []driver.Nam cacheCtx := context.WithValue(ctx, queryKey{}, cache.query) cacheCtx = context.WithValue(cacheCtx, queryerCtxKey{}, inner) cacheCtx = context.WithValue(cacheCtx, namedValueArgsKey{}, nvargs) - rows, err := cache.cache.Get(cacheCtx, cacheKey([]driver.Value{condValue.Value})) + rows, err := cache.Get(cacheCtx, cacheKey([]driver.Value{condValue.Value})) if err != nil { return nil, err } @@ -275,7 +280,7 @@ func (c *cacheConn) inQuery(ctx context.Context, query string, args []driver.Nam return mergeCachedRows(allRows), nil } -func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, insertValues []driver.Value) (cleanUP []func()) { +func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, insertValues []driver.Value) (cleanUp cleanUpTask) { table := queryInfo.Table insertArgs, _ := normalizer.NormalizeArgs(query) @@ -290,7 +295,7 @@ func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, ins cacheConditions := cache.info.Conditions isComplexQuery := len(cacheConditions) != 1 || len(insertArgs.ExtraArgs) > 0 || cacheConditions[0].Operator != domains.CachePlanOperator_EQ if isComplexQuery { - cleanUP = append(cleanUP, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) continue } @@ -301,21 +306,23 @@ func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, ins // select query: "SELECT * FROM table WHERE col1 = ?" // forget the cache for row := range rows { - cleanUP = append(cleanUP, func() { cache.cache.Forget(cacheKey([]driver.Value{row[insertColumnIdx]})) }) + cleanUp.forget = append(cleanUp.forget, forgetTask{cache, cacheKey([]driver.Value{row[insertColumnIdx]})}) } } else { - cleanUP = append(cleanUP, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } } - return cleanUP + return cleanUp } -func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Value) (cleanUp []func()) { +func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Value) cleanUpTask { // TODO: support composite primary key and other unique key table := queryInfo.Table updateConditions := queryInfo.Conditions + var cleanUp cleanUpTask + // if query is NOT "UPDATE `table` SET ... WHERE `unique_col` = ?" if !isSingleUniqueCondition(updateConditions, table) { for _, cache := range cacheByTable[table] { @@ -323,9 +330,9 @@ func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Val // no need to purge because the cache does not contain the updated column continue } - cleanUp = append(cleanUp, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } - return + return cleanUp } updateCondition := updateConditions[0] @@ -340,18 +347,20 @@ func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Val cacheConditions := cache.info.Conditions if isSingleUniqueCondition(cacheConditions, table) && cacheConditions[0].Column == updateCondition.Column { // forget only the updated row - cleanUp = append(cleanUp, func() { cache.cache.Forget(cacheKey([]driver.Value{uniqueValue})) }) + cleanUp.forget = append(cleanUp.forget, forgetTask{cache, cacheKey([]driver.Value{uniqueValue})}) } else { - cleanUp = append(cleanUp, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } } return cleanUp } -func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Value) (cleanUp []func()) { +func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Value) cleanUpTask { table := queryInfo.Table + var cleanUp cleanUpTask + // if query is like "DELETE FROM table WHERE unique = ?" var deleteByUnique bool if len(queryInfo.Conditions) == 1 { @@ -361,10 +370,8 @@ func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Val } if !deleteByUnique { // we should purge all cache - for _, cache := range cacheByTable[table] { - cleanUp = append(cleanUp, cache.cache.Purge) - } - return + cleanUp.purge = append(cleanUp.purge, cacheByTable[table]...) + return cleanUp } uniqueValue := args[queryInfo.Conditions[0].Placeholder.Index] @@ -373,9 +380,9 @@ func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Val if cache.uniqueOnly { // query like "SELECT * FROM table WHERE pk = ?" // we should forget the cache - cleanUp = append(cleanUp, func() { cache.cache.Forget(cacheKey([]driver.Value{uniqueValue})) }) + cleanUp.forget = append(cleanUp.forget, forgetTask{cache, cacheKey([]driver.Value{uniqueValue})}) } else { - cleanUp = append(cleanUp, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } } @@ -418,3 +425,32 @@ func namedToValue(nvargs []driver.NamedValue) []driver.Value { } return args } + +type forgetTask struct { + cache *cacheWithInfo + key string +} + +type cleanUpTask struct { + purge []*cacheWithInfo + forget []forgetTask +} + +func (c *cleanUpTask) reset() { + c.purge = c.purge[:0] + c.forget = c.forget[:0] +} + +func (c *cleanUpTask) do() { + for _, cache := range c.purge { + cache.Purge() + } + for _, forget := range c.forget { + forget.cache.Forget(forget.key) + } +} + +func (c *cleanUpTask) append(tasks cleanUpTask) { + c.purge = append(c.purge, tasks.purge...) + c.forget = append(c.forget, tasks.forget...) +} diff --git a/template/stmt.tmpl b/template/stmt.tmpl index 12dd166..82c9570 100644 --- a/template/stmt.tmpl +++ b/template/stmt.tmpl @@ -6,22 +6,14 @@ import ( "log" "slices" - "github.com/motoki317/sc" "github.com/traP-jp/isuc/domains" "github.com/traP-jp/isuc/normalizer" ) -type cacheWithInfo struct { - query string - info domains.CachePlanSelectQuery - uniqueOnly bool // if true, query is like "SELECT * FROM table WHERE pk = ?" - cache *sc.Cache[string, *cacheRows] -} - // NOTE: no write happens to this map, so it's safe to use in concurrent environment -var caches = make(map[string]cacheWithInfo) +var caches = make(map[string]*cacheWithInfo) -var cacheByTable = make(map[string][]cacheWithInfo) +var cacheByTable = make(map[string][]*cacheWithInfo) var _ driver.Stmt = &customCacheStatement{} @@ -101,10 +93,19 @@ func (c *cacheConn) ExecContext(ctx context.Context, rawQuery string, nvargs []d } if !c.tx { - for _, cleanUp := range c.cleanUp { - cleanUp() + c.cleanUp.do() + c.cleanUp.reset() + } else { + // cleanups are deferred until the transaction is committed + // because we need to forget the cache only if the transaction is committed + + // update the cache + for _, c := range c.cleanUp.purge { + c.updateTx() + } + for _, forget := range c.cleanUp.forget { + forget.cache.updateByKeyTx(forget.key) } - c.cleanUp = c.cleanUp[:0] } return res, err @@ -117,7 +118,7 @@ func (c *cacheConn) execInsert(ctx context.Context, rawQuery string, queryInfo d } cleanUp := handleInsertQuery(queryInfo.Query, *queryInfo.Insert, args) - c.cleanUp = append(c.cleanUp, cleanUp...) + c.cleanUp.append(cleanUp) return inner.ExecContext(ctx, rawQuery, nvargs) } @@ -126,7 +127,7 @@ func (c *cacheConn) execUpdate(ctx context.Context, rawQuery string, queryInfo d args := namedToValue(nvargs) cleanUp := handleUpdateQuery(*queryInfo.Update, args) - c.cleanUp = append(c.cleanUp, cleanUp...) + c.cleanUp.append(cleanUp) return inner.ExecContext(ctx, rawQuery, nvargs) } @@ -135,7 +136,7 @@ func (c *cacheConn) execDelete(ctx context.Context, rawQuery string, queryInfo d args := namedToValue(nvargs) cleanUp := handleDeleteQuery(*queryInfo.Delete, args) - c.cleanUp = append(c.cleanUp, cleanUp...) + c.cleanUp.append(cleanUp) return inner.ExecContext(ctx, rawQuery, nvargs) } @@ -150,7 +151,7 @@ func (s *customCacheStatement) Query(args []driver.Value) (driver.Rows, error) { return s.inQuery(args) } - rows, err := caches[cacheName(s.query)].cache.Get(ctx, cacheKey(args)) + rows, err := caches[cacheName(s.query)].Get(ctx, cacheKey(args)) if err != nil { return nil, err } @@ -169,7 +170,7 @@ func (s *customCacheStatement) inQuery(args []driver.Value) (driver.Rows, error) var cache *cacheWithInfo for _, c := range cacheByTable[table] { if len(c.info.Conditions) == 1 && c.info.Conditions[0].Column == s.queryInfo.Select.Conditions[0].Column && c.info.Conditions[0].Operator == domains.CachePlanOperator_EQ { - cache = &c + cache = c } } if cache == nil { @@ -185,7 +186,7 @@ func (s *customCacheStatement) inQuery(args []driver.Value) (driver.Rows, error) } ctx := context.WithValue(context.Background(), stmtKey{}, stmt) ctx = context.WithValue(ctx, argsKey{}, []driver.Value{condValue}) - rows, err := cache.cache.Get(ctx, cacheKey([]driver.Value{condValue})) + rows, err := cache.Get(ctx, cacheKey([]driver.Value{condValue})) if err != nil { return nil, err } @@ -201,10 +202,6 @@ func (c *cacheConn) QueryContext(ctx context.Context, rawQuery string, nvargs [] return nil, driver.ErrSkip } - if c.tx { - return inner.QueryContext(ctx, rawQuery, nvargs) - } - normalizedQuery := normalizer.NormalizeQuery(rawQuery) queryInfo, ok := queryMap[normalizedQuery] @@ -226,11 +223,19 @@ func (c *cacheConn) QueryContext(ctx context.Context, rawQuery string, nvargs [] args[i] = nv.Value } - cache := caches[queryInfo.Query].cache + cache := caches[queryInfo.Query] + key := cacheKey(args) + + if c.tx && cache.isNewerThan(key, c.txStart) { + // cache is newer than the transaction start time + // we should not use the cache + return inner.QueryContext(ctx, rawQuery, nvargs) + } + cachectx := context.WithValue(ctx, namedValueArgsKey{}, nvargs) cachectx = context.WithValue(cachectx, queryerCtxKey{}, inner) cachectx = context.WithValue(cachectx, queryKey{}, rawQuery) - rows, err := cache.Get(cachectx, cacheKey(args)) + rows, err := cache.Get(cachectx, key) if err != nil { return nil, err } @@ -252,7 +257,7 @@ func (c *cacheConn) inQuery(ctx context.Context, query string, args []driver.Nam var cache *cacheWithInfo for _, c := range cacheByTable[table] { if len(c.info.Conditions) == 1 && c.info.Conditions[0].Column == queryInfo.Select.Conditions[0].Column && c.info.Conditions[0].Operator == domains.CachePlanOperator_EQ { - cache = &c + cache = c } } if cache == nil { @@ -265,7 +270,7 @@ func (c *cacheConn) inQuery(ctx context.Context, query string, args []driver.Nam cacheCtx := context.WithValue(ctx, queryKey{}, cache.query) cacheCtx = context.WithValue(cacheCtx, queryerCtxKey{}, inner) cacheCtx = context.WithValue(cacheCtx, namedValueArgsKey{}, nvargs) - rows, err := cache.cache.Get(cacheCtx, cacheKey([]driver.Value{condValue.Value})) + rows, err := cache.Get(cacheCtx, cacheKey([]driver.Value{condValue.Value})) if err != nil { return nil, err } @@ -275,7 +280,7 @@ func (c *cacheConn) inQuery(ctx context.Context, query string, args []driver.Nam return mergeCachedRows(allRows), nil } -func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, insertValues []driver.Value) (cleanUP []func()) { +func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, insertValues []driver.Value) (cleanUp cleanUpTask) { table := queryInfo.Table insertArgs, _ := normalizer.NormalizeArgs(query) @@ -290,7 +295,7 @@ func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, ins cacheConditions := cache.info.Conditions isComplexQuery := len(cacheConditions) != 1 || len(insertArgs.ExtraArgs) > 0 || cacheConditions[0].Operator != domains.CachePlanOperator_EQ if isComplexQuery { - cleanUP = append(cleanUP, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) continue } @@ -301,21 +306,23 @@ func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, ins // select query: "SELECT * FROM table WHERE col1 = ?" // forget the cache for row := range rows { - cleanUP = append(cleanUP, func() { cache.cache.Forget(cacheKey([]driver.Value{row[insertColumnIdx]})) }) + cleanUp.forget = append(cleanUp.forget, forgetTask{cache, cacheKey([]driver.Value{row[insertColumnIdx]})}) } } else { - cleanUP = append(cleanUP, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } } - return cleanUP + return cleanUp } -func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Value) (cleanUp []func()) { +func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Value) cleanUpTask { // TODO: support composite primary key and other unique key table := queryInfo.Table updateConditions := queryInfo.Conditions + var cleanUp cleanUpTask + // if query is NOT "UPDATE `table` SET ... WHERE `unique_col` = ?" if !isSingleUniqueCondition(updateConditions, table) { for _, cache := range cacheByTable[table] { @@ -323,9 +330,9 @@ func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Val // no need to purge because the cache does not contain the updated column continue } - cleanUp = append(cleanUp, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } - return + return cleanUp } updateCondition := updateConditions[0] @@ -340,18 +347,20 @@ func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Val cacheConditions := cache.info.Conditions if isSingleUniqueCondition(cacheConditions, table) && cacheConditions[0].Column == updateCondition.Column { // forget only the updated row - cleanUp = append(cleanUp, func() { cache.cache.Forget(cacheKey([]driver.Value{uniqueValue})) }) + cleanUp.forget = append(cleanUp.forget, forgetTask{cache, cacheKey([]driver.Value{uniqueValue})}) } else { - cleanUp = append(cleanUp, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } } return cleanUp } -func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Value) (cleanUp []func()) { +func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Value) cleanUpTask { table := queryInfo.Table + var cleanUp cleanUpTask + // if query is like "DELETE FROM table WHERE unique = ?" var deleteByUnique bool if len(queryInfo.Conditions) == 1 { @@ -361,10 +370,8 @@ func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Val } if !deleteByUnique { // we should purge all cache - for _, cache := range cacheByTable[table] { - cleanUp = append(cleanUp, cache.cache.Purge) - } - return + cleanUp.purge = append(cleanUp.purge, cacheByTable[table]...) + return cleanUp } uniqueValue := args[queryInfo.Conditions[0].Placeholder.Index] @@ -373,9 +380,9 @@ func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Val if cache.uniqueOnly { // query like "SELECT * FROM table WHERE pk = ?" // we should forget the cache - cleanUp = append(cleanUp, func() { cache.cache.Forget(cacheKey([]driver.Value{uniqueValue})) }) + cleanUp.forget = append(cleanUp.forget, forgetTask{cache, cacheKey([]driver.Value{uniqueValue})}) } else { - cleanUp = append(cleanUp, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } } @@ -418,3 +425,32 @@ func namedToValue(nvargs []driver.NamedValue) []driver.Value { } return args } + +type forgetTask struct { + cache *cacheWithInfo + key string +} + +type cleanUpTask struct { + purge []*cacheWithInfo + forget []forgetTask +} + +func (c *cleanUpTask) reset() { + c.purge = c.purge[:0] + c.forget = c.forget[:0] +} + +func (c *cleanUpTask) do() { + for _, cache := range c.purge { + cache.Purge() + } + for _, forget := range c.forget { + forget.cache.Forget(forget.key) + } +} + +func (c *cleanUpTask) append(tasks cleanUpTask) { + c.purge = append(c.purge, tasks.purge...) + c.forget = append(c.forget, tasks.forget...) +} diff --git a/test/cache/cache.go b/test/cache/cache.go index cb88f66..2b9d5a2 100644 --- a/test/cache/cache.go +++ b/test/cache/cache.go @@ -5,8 +5,41 @@ import ( "database/sql/driver" "fmt" "strings" + "sync" + "sync/atomic" + "time" + + "github.com/motoki317/sc" + "github.com/traP-jp/isuc/domains" ) +type cacheWithInfo struct { + *sc.Cache[string, *cacheRows] + query string + info domains.CachePlanSelectQuery + uniqueOnly bool // if true, query is like "SELECT * FROM table WHERE pk = ?" + lastUpdate atomic.Int64 // time.Time.UnixNano() + lastUpdateByKey syncMap[int64] +} + +func (c *cacheWithInfo) updateTx() { + c.lastUpdate.Store(time.Now().UnixNano()) +} + +func (c *cacheWithInfo) updateByKeyTx(key string) { + c.lastUpdateByKey.Store(key, time.Now().UnixNano()) +} + +func (c *cacheWithInfo) isNewerThan(key string, t int64) bool { + if c.lastUpdate.Load() > t { + return true + } + if update, ok := c.lastUpdateByKey.Load(key); ok && update > t { + return true + } + return false +} + type ( queryKey struct{} stmtKey struct{} @@ -18,7 +51,7 @@ type ( func ExportMetrics() string { res := "" for query, cache := range caches { - stats := cache.cache.Stats() + stats := cache.Stats() progress := "[" for i := 0; i < 20; i++ { if i < int(stats.HitRatio()*20) { @@ -43,7 +76,7 @@ type CacheStats struct { func ExportCacheStats() map[string]CacheStats { res := make(map[string]CacheStats) for query, cache := range caches { - stats := cache.cache.Stats() + stats := cache.Stats() res[query] = CacheStats{ Query: query, HitRatio: stats.HitRatio(), @@ -56,7 +89,7 @@ func ExportCacheStats() map[string]CacheStats { func PurgeAllCaches() { for _, cache := range caches { - cache.cache.Purge() + cache.Purge() } } @@ -109,3 +142,20 @@ func replaceFn(ctx context.Context, key string) (*cacheRows, error) { } return cacheRows.clone(), nil } + +type syncMap[T any] struct { + m sync.Map +} + +func (m *syncMap[T]) Load(key string) (T, bool) { + var zero T + v, ok := m.m.Load(key) + if !ok { + return zero, false + } + return v.(T), true +} + +func (m *syncMap[T]) Store(key string, value T) { + m.m.Store(key, value) +} diff --git a/test/cache/driver.go b/test/cache/driver.go index 6301c05..aa0b77a 100644 --- a/test/cache/driver.go +++ b/test/cache/driver.go @@ -131,18 +131,18 @@ func init() { conditions := query.Select.Conditions if isSingleUniqueCondition(conditions, query.Select.Table) { - caches[normalized] = cacheWithInfo{ + caches[normalized] = &cacheWithInfo{ + Cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), query: normalized, info: *query.Select, - cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), uniqueOnly: true, } continue } - caches[query.Query] = cacheWithInfo{ + caches[query.Query] = &cacheWithInfo{ + Cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), query: query.Query, info: *query.Select, - cache: sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute), uniqueOnly: false, } @@ -185,7 +185,8 @@ var ( type cacheConn struct { inner driver.Conn tx bool - cleanUp []func() + txStart int64 // time.Time.UnixNano() + cleanUp cleanUpTask } func (c *cacheConn) Prepare(rawQuery string) (driver.Stmt, error) { @@ -228,23 +229,26 @@ func (c *cacheConn) Begin() (driver.Tx, error) { return nil, err } c.tx = true + c.txStart = time.Now().UnixNano() return &cacheTx{conn: c, inner: inner}, nil } func (c *cacheConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var inner driver.Tx + var err error if i, ok := c.inner.(driver.ConnBeginTx); ok { - inner, err := i.BeginTx(ctx, opts) + inner, err = i.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + } else { + inner, err = c.inner.Begin() if err != nil { return nil, err } - c.tx = true - return &cacheTx{conn: c, inner: inner}, nil - } - inner, err := c.inner.Begin() - if err != nil { - return nil, err } c.tx = true + c.txStart = time.Now().UnixNano() return &cacheTx{conn: c, inner: inner}, nil } @@ -265,10 +269,13 @@ type cacheTx struct { func (t *cacheTx) Commit() error { t.conn.tx = false defer func() { - for _, c := range t.conn.cleanUp { - c() + for _, c := range t.conn.cleanUp.purge { + c.Purge() + } + for _, forget := range t.conn.cleanUp.forget { + forget.cache.Forget(forget.key) } - t.conn.cleanUp = t.conn.cleanUp[:0] + t.conn.cleanUp.reset() }() return t.inner.Commit() } @@ -276,7 +283,7 @@ func (t *cacheTx) Commit() error { func (t *cacheTx) Rollback() error { t.conn.tx = false // no need to clean up - t.conn.cleanUp = nil + t.conn.cleanUp.reset() return t.inner.Rollback() } diff --git a/test/cache/reset.go b/test/cache/reset.go index e26df24..4ba2077 100644 --- a/test/cache/reset.go +++ b/test/cache/reset.go @@ -9,7 +9,7 @@ import ( func ResetCache() { for key := range caches { v := caches[key] - *v.cache = *sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute) + *v.Cache = *sc.NewMust(replaceFn, 10*time.Minute, 10*time.Minute) caches[key] = v } } diff --git a/test/cache/stmt.go b/test/cache/stmt.go index a61e3ba..1adb510 100644 --- a/test/cache/stmt.go +++ b/test/cache/stmt.go @@ -6,22 +6,14 @@ import ( "log" "slices" - "github.com/motoki317/sc" "github.com/traP-jp/isuc/domains" "github.com/traP-jp/isuc/normalizer" ) -type cacheWithInfo struct { - query string - info domains.CachePlanSelectQuery - uniqueOnly bool // if true, query is like "SELECT * FROM table WHERE pk = ?" - cache *sc.Cache[string, *cacheRows] -} - // NOTE: no write happens to this map, so it's safe to use in concurrent environment -var caches = make(map[string]cacheWithInfo) +var caches = make(map[string]*cacheWithInfo) -var cacheByTable = make(map[string][]cacheWithInfo) +var cacheByTable = make(map[string][]*cacheWithInfo) var _ driver.Stmt = &customCacheStatement{} @@ -101,10 +93,19 @@ func (c *cacheConn) ExecContext(ctx context.Context, rawQuery string, nvargs []d } if !c.tx { - for _, cleanUp := range c.cleanUp { - cleanUp() + c.cleanUp.do() + c.cleanUp.reset() + } else { + // cleanups are deferred until the transaction is committed + // because we need to forget the cache only if the transaction is committed + + // update the cache + for _, c := range c.cleanUp.purge { + c.updateTx() + } + for _, forget := range c.cleanUp.forget { + forget.cache.updateByKeyTx(forget.key) } - c.cleanUp = c.cleanUp[:0] } return res, err @@ -117,7 +118,7 @@ func (c *cacheConn) execInsert(ctx context.Context, rawQuery string, queryInfo d } cleanUp := handleInsertQuery(queryInfo.Query, *queryInfo.Insert, args) - c.cleanUp = append(c.cleanUp, cleanUp...) + c.cleanUp.append(cleanUp) return inner.ExecContext(ctx, rawQuery, nvargs) } @@ -126,7 +127,7 @@ func (c *cacheConn) execUpdate(ctx context.Context, rawQuery string, queryInfo d args := namedToValue(nvargs) cleanUp := handleUpdateQuery(*queryInfo.Update, args) - c.cleanUp = append(c.cleanUp, cleanUp...) + c.cleanUp.append(cleanUp) return inner.ExecContext(ctx, rawQuery, nvargs) } @@ -135,7 +136,7 @@ func (c *cacheConn) execDelete(ctx context.Context, rawQuery string, queryInfo d args := namedToValue(nvargs) cleanUp := handleDeleteQuery(*queryInfo.Delete, args) - c.cleanUp = append(c.cleanUp, cleanUp...) + c.cleanUp.append(cleanUp) return inner.ExecContext(ctx, rawQuery, nvargs) } @@ -150,7 +151,7 @@ func (s *customCacheStatement) Query(args []driver.Value) (driver.Rows, error) { return s.inQuery(args) } - rows, err := caches[cacheName(s.query)].cache.Get(ctx, cacheKey(args)) + rows, err := caches[cacheName(s.query)].Get(ctx, cacheKey(args)) if err != nil { return nil, err } @@ -169,7 +170,7 @@ func (s *customCacheStatement) inQuery(args []driver.Value) (driver.Rows, error) var cache *cacheWithInfo for _, c := range cacheByTable[table] { if len(c.info.Conditions) == 1 && c.info.Conditions[0].Column == s.queryInfo.Select.Conditions[0].Column && c.info.Conditions[0].Operator == domains.CachePlanOperator_EQ { - cache = &c + cache = c } } if cache == nil { @@ -185,7 +186,7 @@ func (s *customCacheStatement) inQuery(args []driver.Value) (driver.Rows, error) } ctx := context.WithValue(context.Background(), stmtKey{}, stmt) ctx = context.WithValue(ctx, argsKey{}, []driver.Value{condValue}) - rows, err := cache.cache.Get(ctx, cacheKey([]driver.Value{condValue})) + rows, err := cache.Get(ctx, cacheKey([]driver.Value{condValue})) if err != nil { return nil, err } @@ -201,10 +202,6 @@ func (c *cacheConn) QueryContext(ctx context.Context, rawQuery string, nvargs [] return nil, driver.ErrSkip } - if c.tx { - return inner.QueryContext(ctx, rawQuery, nvargs) - } - normalizedQuery := normalizer.NormalizeQuery(rawQuery) queryInfo, ok := queryMap[normalizedQuery] @@ -226,11 +223,19 @@ func (c *cacheConn) QueryContext(ctx context.Context, rawQuery string, nvargs [] args[i] = nv.Value } - cache := caches[queryInfo.Query].cache + cache := caches[queryInfo.Query] + key := cacheKey(args) + + if c.tx && cache.isNewerThan(key, c.txStart) { + // cache is newer than the transaction start time + // we should not use the cache + return inner.QueryContext(ctx, rawQuery, nvargs) + } + cachectx := context.WithValue(ctx, namedValueArgsKey{}, nvargs) cachectx = context.WithValue(cachectx, queryerCtxKey{}, inner) cachectx = context.WithValue(cachectx, queryKey{}, rawQuery) - rows, err := cache.Get(cachectx, cacheKey(args)) + rows, err := cache.Get(cachectx, key) if err != nil { return nil, err } @@ -252,7 +257,7 @@ func (c *cacheConn) inQuery(ctx context.Context, query string, args []driver.Nam var cache *cacheWithInfo for _, c := range cacheByTable[table] { if len(c.info.Conditions) == 1 && c.info.Conditions[0].Column == queryInfo.Select.Conditions[0].Column && c.info.Conditions[0].Operator == domains.CachePlanOperator_EQ { - cache = &c + cache = c } } if cache == nil { @@ -265,7 +270,7 @@ func (c *cacheConn) inQuery(ctx context.Context, query string, args []driver.Nam cacheCtx := context.WithValue(ctx, queryKey{}, cache.query) cacheCtx = context.WithValue(cacheCtx, queryerCtxKey{}, inner) cacheCtx = context.WithValue(cacheCtx, namedValueArgsKey{}, nvargs) - rows, err := cache.cache.Get(cacheCtx, cacheKey([]driver.Value{condValue.Value})) + rows, err := cache.Get(cacheCtx, cacheKey([]driver.Value{condValue.Value})) if err != nil { return nil, err } @@ -275,7 +280,7 @@ func (c *cacheConn) inQuery(ctx context.Context, query string, args []driver.Nam return mergeCachedRows(allRows), nil } -func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, insertValues []driver.Value) (cleanUP []func()) { +func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, insertValues []driver.Value) (cleanUp cleanUpTask) { table := queryInfo.Table insertArgs, _ := normalizer.NormalizeArgs(query) @@ -290,7 +295,7 @@ func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, ins cacheConditions := cache.info.Conditions isComplexQuery := len(cacheConditions) != 1 || len(insertArgs.ExtraArgs) > 0 || cacheConditions[0].Operator != domains.CachePlanOperator_EQ if isComplexQuery { - cleanUP = append(cleanUP, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) continue } @@ -301,21 +306,23 @@ func handleInsertQuery(query string, queryInfo domains.CachePlanInsertQuery, ins // select query: "SELECT * FROM table WHERE col1 = ?" // forget the cache for row := range rows { - cleanUP = append(cleanUP, func() { cache.cache.Forget(cacheKey([]driver.Value{row[insertColumnIdx]})) }) + cleanUp.forget = append(cleanUp.forget, forgetTask{cache, cacheKey([]driver.Value{row[insertColumnIdx]})}) } } else { - cleanUP = append(cleanUP, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } } - return cleanUP + return cleanUp } -func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Value) (cleanUp []func()) { +func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Value) cleanUpTask { // TODO: support composite primary key and other unique key table := queryInfo.Table updateConditions := queryInfo.Conditions + var cleanUp cleanUpTask + // if query is NOT "UPDATE `table` SET ... WHERE `unique_col` = ?" if !isSingleUniqueCondition(updateConditions, table) { for _, cache := range cacheByTable[table] { @@ -323,9 +330,9 @@ func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Val // no need to purge because the cache does not contain the updated column continue } - cleanUp = append(cleanUp, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } - return + return cleanUp } updateCondition := updateConditions[0] @@ -340,18 +347,20 @@ func handleUpdateQuery(queryInfo domains.CachePlanUpdateQuery, args []driver.Val cacheConditions := cache.info.Conditions if isSingleUniqueCondition(cacheConditions, table) && cacheConditions[0].Column == updateCondition.Column { // forget only the updated row - cleanUp = append(cleanUp, func() { cache.cache.Forget(cacheKey([]driver.Value{uniqueValue})) }) + cleanUp.forget = append(cleanUp.forget, forgetTask{cache, cacheKey([]driver.Value{uniqueValue})}) } else { - cleanUp = append(cleanUp, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } } return cleanUp } -func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Value) (cleanUp []func()) { +func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Value) cleanUpTask { table := queryInfo.Table + var cleanUp cleanUpTask + // if query is like "DELETE FROM table WHERE unique = ?" var deleteByUnique bool if len(queryInfo.Conditions) == 1 { @@ -361,10 +370,8 @@ func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Val } if !deleteByUnique { // we should purge all cache - for _, cache := range cacheByTable[table] { - cleanUp = append(cleanUp, cache.cache.Purge) - } - return + cleanUp.purge = append(cleanUp.purge, cacheByTable[table]...) + return cleanUp } uniqueValue := args[queryInfo.Conditions[0].Placeholder.Index] @@ -373,9 +380,9 @@ func handleDeleteQuery(queryInfo domains.CachePlanDeleteQuery, args []driver.Val if cache.uniqueOnly { // query like "SELECT * FROM table WHERE pk = ?" // we should forget the cache - cleanUp = append(cleanUp, func() { cache.cache.Forget(cacheKey([]driver.Value{uniqueValue})) }) + cleanUp.forget = append(cleanUp.forget, forgetTask{cache, cacheKey([]driver.Value{uniqueValue})}) } else { - cleanUp = append(cleanUp, cache.cache.Purge) + cleanUp.purge = append(cleanUp.purge, cache) } } @@ -418,3 +425,32 @@ func namedToValue(nvargs []driver.NamedValue) []driver.Value { } return args } + +type forgetTask struct { + cache *cacheWithInfo + key string +} + +type cleanUpTask struct { + purge []*cacheWithInfo + forget []forgetTask +} + +func (c *cleanUpTask) reset() { + c.purge = c.purge[:0] + c.forget = c.forget[:0] +} + +func (c *cleanUpTask) do() { + for _, cache := range c.purge { + cache.Purge() + } + for _, forget := range c.forget { + forget.cache.Forget(forget.key) + } +} + +func (c *cleanUpTask) append(tasks cleanUpTask) { + c.purge = append(c.purge, tasks.purge...) + c.forget = append(c.forget, tasks.forget...) +} diff --git a/test/cache_test.go b/test/cache_test.go index 66bc810..0de78e2 100644 --- a/test/cache_test.go +++ b/test/cache_test.go @@ -2,6 +2,7 @@ package test import ( "database/sql" + "sync" "testing" "time" @@ -200,8 +201,13 @@ func TestTransaction(t *testing.T) { afterUpdate := make(chan struct{}) beforeCommit := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + // transaction go func() { + defer wg.Done() + tx, err := db.Beginx() if err != nil { errCh <- err @@ -227,6 +233,8 @@ func TestTransaction(t *testing.T) { // select go func() { + defer wg.Done() + <-afterUpdate var user User @@ -246,11 +254,11 @@ func TestTransaction(t *testing.T) { t.Log("select completed") }() - if err := <-errCh; err != nil { - t.Fatal(err) - } - if err := <-errCh; err != nil { - t.Fatal(err) + wg.Wait() + for range len(errCh) { + if err := <-errCh; err != nil { + t.Fatal(err) + } } // now user must be updated @@ -263,3 +271,109 @@ func TestTransaction(t *testing.T) { updated.Name = "updated" AssertUser(t, updated, user) } + +func TestSelectTransaction(t *testing.T) { + cache.ResetCache() + db := NewDB(t) + + tx := db.MustBegin() + defer tx.Rollback() + + var user User + err := tx.Get(&user, "SELECT * FROM `users` WHERE `id` = ?", 1) + if err != nil { + t.Fatal(err) + } + + AssertUser(t, InitialData[0], user) + + // cache hit + err = tx.Get(&user, "SELECT * FROM `users` WHERE `id` = ?", 1) + if err != nil { + t.Fatal(err) + } + + AssertUser(t, InitialData[0], user) + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + stats := cache.ExportCacheStats()[normalizer.NormalizeQuery("SELECT * FROM `users` WHERE `id` = ?")] + assert.Equal(t, 1, stats.Hits) + assert.Equal(t, 1, stats.Misses) +} + +func TestFuzzyRead(t *testing.T) { + cache.ResetCache() + db := NewDB(t) + + errCh := make(chan error, 2) + afterFirstQuery := make(chan struct{}) + afterUpdate := make(chan struct{}) + + var wg sync.WaitGroup + wg.Add(2) + + // transaction 1 + go func() { + defer wg.Done() + + tx := db.MustBegin() + defer tx.Rollback() + + var user1 User + err := tx.Get(&user1, "SELECT * FROM `users` WHERE `id` = ?", 1) + if err != nil { + errCh <- err + return + } + AssertUser(t, InitialData[0], user1) + + close(afterFirstQuery) + + <-afterUpdate + + var user2 User + err = tx.Get(&user2, "SELECT * FROM `users` WHERE `id` = ?", 1) + if err != nil { + errCh <- err + return + } + + AssertUser(t, user1, user2) + + errCh <- tx.Commit() + }() + + // transaction 2 + go func() { + defer wg.Done() + + tx := db.MustBegin() + defer tx.Rollback() + + <-afterFirstQuery + + _, err := tx.Exec("UPDATE `users` SET `name` = ? WHERE `id` = ?", "updated", 1) + if err != nil { + errCh <- err + return + } + + err = tx.Commit() + if err != nil { + errCh <- err + return + } + + close(afterUpdate) + }() + + wg.Wait() + for range len(errCh) { + if err := <-errCh; err != nil { + t.Fatal(err) + } + } +}