From 2ead5d6735824aee2e359afedad24f13a3f26903 Mon Sep 17 00:00:00 2001 From: Bingyi Sun <sunbingyi1992@gmail.com> Date: Mon, 20 Jan 2025 14:07:05 +0800 Subject: [PATCH] fix: cherry pick warmup async (#39402) related pr: https://github.com/milvus-io/milvus/pull/38690 issue: https://github.com/milvus-io/milvus/issues/38692 --------- Signed-off-by: sunby <sunbingyi1992@gmail.com> --- .../querynodev2/delegator/delegator_data.go | 1 + internal/querynodev2/segments/manager_test.go | 1 + internal/querynodev2/segments/pool.go | 2 +- internal/querynodev2/segments/reduce_test.go | 1 + .../querynodev2/segments/retrieve_test.go | 2 + internal/querynodev2/segments/search_test.go | 2 + internal/querynodev2/segments/segment.go | 67 +++++++++++++++++-- .../querynodev2/segments/segment_loader.go | 7 +- internal/querynodev2/segments/segment_test.go | 24 +++++++ internal/querynodev2/server_test.go | 1 + 10 files changed, 101 insertions(+), 7 deletions(-) diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index d8521e794d1ae..e81eec6189e2a 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -106,6 +106,7 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) { DeltaPosition: insertData.StartPosition, Level: datapb.SegmentLevel_L1, }, + nil, ) if err != nil { log.Error("failed to create new segment", diff --git a/internal/querynodev2/segments/manager_test.go b/internal/querynodev2/segments/manager_test.go index e09002a43dcdc..25f9183a7c533 100644 --- a/internal/querynodev2/segments/manager_test.go +++ b/internal/querynodev2/segments/manager_test.go @@ -67,6 +67,7 @@ func (s *ManagerSuite) SetupTest() { InsertChannel: s.channels[i], Level: s.levels[i], }, + nil, ) s.Require().NoError(err) s.segments = append(s.segments, segment) diff --git a/internal/querynodev2/segments/pool.go b/internal/querynodev2/segments/pool.go index 7bddca6169e83..eb7f56b2feb31 100644 --- a/internal/querynodev2/segments/pool.go +++ b/internal/querynodev2/segments/pool.go @@ -113,7 +113,7 @@ func initWarmupPool() { conc.WithPreAlloc(false), conc.WithDisablePurge(false), conc.WithPreHandler(runtime.LockOSThread), // lock os thread for cgo thread disposal - conc.WithNonBlocking(true), // make warming up non blocking + conc.WithNonBlocking(false), ) warmupPool.Store(pool) diff --git a/internal/querynodev2/segments/reduce_test.go b/internal/querynodev2/segments/reduce_test.go index b1996385043ba..c1950d4b24599 100644 --- a/internal/querynodev2/segments/reduce_test.go +++ b/internal/querynodev2/segments/reduce_test.go @@ -90,6 +90,7 @@ func (suite *ReduceSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index ea072d5ac387a..d29535029c0f2 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -96,6 +96,7 @@ func (suite *RetrieveSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) @@ -124,6 +125,7 @@ func (suite *RetrieveSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) diff --git a/internal/querynodev2/segments/search_test.go b/internal/querynodev2/segments/search_test.go index 74929b44d17ed..b948ec37f08f7 100644 --- a/internal/querynodev2/segments/search_test.go +++ b/internal/querynodev2/segments/search_test.go @@ -87,6 +87,7 @@ func (suite *SearchSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) @@ -115,6 +116,7 @@ func (suite *SearchSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 895557e626812..3b0237b27bfc7 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -32,6 +32,7 @@ import ( "io" "runtime" "strings" + "sync" "time" "unsafe" @@ -261,6 +262,7 @@ type LocalSegment struct { fields *typeutil.ConcurrentMap[int64, *FieldInfo] fieldIndexes *typeutil.ConcurrentMap[int64, *IndexedFieldInfo] space *milvus_storage.Space + warmupDispatcher *AsyncWarmupDispatcher } func NewSegment(ctx context.Context, @@ -268,6 +270,7 @@ func NewSegment(ctx context.Context, segmentType SegmentType, version int64, loadInfo *querypb.SegmentLoadInfo, + warmupDispatcher *AsyncWarmupDispatcher, ) (Segment, error) { log := log.Ctx(ctx) /* @@ -326,9 +329,10 @@ func NewSegment(ctx context.Context, fields: typeutil.NewConcurrentMap[int64, *FieldInfo](), fieldIndexes: typeutil.NewConcurrentMap[int64, *IndexedFieldInfo](), - memSize: atomic.NewInt64(-1), - rowNum: atomic.NewInt64(-1), - insertCount: atomic.NewInt64(0), + memSize: atomic.NewInt64(-1), + rowNum: atomic.NewInt64(-1), + insertCount: atomic.NewInt64(0), + warmupDispatcher: warmupDispatcher, } if err := segment.initializeSegment(); err != nil { @@ -1507,7 +1511,7 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64, mmap return nil, nil }).Await() case "async": - GetWarmupPool().Submit(func() (any, error) { + task := func() (any, error) { // failed to wait for state update, return directly if !s.ptrLock.BlockUntilDataLoadedOrReleased() { return nil, nil @@ -1527,7 +1531,8 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64, mmap } log.Info("warming up chunk cache asynchronously done") return nil, nil - }) + } + s.warmupDispatcher.AddTask(task) default: // no warming up } @@ -1666,3 +1671,55 @@ func (s *LocalSegment) indexNeedLoadRawData(schema *schemapb.CollectionSchema, i } return !typeutil.IsVectorType(fieldSchema.DataType) && s.HasRawData(indexInfo.IndexInfo.FieldID), nil } + +type ( + WarmupTask = func() (any, error) + AsyncWarmupDispatcher struct { + mu sync.RWMutex + tasks []WarmupTask + notify chan struct{} + } +) + +func NewWarmupDispatcher() *AsyncWarmupDispatcher { + return &AsyncWarmupDispatcher{ + notify: make(chan struct{}, 1), + } +} + +func (d *AsyncWarmupDispatcher) AddTask(task func() (any, error)) { + d.mu.Lock() + d.tasks = append(d.tasks, task) + d.mu.Unlock() + select { + case d.notify <- struct{}{}: + default: + } +} + +func (d *AsyncWarmupDispatcher) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-d.notify: + d.mu.RLock() + tasks := make([]WarmupTask, len(d.tasks)) + copy(tasks, d.tasks) + d.mu.RUnlock() + + for _, task := range tasks { + select { + case <-ctx.Done(): + return + default: + GetWarmupPool().Submit(task) + } + } + + d.mu.Lock() + d.tasks = d.tasks[len(tasks):] + d.mu.Unlock() + } + } +} diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 98abdc1b39e07..5d70c2b6419bd 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -569,12 +569,15 @@ func NewLoader( duf := NewDiskUsageFetcher(ctx) go duf.Start() + warmupDispatcher := NewWarmupDispatcher() + go warmupDispatcher.Run(ctx) loader := &segmentLoader{ manager: manager, cm: cm, loadingSegments: typeutil.NewConcurrentMap[int64, *loadResult](), committedResourceNotifier: syncutil.NewVersionedNotifier(), duf: duf, + warmupDispatcher: warmupDispatcher, } return loader @@ -617,7 +620,8 @@ type segmentLoader struct { committedResource LoadResource committedResourceNotifier *syncutil.VersionedNotifier - duf *diskUsageFetcher + duf *diskUsageFetcher + warmupDispatcher *AsyncWarmupDispatcher } var _ Loader = (*segmentLoader)(nil) @@ -700,6 +704,7 @@ func (loader *segmentLoader) Load(ctx context.Context, segmentType, version, loadInfo, + loader.warmupDispatcher, ) if err != nil { log.Warn("load segment failed when create new segment", diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index 4042ff56ce330..0f1c31b653331 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -5,8 +5,11 @@ import ( "fmt" "path/filepath" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "go.uber.org/atomic" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" storage "github.com/milvus-io/milvus/internal/storage" @@ -90,6 +93,7 @@ func (suite *SegmentSuite) SetupTest() { }, }, }, + nil, ) suite.Require().NoError(err) @@ -121,6 +125,7 @@ func (suite *SegmentSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) @@ -221,3 +226,22 @@ func (suite *SegmentSuite) TestSegmentReleased() { func TestSegment(t *testing.T) { suite.Run(t, new(SegmentSuite)) } + +func TestWarmupDispatcher(t *testing.T) { + d := NewWarmupDispatcher() + ctx := context.Background() + go d.Run(ctx) + + completed := atomic.NewInt64(0) + taskCnt := 10000 + for i := 0; i < taskCnt; i++ { + d.AddTask(func() (any, error) { + completed.Inc() + return nil, nil + }) + } + + assert.Eventually(t, func() bool { + return completed.Load() == int64(taskCnt) + }, 10*time.Second, time.Second) +} diff --git a/internal/querynodev2/server_test.go b/internal/querynodev2/server_test.go index 89779bdc1c04f..0b43797746b50 100644 --- a/internal/querynodev2/server_test.go +++ b/internal/querynodev2/server_test.go @@ -236,6 +236,7 @@ func (suite *QueryNodeSuite) TestStop() { Level: datapb.SegmentLevel_Legacy, InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", 1), }, + nil, ) suite.NoError(err) suite.node.manager.Segment.Put(context.Background(), segments.SegmentTypeSealed, segment)