From e24df29d0c8d16d77c6868cfe312b1b3327b447d Mon Sep 17 00:00:00 2001 From: git-hulk Date: Tue, 7 Jan 2025 20:23:16 +0800 Subject: [PATCH] Fix data race when updating the raft snapshot and compact threshold --- store/engine/raft/node.go | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/store/engine/raft/node.go b/store/engine/raft/node.go index 70a7069..c5c9092 100644 --- a/store/engine/raft/node.go +++ b/store/engine/raft/node.go @@ -73,13 +73,12 @@ type Node struct { logger *zap.Logger peers sync.Map - mu sync.Mutex leader uint64 appliedIndex uint64 snapshotIndex uint64 confState raftpb.ConfState - snapshotThreshold uint64 - compactThreshold uint64 + snapshotThreshold atomic.Uint64 + compactThreshold atomic.Uint64 wg sync.WaitGroup shutdown chan struct{} @@ -97,14 +96,14 @@ func New(config *Config) (*Node, error) { logger := logger.Get().With(zap.Uint64("node_id", config.ID)) n := &Node{ - config: config, - leader: raft.None, - dataStore: NewDataStore(config.DataDir), - leaderChanged: make(chan bool), - snapshotThreshold: defaultSnapshotThreshold, - compactThreshold: defaultCompactThreshold, - logger: logger, - } + config: config, + leader: raft.None, + dataStore: NewDataStore(config.DataDir), + leaderChanged: make(chan bool), + logger: logger, + } + n.snapshotThreshold.Store(defaultSnapshotThreshold) + n.compactThreshold.Store(defaultCompactThreshold) if err := n.run(); err != nil { return nil, err } @@ -127,9 +126,7 @@ func (n *Node) ListPeers() map[uint64]string { } func (n *Node) SetSnapshotThreshold(threshold uint64) { - n.mu.Lock() - defer n.mu.Unlock() - n.snapshotThreshold = threshold + n.snapshotThreshold.Store(threshold) } func (n *Node) run() error { @@ -309,7 +306,7 @@ func (n *Node) runRaftMessages() error { } func (n *Node) triggerSnapshotIfNeed() error { - if n.appliedIndex-n.snapshotIndex <= n.snapshotThreshold { + if n.appliedIndex-n.snapshotIndex <= n.snapshotThreshold.Load() { return nil } snapshotBytes, err := n.dataStore.GetDataStoreSnapshot() @@ -325,8 +322,8 @@ func (n *Node) triggerSnapshotIfNeed() error { } compactIndex := uint64(1) - if n.appliedIndex > n.compactThreshold { - compactIndex = n.appliedIndex - n.compactThreshold + if n.appliedIndex > n.compactThreshold.Load() { + compactIndex = n.appliedIndex - n.compactThreshold.Load() } if err := n.dataStore.raftStorage.Compact(compactIndex); err != nil && !errors.Is(err, raft.ErrCompacted) { return err