From 2ead5d6735824aee2e359afedad24f13a3f26903 Mon Sep 17 00:00:00 2001
From: Bingyi Sun <sunbingyi1992@gmail.com>
Date: Mon, 20 Jan 2025 14:07:05 +0800
Subject: [PATCH] fix: cherry pick warmup async (#39402)

related pr: https://github.com/milvus-io/milvus/pull/38690
issue: https://github.com/milvus-io/milvus/issues/38692

---------

Signed-off-by: sunby <sunbingyi1992@gmail.com>
---
 .../querynodev2/delegator/delegator_data.go   |  1 +
 internal/querynodev2/segments/manager_test.go |  1 +
 internal/querynodev2/segments/pool.go         |  2 +-
 internal/querynodev2/segments/reduce_test.go  |  1 +
 .../querynodev2/segments/retrieve_test.go     |  2 +
 internal/querynodev2/segments/search_test.go  |  2 +
 internal/querynodev2/segments/segment.go      | 67 +++++++++++++++++--
 .../querynodev2/segments/segment_loader.go    |  7 +-
 internal/querynodev2/segments/segment_test.go | 24 +++++++
 internal/querynodev2/server_test.go           |  1 +
 10 files changed, 101 insertions(+), 7 deletions(-)

diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go
index d8521e794d1ae..e81eec6189e2a 100644
--- a/internal/querynodev2/delegator/delegator_data.go
+++ b/internal/querynodev2/delegator/delegator_data.go
@@ -106,6 +106,7 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) {
 					DeltaPosition: insertData.StartPosition,
 					Level:         datapb.SegmentLevel_L1,
 				},
+				nil,
 			)
 			if err != nil {
 				log.Error("failed to create new segment",
diff --git a/internal/querynodev2/segments/manager_test.go b/internal/querynodev2/segments/manager_test.go
index e09002a43dcdc..25f9183a7c533 100644
--- a/internal/querynodev2/segments/manager_test.go
+++ b/internal/querynodev2/segments/manager_test.go
@@ -67,6 +67,7 @@ func (s *ManagerSuite) SetupTest() {
 				InsertChannel: s.channels[i],
 				Level:         s.levels[i],
 			},
+			nil,
 		)
 		s.Require().NoError(err)
 		s.segments = append(s.segments, segment)
diff --git a/internal/querynodev2/segments/pool.go b/internal/querynodev2/segments/pool.go
index 7bddca6169e83..eb7f56b2feb31 100644
--- a/internal/querynodev2/segments/pool.go
+++ b/internal/querynodev2/segments/pool.go
@@ -113,7 +113,7 @@ func initWarmupPool() {
 			conc.WithPreAlloc(false),
 			conc.WithDisablePurge(false),
 			conc.WithPreHandler(runtime.LockOSThread), // lock os thread for cgo thread disposal
-			conc.WithNonBlocking(true),                // make warming up non blocking
+			conc.WithNonBlocking(false),
 		)
 
 		warmupPool.Store(pool)
diff --git a/internal/querynodev2/segments/reduce_test.go b/internal/querynodev2/segments/reduce_test.go
index b1996385043ba..c1950d4b24599 100644
--- a/internal/querynodev2/segments/reduce_test.go
+++ b/internal/querynodev2/segments/reduce_test.go
@@ -90,6 +90,7 @@ func (suite *ReduceSuite) SetupTest() {
 			InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID),
 			Level:         datapb.SegmentLevel_Legacy,
 		},
+		nil,
 	)
 	suite.Require().NoError(err)
 
diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go
index ea072d5ac387a..d29535029c0f2 100644
--- a/internal/querynodev2/segments/retrieve_test.go
+++ b/internal/querynodev2/segments/retrieve_test.go
@@ -96,6 +96,7 @@ func (suite *RetrieveSuite) SetupTest() {
 			InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID),
 			Level:         datapb.SegmentLevel_Legacy,
 		},
+		nil,
 	)
 	suite.Require().NoError(err)
 
@@ -124,6 +125,7 @@ func (suite *RetrieveSuite) SetupTest() {
 			InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID),
 			Level:         datapb.SegmentLevel_Legacy,
 		},
+		nil,
 	)
 	suite.Require().NoError(err)
 
