diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index d6042f1927525..5a6b61045070b 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -466,6 +466,13 @@ func (s *Server) startDataCoord() { sessionutil.SaveServerInfo(typeutil.DataCoordRole, s.session.GetServerID()) } +func (s *Server) GetServerID() int64 { + if s.session != nil { + return s.session.GetServerID() + } + return paramtable.GetNodeID() +} + func (s *Server) afterStart() {} func (s *Server) initCluster() error { diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index 8818c0d497bc7..3d04ef6ed3aeb 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -83,7 +83,6 @@ var Params *paramtable.ComponentParam = paramtable.Get() // `segmentCache` stores all flushing and flushed segments. type DataNode struct { ctx context.Context - serverID int64 cancel context.CancelFunc Role string stateCode atomic.Value // commonpb.StateCode_Initializing @@ -129,7 +128,7 @@ type DataNode struct { } // NewDataNode will return a DataNode with abnormal state. -func NewDataNode(ctx context.Context, factory dependency.Factory, serverID int64) *DataNode { +func NewDataNode(ctx context.Context, factory dependency.Factory) *DataNode { rand.Seed(time.Now().UnixNano()) ctx2, cancel2 := context.WithCancel(ctx) node := &DataNode{ @@ -140,7 +139,6 @@ func NewDataNode(ctx context.Context, factory dependency.Factory, serverID int64 rootCoord: nil, dataCoord: nil, factory: factory, - serverID: serverID, segmentCache: newCache(), compactionExecutor: newCompactionExecutor(), @@ -228,10 +226,10 @@ func (node *DataNode) initRateCollector() error { } func (node *DataNode) GetNodeID() int64 { - if node.serverID == 0 && node.session != nil { + if node.session != nil { return node.session.ServerID } - return node.serverID + return paramtable.GetNodeID() } func (node *DataNode) Init() error { @@ -246,7 +244,7 @@ func (node *DataNode) Init() error { return } - serverID := node.session.ServerID + serverID := node.GetNodeID() log := log.Ctx(node.ctx).With(zap.String("role", typeutil.DataNodeRole), zap.Int64("nodeID", serverID)) node.broker = broker.NewCoordBroker(node.rootCoord, node.dataCoord, serverID) diff --git a/internal/datanode/event_manager.go b/internal/datanode/event_manager.go index 5fda6f0f2c591..4d0ceea6c2fc3 100644 --- a/internal/datanode/event_manager.go +++ b/internal/datanode/event_manager.go @@ -92,7 +92,7 @@ func (node *DataNode) StartWatchChannels(ctx context.Context) { // serves the corner case for etcd connection lost and missing some events func (node *DataNode) checkWatchedList() error { // REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name} - prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.serverID)) + prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetNodeID())) keys, values, err := node.watchKv.LoadWithPrefix(prefix) if err != nil { return err diff --git a/internal/datanode/mock_test.go b/internal/datanode/mock_test.go index c5afe6cd5ae93..ab9a99ad8f2f3 100644 --- a/internal/datanode/mock_test.go +++ b/internal/datanode/mock_test.go @@ -83,7 +83,7 @@ var segID2SegInfo = map[int64]*datapb.SegmentInfo{ func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode { factory := dependency.NewDefaultFactory(true) - node := NewDataNode(ctx, factory, 1) + node := NewDataNode(ctx, factory) node.SetSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}) node.dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID()) diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index 387017aaf003f..0b8aedce31db9 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -180,7 +180,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { interceptor.ClusterValidationUnaryServerInterceptor(), interceptor.ServerIDValidationUnaryServerInterceptor(func() int64 { if s.serverID.Load() == 0 { - s.serverID.Store(paramtable.GetNodeID()) + s.serverID.Store(s.dataCoord.(*datacoord.Server).GetServerID()) } return s.serverID.Load() }), @@ -191,7 +191,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { interceptor.ClusterValidationStreamServerInterceptor(), interceptor.ServerIDValidationStreamServerInterceptor(func() int64 { if s.serverID.Load() == 0 { - s.serverID.Store(paramtable.GetNodeID()) + s.serverID.Store(s.dataCoord.(*datacoord.Server).GetServerID()) } return s.serverID.Load() }), diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index dc240d695c473..5bbf12224e577 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -91,7 +91,7 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) } s.serverID.Store(paramtable.GetNodeID()) - s.datanode = dn.NewDataNode(s.ctx, s.factory, s.serverID.Load()) + s.datanode = dn.NewDataNode(s.ctx, s.factory) return s, nil }