diff --git a/internal/querynodev2/segments/search_test.go b/internal/querynodev2/segments/search_test.go
index 74929b44d17ed..b948ec37f08f7 100644
--- a/internal/querynodev2/segments/search_test.go
+++ b/internal/querynodev2/segments/search_test.go
@@ -87,6 +87,7 @@ func (suite *SearchSuite) SetupTest() {
 			InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID),
 			Level:         datapb.SegmentLevel_Legacy,
 		},
+		nil,
 	)
 	suite.Require().NoError(err)
 
@@ -115,6 +116,7 @@ func (suite *SearchSuite) SetupTest() {
 			InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID),
 			Level:         datapb.SegmentLevel_Legacy,
 		},
+		nil,
 	)
 	suite.Require().NoError(err)
 
diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go
index 895557e626812..3b0237b27bfc7 100644
--- a/internal/querynodev2/segments/segment.go
+++ b/internal/querynodev2/segments/segment.go
@@ -32,6 +32,7 @@ import (
 	"io"
 	"runtime"
 	"strings"
+	"sync"
 	"time"
 	"unsafe"
 
@@ -261,6 +262,7 @@ type LocalSegment struct {
 	fields             *typeutil.ConcurrentMap[int64, *FieldInfo]
 	fieldIndexes       *typeutil.ConcurrentMap[int64, *IndexedFieldInfo]
 	space              *milvus_storage.Space
+	warmupDispatcher   *AsyncWarmupDispatcher
 }
 
 func NewSegment(ctx context.Context,
@@ -268,6 +270,7 @@ func NewSegment(ctx context.Context,
 	segmentType SegmentType,
 	version int64,
 	loadInfo *querypb.SegmentLoadInfo,
+	warmupDispatcher *AsyncWarmupDispatcher,
 ) (Segment, error) {
 	log := log.Ctx(ctx)
 	/*
@@ -326,9 +329,10 @@ func NewSegment(ctx context.Context,
 		fields:             typeutil.NewConcurrentMap[int64, *FieldInfo](),
 		fieldIndexes:       typeutil.NewConcurrentMap[int64, *IndexedFieldInfo](),
 
-		memSize:     atomic.NewInt64(-1),
-		rowNum:      atomic.NewInt64(-1),
-		insertCount: atomic.NewInt64(0),
+		memSize:          atomic.NewInt64(-1),
+		rowNum:           atomic.NewInt64(-1),
+		insertCount:      atomic.NewInt64(0),
+		warmupDispatcher: warmupDispatcher,
 	}
 
 	if err := segment.initializeSegment(); err != nil {
@@ -1507,7 +1511,7 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64, mmap
 			return nil, nil
 		}).Await()
 	case "async":
-		GetWarmupPool().Submit(func() (any, error) {
+		task := func() (any, error) {
 			// failed to wait for state update, return directly
 			if !s.ptrLock.BlockUntilDataLoadedOrReleased() {
 				return nil, nil
@@ -1527,7 +1531,8 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64, mmap
 			}
 			log.Info("warming up chunk cache asynchronously done")
 			return nil, nil
-		})
+		}
+		s.warmupDispatcher.AddTask(task)
 	default:
 		// no warming up
 	}
@@ -1666,3 +1671,55 @@ func (s *LocalSegment) indexNeedLoadRawData(schema *schemapb.CollectionSchema, i
 	}
 	return !typeutil.IsVectorType(fieldSchema.DataType) && s.HasRawData(indexInfo.IndexInfo.FieldID), nil
 }
+
+type (
+	WarmupTask            = func() (any, error)
+	AsyncWarmupDispatcher struct {
+		mu     sync.RWMutex
+		tasks  []WarmupTask
+		notify chan struct{}
+	}
+)
+
+func NewWarmupDispatcher() *AsyncWarmupDispatcher {
+	return &AsyncWarmupDispatcher{
+		notify: make(chan struct{}, 1),
+	}
+}
+
+func (d *AsyncWarmupDispatcher) AddTask(task func() (any, error)) {
+	d.mu.Lock()
+	d.tasks = append(d.tasks, task)
+	d.mu.Unlock()
+	select {
+	case d.notify <- struct{}{}:
+	default:
+	}
+}
+
+func (d *AsyncWarmupDispatcher) Run(ctx context.Context) {
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		case <-d.notify:
+			d.mu.RLock()
+			tasks := make([]WarmupTask, len(d.tasks))
+			copy(tasks, d.tasks)
+			d.mu.RUnlock()
+
+			for _, task := range tasks {
+				select {
+				case <-ctx.Done():
+					return
+				default:
+					GetWarmupPool().Submit(task)
+				}
+			}
+
+			d.mu.Lock()
+			d.tasks = d.tasks[len(tasks):]
+			d.mu.Unlock()
+		}
+	}
+}
diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go
index 98abdc1b39e07..5d70c2b6419bd 100644
--- a/internal/querynodev2/segments/segment_loader.go
+++ b/internal/querynodev2/segments/segment_loader.go
@@ -569,12 +569,15 @@ func NewLoader(
 	duf := NewDiskUsageFetcher(ctx)
 	go duf.Start()
 
+	warmupDispatcher := NewWarmupDispatcher()
+	go warmupDispatcher.Run(ctx)
 	loader := &segmentLoader{
 		manager:                   manager,
 		cm:                        cm,
 		loadingSegments:           typeutil.NewConcurrentMap[int64, *loadResult](),
 		committedResourceNotifier: syncutil.NewVersionedNotifier(),
 		duf:                       duf,
+		warmupDispatcher:          warmupDispatcher,
 	}
 
 	return loader
@@ -617,7 +620,8 @@ type segmentLoader struct {
 	committedResource         LoadResource
 	committedResourceNotifier *syncutil.VersionedNotifier
 
-	duf *diskUsageFetcher
+	duf              *diskUsageFetcher
+	warmupDispatcher *AsyncWarmupDispatcher
 }
 
 var _ Loader = (*segmentLoader)(nil)
@@ -700,6 +704,7 @@ func (loader *segmentLoader) Load(ctx context.Context,
 			segmentType,
 			version,
 			loadInfo,
+			loader.warmupDispatcher,
 		)
 		if err != nil {
 			log.Warn("load segment failed when create new segment",
diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go
index 4042ff56ce330..0f1c31b653331 100644
--- a/internal/querynodev2/segments/segment_test.go
+++ b/internal/querynodev2/segments/segment_test.go
@@ -5,8 +5,11 @@ import (
 	"fmt"
 	"path/filepath"
 	"testing"
+	"time"
 
+	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/suite"
+	"go.uber.org/atomic"
 
 	"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
 	storage "github.com/milvus-io/milvus/internal/storage"
@@ -90,6 +93,7 @@ func (suite *SegmentSuite) SetupTest() {
 				},
 			},
 		},
+		nil,
 	)
 	suite.Require().NoError(err)
 
@@ -121,6 +125,7 @@ func (suite *SegmentSuite) SetupTest() {
 			InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID),
 			Level:         datapb.SegmentLevel_Legacy,
 		},
+		nil,
 	)
 	suite.Require().NoError(err)
 
@@ -221,3 +226,22 @@ func (suite *SegmentSuite) TestSegmentReleased() {
 func TestSegment(t *testing.T) {
 	suite.Run(t, new(SegmentSuite))
 }
+
+func TestWarmupDispatcher(t *testing.T) {
+	d := NewWarmupDispatcher()
+	ctx := context.Background()
+	go d.Run(ctx)
+
+	completed := atomic.NewInt64(0)
+	taskCnt := 10000
+	for i := 0; i < taskCnt; i++ {
+		d.AddTask(func() (any, error) {
+			completed.Inc()
+			return nil, nil
+		})
+	}
+
+	assert.Eventually(t, func() bool {
+		return completed.Load() == int64(taskCnt)
+	}, 10*time.Second, time.Second)
+}
diff --git a/internal/querynodev2/server_test.go b/internal/querynodev2/server_test.go
index 89779bdc1c04f..0b43797746b50 100644
--- a/internal/querynodev2/server_test.go
+++ b/internal/querynodev2/server_test.go
@@ -236,6 +236,7 @@ func (suite *QueryNodeSuite) TestStop() {
 			Level:         datapb.SegmentLevel_Legacy,
 			InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", 1),
 		},
+		nil,
 	)
 	suite.NoError(err)
 	suite.node.manager.Segment.Put(context.Background(), segments.SegmentTypeSealed, segment)