From 139787371efeaa9d8f5bf4e33eaf40a522218eab Mon Sep 17 00:00:00 2001 From: aoiasd <45024769+aoiasd@users.noreply.github.com> Date: Thu, 19 Sep 2024 10:57:12 +0800 Subject: [PATCH] feat: support embedding bm25 sparse vector and flush bm25 stats log (#36036) relate: https://github.com/milvus-io/milvus/issues/35853 --------- Signed-off-by: aoiasd --- internal/datacoord/compaction_task_l0.go | 2 +- internal/datacoord/meta.go | 3 +- internal/datacoord/meta_test.go | 4 +- internal/datacoord/services.go | 2 +- internal/datanode/compaction/load_stats.go | 47 +++ .../datanode/compaction/mix_compactor_test.go | 2 +- internal/datanode/importv2/util.go | 4 +- internal/datanode/services_test.go | 40 +-- internal/flushcommon/metacache/actions.go | 6 + internal/flushcommon/metacache/bm25_stats.go | 74 +++++ .../flushcommon/metacache/bm25_stats_test.go | 61 ++++ internal/flushcommon/metacache/meta_cache.go | 29 +- .../flushcommon/metacache/meta_cache_test.go | 8 +- .../flushcommon/metacache/mock_meta_cache.go | 23 +- internal/flushcommon/metacache/segment.go | 9 +- .../flushcommon/metacache/segment_test.go | 6 +- .../flushcommon/pipeline/data_sync_service.go | 56 +++- .../pipeline/flow_graph_embedding_node.go | 164 +++++++++++ .../flow_graph_embedding_node_test.go | 175 +++++++++++ .../pipeline/flow_graph_message.go | 4 + .../pipeline/flow_graph_time_tick_node.go | 1 - .../pipeline/flow_graph_write_node.go | 30 +- internal/flushcommon/syncmgr/meta_writer.go | 10 +- .../flushcommon/syncmgr/meta_writer_test.go | 4 +- internal/flushcommon/syncmgr/options.go | 1 + internal/flushcommon/syncmgr/serializer.go | 7 + .../flushcommon/syncmgr/storage_serializer.go | 68 ++++- .../syncmgr/storage_serializer_test.go | 2 +- .../flushcommon/syncmgr/sync_manager_test.go | 4 +- internal/flushcommon/syncmgr/task.go | 65 ++++- internal/flushcommon/syncmgr/task_test.go | 6 +- .../writebuffer/bf_write_buffer.go | 15 +- .../writebuffer/bf_write_buffer_test.go | 95 +++--- .../flushcommon/writebuffer/insert_buffer.go | 21 +- .../writebuffer/insert_buffer_test.go | 12 +- .../writebuffer/l0_write_buffer.go | 21 +- .../writebuffer/l0_write_buffer_test.go | 37 +-- internal/flushcommon/writebuffer/manager.go | 6 +- .../flushcommon/writebuffer/mock_manager.go | 22 +- .../writebuffer/mock_write_buffer.go | 12 +- .../flushcommon/writebuffer/segment_buffer.go | 7 +- .../flushcommon/writebuffer/stats_buffer.go | 48 ++++ .../flushcommon/writebuffer/write_buffer.go | 272 +++++++++++------- .../writebuffer/write_buffer_test.go | 4 +- internal/metastore/kv/binlog/binlog.go | 11 + internal/metastore/kv/datacoord/constant.go | 1 + internal/metastore/kv/datacoord/kv_catalog.go | 21 +- .../metastore/kv/datacoord/kv_catalog_test.go | 29 +- internal/metastore/kv/datacoord/util.go | 55 ++-- internal/proto/data_coord.proto | 3 + internal/storage/binlog_writer.go | 2 + internal/storage/stats.go | 174 ++++++++++- internal/storage/utils.go | 10 +- internal/util/function/bm25_function.go | 159 ++++++++++ internal/util/function/function.go | 41 +++ internal/util/function/function_test.go | 82 ++++++ pkg/common/common.go | 3 + pkg/util/metautil/binlog.go | 5 + pkg/util/typeutil/schema.go | 12 + 59 files changed, 1718 insertions(+), 379 deletions(-) create mode 100644 internal/flushcommon/metacache/bm25_stats.go create mode 100644 internal/flushcommon/metacache/bm25_stats_test.go create mode 100644 internal/flushcommon/pipeline/flow_graph_embedding_node.go create mode 100644 internal/flushcommon/pipeline/flow_graph_embedding_node_test.go create mode 100644 internal/flushcommon/writebuffer/stats_buffer.go create mode 100644 internal/util/function/bm25_function.go create mode 100644 internal/util/function/function.go create mode 100644 internal/util/function/function_test.go diff --git a/internal/datacoord/compaction_task_l0.go b/internal/datacoord/compaction_task_l0.go index 7c3eae1697bab..2a4d12339607a 100644 --- a/internal/datacoord/compaction_task_l0.go +++ b/internal/datacoord/compaction_task_l0.go @@ -411,7 +411,7 @@ func (t *l0CompactionTask) saveSegmentMeta() error { result := t.result var operators []UpdateOperator for _, seg := range result.GetSegments() { - operators = append(operators, AddBinlogsOperator(seg.GetSegmentID(), nil, nil, seg.GetDeltalogs())) + operators = append(operators, AddBinlogsOperator(seg.GetSegmentID(), nil, nil, seg.GetDeltalogs(), nil)) } for _, segID := range t.InputSegments { diff --git a/internal/datacoord/meta.go b/internal/datacoord/meta.go index 822bc9b0497a5..e9b18e673327b 100644 --- a/internal/datacoord/meta.go +++ b/internal/datacoord/meta.go @@ -863,7 +863,7 @@ func RevertSegmentPartitionStatsVersionOperator(segmentID int64) UpdateOperator } // Add binlogs in segmentInfo -func AddBinlogsOperator(segmentID int64, binlogs, statslogs, deltalogs []*datapb.FieldBinlog) UpdateOperator { +func AddBinlogsOperator(segmentID int64, binlogs, statslogs, deltalogs, bm25logs []*datapb.FieldBinlog) UpdateOperator { return func(modPack *updateSegmentPack) bool { segment := modPack.Get(segmentID) if segment == nil { @@ -875,6 +875,7 @@ func AddBinlogsOperator(segmentID int64, binlogs, statslogs, deltalogs []*datapb segment.Binlogs = mergeFieldBinlogs(segment.GetBinlogs(), binlogs) segment.Statslogs = mergeFieldBinlogs(segment.GetStatslogs(), statslogs) segment.Deltalogs = mergeFieldBinlogs(segment.GetDeltalogs(), deltalogs) + segment.Bm25Statslogs = mergeFieldBinlogs(segment.GetBm25Statslogs(), bm25logs) modPack.increments[segmentID] = metastore.BinlogsIncrement{ Segment: segment.SegmentInfo, } diff --git a/internal/datacoord/meta_test.go b/internal/datacoord/meta_test.go index 6d5fb2c39e174..78bb0078ac9bb 100644 --- a/internal/datacoord/meta_test.go +++ b/internal/datacoord/meta_test.go @@ -675,6 +675,7 @@ func TestUpdateSegmentsInfo(t *testing.T) { []*datapb.FieldBinlog{getFieldBinlogIDsWithEntry(1, 10, 1)}, []*datapb.FieldBinlog{getFieldBinlogIDs(1, 1)}, []*datapb.FieldBinlog{{Binlogs: []*datapb.Binlog{{EntriesNum: 1, TimestampFrom: 100, TimestampTo: 200, LogSize: 1000, LogPath: "", LogID: 2}}}}, + []*datapb.FieldBinlog{}, ), UpdateStartPosition([]*datapb.SegmentStartPosition{{SegmentID: 1, StartPosition: &msgpb.MsgPosition{MsgID: []byte{1, 2, 3}}}}), UpdateCheckPointOperator(1, []*datapb.CheckPoint{{SegmentID: 1, NumOfRows: 10}}), @@ -735,7 +736,7 @@ func TestUpdateSegmentsInfo(t *testing.T) { assert.NoError(t, err) err = meta.UpdateSegmentsInfo( - AddBinlogsOperator(1, nil, nil, nil), + AddBinlogsOperator(1, nil, nil, nil, nil), ) assert.NoError(t, err) @@ -816,6 +817,7 @@ func TestUpdateSegmentsInfo(t *testing.T) { []*datapb.FieldBinlog{getFieldBinlogIDs(1, 2)}, []*datapb.FieldBinlog{getFieldBinlogIDs(1, 2)}, []*datapb.FieldBinlog{{Binlogs: []*datapb.Binlog{{EntriesNum: 1, TimestampFrom: 100, TimestampTo: 200, LogSize: 1000, LogPath: "", LogID: 2}}}}, + []*datapb.FieldBinlog{}, ), UpdateStartPosition([]*datapb.SegmentStartPosition{{SegmentID: 1, StartPosition: &msgpb.MsgPosition{MsgID: []byte{1, 2, 3}}}}), UpdateCheckPointOperator(1, []*datapb.CheckPoint{{SegmentID: 1, NumOfRows: 10}}), diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 2e1c94c2a4278..779b77c4e14e1 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -549,7 +549,7 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath // save binlogs, start positions and checkpoints operators = append(operators, - AddBinlogsOperator(req.GetSegmentID(), req.GetField2BinlogPaths(), req.GetField2StatslogPaths(), req.GetDeltalogs()), + AddBinlogsOperator(req.GetSegmentID(), req.GetField2BinlogPaths(), req.GetField2StatslogPaths(), req.GetDeltalogs(), req.GetField2Bm25LogPaths()), UpdateStartPosition(req.GetStartPositions()), UpdateCheckPointOperator(req.GetSegmentID(), req.GetCheckPoints()), ) diff --git a/internal/datanode/compaction/load_stats.go b/internal/datanode/compaction/load_stats.go index 9961ba2c179b8..f96d2861dda4e 100644 --- a/internal/datanode/compaction/load_stats.go +++ b/internal/datanode/compaction/load_stats.go @@ -21,6 +21,7 @@ import ( "path" "time" + "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -30,6 +31,52 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +func LoadBM25Stats(ctx context.Context, chunkManager storage.ChunkManager, schema *schemapb.CollectionSchema, segmentID int64, statsBinlogs []*datapb.FieldBinlog) (map[int64]*storage.BM25Stats, error) { + startTs := time.Now() + log := log.With(zap.Int64("segmentID", segmentID)) + log.Info("begin to reload history BM25 stats", zap.Int("statsBinLogsLen", len(statsBinlogs))) + + fieldList, fieldOffset := make([]int64, len(statsBinlogs)), make([]int, len(statsBinlogs)) + logpaths := make([]string, 0) + for i, binlog := range statsBinlogs { + fieldList[i] = binlog.FieldID + fieldOffset[i] = len(binlog.Binlogs) + logpaths = append(logpaths, lo.Map(binlog.Binlogs, func(log *datapb.Binlog, _ int) string { return log.GetLogPath() })...) + } + + if len(logpaths) == 0 { + log.Warn("no BM25 stats to load") + return nil, nil + } + + values, err := chunkManager.MultiRead(ctx, logpaths) + if err != nil { + log.Warn("failed to load BM25 stats files", zap.Error(err)) + return nil, err + } + + result := make(map[int64]*storage.BM25Stats) + cnt := 0 + for i, fieldID := range fieldList { + for offset := 0; offset < fieldOffset[i]; offset++ { + stats, ok := result[fieldID] + if !ok { + stats = storage.NewBM25Stats() + result[fieldID] = stats + } + err := stats.Deserialize(values[cnt+offset]) + if err != nil { + return nil, err + } + } + cnt += fieldOffset[i] + } + + // TODO ADD METRIC FOR LOAD BM25 TIME + log.Info("Successfully load BM25 stats", zap.Any("time", time.Since(startTs))) + return result, nil +} + func LoadStats(ctx context.Context, chunkManager storage.ChunkManager, schema *schemapb.CollectionSchema, segmentID int64, statsBinlogs []*datapb.FieldBinlog) ([]*storage.PkStatistics, error) { startTs := time.Now() log := log.With(zap.Int64("segmentID", segmentID)) diff --git a/internal/datanode/compaction/mix_compactor_test.go b/internal/datanode/compaction/mix_compactor_test.go index 36a0a1ceab2a7..3405bfcba3787 100644 --- a/internal/datanode/compaction/mix_compactor_test.go +++ b/internal/datanode/compaction/mix_compactor_test.go @@ -189,7 +189,7 @@ func (s *MixCompactionTaskSuite) TestCompactTwoToOne() { PartitionID: PartitionID, ID: 99999, NumOfRows: 0, - }, pkoracle.NewBloomFilterSet()) + }, pkoracle.NewBloomFilterSet(), nil) s.plan.SegmentBinlogs = append(s.plan.SegmentBinlogs, &datapb.CompactionSegmentBinlogs{ SegmentID: seg.SegmentID(), diff --git a/internal/datanode/importv2/util.go b/internal/datanode/importv2/util.go index 829283f4f0ad8..7dd3c3eda1004 100644 --- a/internal/datanode/importv2/util.go +++ b/internal/datanode/importv2/util.go @@ -63,7 +63,7 @@ func NewSyncTask(ctx context.Context, }, func(info *datapb.SegmentInfo) pkoracle.PkStat { bfs := pkoracle.NewBloomFilterSet() return bfs - }) + }, metacache.NewBM25StatsFactory) } var serializer syncmgr.Serializer @@ -248,7 +248,7 @@ func NewMetaCache(req *datapb.ImportRequest) map[string]metacache.MetaCache { } metaCache := metacache.NewMetaCache(info, func(segment *datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) metaCaches[channel] = metaCache } return metaCaches diff --git a/internal/datanode/services_test.go b/internal/datanode/services_test.go index c776b675999a5..13f65fb99d80b 100644 --- a/internal/datanode/services_test.go +++ b/internal/datanode/services_test.go @@ -373,7 +373,7 @@ func (s *DataNodeServicesSuite) TestFlushSegments() { PartitionID: 2, State: commonpb.SegmentState_Growing, StartPosition: &msgpb.MsgPosition{}, - }, func(_ *datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() }) + }, func(_ *datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() }, metacache.NoneBm25StatsFactory) s.Run("service_not_ready", func() { ctx, cancel := context.WithCancel(context.Background()) @@ -637,7 +637,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Vchan: &datapb.VchannelInfo{}, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 100, CollectionID: 1, @@ -648,7 +648,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L0, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 101, CollectionID: 1, @@ -659,7 +659,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L1, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 102, CollectionID: 1, @@ -670,7 +670,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L0, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 103, CollectionID: 1, @@ -681,7 +681,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L0, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) mockFlowgraphManager := pipeline.NewMockFlowgraphManager(s.T()) mockFlowgraphManager.EXPECT().GetFlowgraphService(mock.Anything). Return(pipeline.NewDataSyncServiceWithMetaCache(cache), true) @@ -759,7 +759,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Vchan: &datapb.VchannelInfo{}, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 100, CollectionID: 1, @@ -770,7 +770,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L1, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 101, CollectionID: 1, @@ -781,7 +781,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L1, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) mockFlowgraphManager := pipeline.NewMockFlowgraphManager(s.T()) mockFlowgraphManager.EXPECT().GetFlowgraphService(mock.Anything). Return(pipeline.NewDataSyncServiceWithMetaCache(cache), true) @@ -847,7 +847,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Vchan: &datapb.VchannelInfo{}, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 100, CollectionID: 1, @@ -858,7 +858,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L1, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 101, CollectionID: 1, @@ -869,7 +869,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L1, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) mockFlowgraphManager := pipeline.NewMockFlowgraphManager(s.T()) mockFlowgraphManager.EXPECT().GetFlowgraphService(mock.Anything). Return(pipeline.NewDataSyncServiceWithMetaCache(cache), true) @@ -935,7 +935,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Vchan: &datapb.VchannelInfo{}, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 100, CollectionID: 1, @@ -946,7 +946,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L1, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 101, CollectionID: 1, @@ -957,7 +957,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L1, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 102, CollectionID: 1, @@ -968,7 +968,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L1, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) mockFlowgraphManager := pipeline.NewMockFlowgraphManager(s.T()) mockFlowgraphManager.EXPECT().GetFlowgraphService(mock.Anything). Return(pipeline.NewDataSyncServiceWithMetaCache(cache), true) @@ -1028,7 +1028,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Vchan: &datapb.VchannelInfo{}, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 100, CollectionID: 1, @@ -1039,7 +1039,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L0, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) cache.AddSegment(&datapb.SegmentInfo{ ID: 101, CollectionID: 1, @@ -1050,7 +1050,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Level: datapb.SegmentLevel_L1, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) mockFlowgraphManager := pipeline.NewMockFlowgraphManager(s.T()) mockFlowgraphManager.EXPECT().GetFlowgraphService(mock.Anything). Return(pipeline.NewDataSyncServiceWithMetaCache(cache), true) @@ -1110,7 +1110,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { Vchan: &datapb.VchannelInfo{}, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) mockFlowgraphManager := pipeline.NewMockFlowgraphManager(s.T()) mockFlowgraphManager.EXPECT().GetFlowgraphService(mock.Anything). Return(pipeline.NewDataSyncServiceWithMetaCache(cache), true) diff --git a/internal/flushcommon/metacache/actions.go b/internal/flushcommon/metacache/actions.go index 60eb2ee8699f7..51b780100c690 100644 --- a/internal/flushcommon/metacache/actions.go +++ b/internal/flushcommon/metacache/actions.go @@ -159,6 +159,12 @@ func RollStats(newStats ...*storage.PrimaryKeyStats) SegmentAction { } } +func MergeBm25Stats(newStats map[int64]*storage.BM25Stats) SegmentAction { + return func(info *SegmentInfo) { + info.bm25stats.Merge(newStats) + } +} + func StartSyncing(batchSize int64) SegmentAction { return func(info *SegmentInfo) { info.syncingRows += batchSize diff --git a/internal/flushcommon/metacache/bm25_stats.go b/internal/flushcommon/metacache/bm25_stats.go new file mode 100644 index 0000000000000..97c36db824c33 --- /dev/null +++ b/internal/flushcommon/metacache/bm25_stats.go @@ -0,0 +1,74 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metacache + +import ( + "sync" + + "github.com/pingcap/log" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/storage" +) + +type SegmentBM25Stats struct { + mut sync.RWMutex + stats map[int64]*storage.BM25Stats +} + +func (s *SegmentBM25Stats) Merge(stats map[int64]*storage.BM25Stats) { + s.mut.Lock() + defer s.mut.Unlock() + + for fieldID, current := range stats { + if history, ok := s.stats[fieldID]; !ok { + s.stats[fieldID] = current.Clone() + } else { + history.Merge(current) + } + } +} + +func (s *SegmentBM25Stats) Serialize() (map[int64][]byte, map[int64]int64, error) { + s.mut.Lock() + defer s.mut.Unlock() + + result := make(map[int64][]byte) + numRow := make(map[int64]int64) + for fieldID, stats := range s.stats { + bytes, err := stats.Serialize() + if err != nil { + log.Warn("serialize history bm25 stats failed", zap.Int64("fieldID", fieldID)) + return nil, nil, err + } + result[fieldID] = bytes + numRow[fieldID] = stats.NumRow() + } + return result, numRow, nil +} + +func NewEmptySegmentBM25Stats() *SegmentBM25Stats { + return &SegmentBM25Stats{ + stats: make(map[int64]*storage.BM25Stats), + } +} + +func NewSegmentBM25Stats(stats map[int64]*storage.BM25Stats) *SegmentBM25Stats { + return &SegmentBM25Stats{ + stats: stats, + } +} diff --git a/internal/flushcommon/metacache/bm25_stats_test.go b/internal/flushcommon/metacache/bm25_stats_test.go new file mode 100644 index 0000000000000..bf4fb790530d2 --- /dev/null +++ b/internal/flushcommon/metacache/bm25_stats_test.go @@ -0,0 +1,61 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metacache + +import ( + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +type BM25StatsSetSuite struct { + suite.Suite + stats *SegmentBM25Stats + bm25FieldIDs []int64 +} + +func (s *BM25StatsSetSuite) SetupTest() { + paramtable.Init() + s.stats = NewEmptySegmentBM25Stats() + s.bm25FieldIDs = []int64{101, 102} +} + +func (suite *BM25StatsSetSuite) TestMergeAndSeralize() { + statsA := map[int64]*storage.BM25Stats{ + 101: {}, + } + statsA[101].Append(map[uint32]float32{1: 1, 2: 2}) + + statsB := map[int64]*storage.BM25Stats{ + 101: {}, + } + statsB[101].Append(map[uint32]float32{1: 1, 2: 2}) + + suite.stats.Merge(statsA) + suite.stats.Merge(statsB) + + blobs, numrows, err := suite.stats.Serialize() + suite.NoError(err) + suite.Equal(numrows[101], int64(2)) + + storageStats := storage.NewBM25Stats() + err = storageStats.Deserialize(blobs[101]) + suite.NoError(err) + + suite.Equal(storageStats.NumRow(), int64(2)) +} diff --git a/internal/flushcommon/metacache/meta_cache.go b/internal/flushcommon/metacache/meta_cache.go index fc8338750881b..cb75b3ddf1773 100644 --- a/internal/flushcommon/metacache/meta_cache.go +++ b/internal/flushcommon/metacache/meta_cache.go @@ -37,7 +37,7 @@ type MetaCache interface { // Schema returns collection schema. Schema() *schemapb.CollectionSchema // AddSegment adds a segment from segment info. - AddSegment(segInfo *datapb.SegmentInfo, factory PkStatsFactory, actions ...SegmentAction) + AddSegment(segInfo *datapb.SegmentInfo, pkFactory PkStatsFactory, bmFactory BM25StatsFactory, actions ...SegmentAction) // UpdateSegments applies action to segment(s) satisfy the provided filters. UpdateSegments(action SegmentAction, filters ...SegmentFilter) // RemoveSegments removes segments matches the provided filter. @@ -58,7 +58,18 @@ type MetaCache interface { var _ MetaCache = (*metaCacheImpl)(nil) -type PkStatsFactory func(vchannel *datapb.SegmentInfo) pkoracle.PkStat +type ( + PkStatsFactory func(vchannel *datapb.SegmentInfo) pkoracle.PkStat + BM25StatsFactory func(vchannel *datapb.SegmentInfo) *SegmentBM25Stats +) + +func NoneBm25StatsFactory(vchannel *datapb.SegmentInfo) *SegmentBM25Stats { + return nil +} + +func NewBM25StatsFactory(vchannel *datapb.SegmentInfo) *SegmentBM25Stats { + return NewEmptySegmentBM25Stats() +} type metaCacheImpl struct { collectionID int64 @@ -70,7 +81,7 @@ type metaCacheImpl struct { stateSegments map[commonpb.SegmentState]map[int64]*SegmentInfo } -func NewMetaCache(info *datapb.ChannelWatchInfo, factory PkStatsFactory) MetaCache { +func NewMetaCache(info *datapb.ChannelWatchInfo, pkFactory PkStatsFactory, bmFactor BM25StatsFactory) MetaCache { vchannel := info.GetVchan() cache := &metaCacheImpl{ collectionID: vchannel.GetCollectionID(), @@ -91,19 +102,19 @@ func NewMetaCache(info *datapb.ChannelWatchInfo, factory PkStatsFactory) MetaCac cache.stateSegments[state] = make(map[int64]*SegmentInfo) } - cache.init(vchannel, factory) + cache.init(vchannel, pkFactory, bmFactor) return cache } -func (c *metaCacheImpl) init(vchannel *datapb.VchannelInfo, factory PkStatsFactory) { +func (c *metaCacheImpl) init(vchannel *datapb.VchannelInfo, pkFactory PkStatsFactory, bmFactor BM25StatsFactory) { for _, seg := range vchannel.FlushedSegments { - c.addSegment(NewSegmentInfo(seg, factory(seg))) + c.addSegment(NewSegmentInfo(seg, pkFactory(seg), bmFactor(seg))) } for _, seg := range vchannel.UnflushedSegments { // segment state could be sealed for growing segment if flush request processed before datanode watch seg.State = commonpb.SegmentState_Growing - c.addSegment(NewSegmentInfo(seg, factory(seg))) + c.addSegment(NewSegmentInfo(seg, pkFactory(seg), bmFactor(seg))) } } @@ -118,8 +129,8 @@ func (c *metaCacheImpl) Schema() *schemapb.CollectionSchema { } // AddSegment adds a segment from segment info. -func (c *metaCacheImpl) AddSegment(segInfo *datapb.SegmentInfo, factory PkStatsFactory, actions ...SegmentAction) { - segment := NewSegmentInfo(segInfo, factory(segInfo)) +func (c *metaCacheImpl) AddSegment(segInfo *datapb.SegmentInfo, pkFactory PkStatsFactory, bmFactory BM25StatsFactory, actions ...SegmentAction) { + segment := NewSegmentInfo(segInfo, pkFactory(segInfo), bmFactory(segInfo)) for _, action := range actions { action(segment) diff --git a/internal/flushcommon/metacache/meta_cache_test.go b/internal/flushcommon/metacache/meta_cache_test.go index 06b933b7d5e15..7f539b7e76a81 100644 --- a/internal/flushcommon/metacache/meta_cache_test.go +++ b/internal/flushcommon/metacache/meta_cache_test.go @@ -96,7 +96,7 @@ func (s *MetaCacheSuite) SetupTest() { FlushedSegments: flushSegmentInfos, UnflushedSegments: growingSegmentInfos, }, - }, s.bfsFactory) + }, s.bfsFactory, NoneBm25StatsFactory) } func (s *MetaCacheSuite) TestMetaInfo() { @@ -113,7 +113,7 @@ func (s *MetaCacheSuite) TestAddSegment() { } s.cache.AddSegment(info, func(info *datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }, UpdateState(commonpb.SegmentState_Flushed)) + }, NoneBm25StatsFactory, UpdateState(commonpb.SegmentState_Flushed)) } segments := s.cache.GetSegmentsBy(WithSegmentIDs(testSegs...)) @@ -262,7 +262,7 @@ func BenchmarkGetSegmentsBy(b *testing.B) { }, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, NoneBm25StatsFactory) b.ResetTimer() for i := 0; i < b.N; i++ { filter := WithSegmentIDs(0) @@ -294,7 +294,7 @@ func BenchmarkGetSegmentsByWithoutIDs(b *testing.B) { }, }, func(*datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, NoneBm25StatsFactory) b.ResetTimer() for i := 0; i < b.N; i++ { // use old func filter diff --git a/internal/flushcommon/metacache/mock_meta_cache.go b/internal/flushcommon/metacache/mock_meta_cache.go index 8a25577759c94..e88b6ae8c40a2 100644 --- a/internal/flushcommon/metacache/mock_meta_cache.go +++ b/internal/flushcommon/metacache/mock_meta_cache.go @@ -26,14 +26,14 @@ func (_m *MockMetaCache) EXPECT() *MockMetaCache_Expecter { return &MockMetaCache_Expecter{mock: &_m.Mock} } -// AddSegment provides a mock function with given fields: segInfo, factory, actions -func (_m *MockMetaCache) AddSegment(segInfo *datapb.SegmentInfo, factory PkStatsFactory, actions ...SegmentAction) { +// AddSegment provides a mock function with given fields: segInfo, pkFactory, bmFactory, actions +func (_m *MockMetaCache) AddSegment(segInfo *datapb.SegmentInfo, pkFactory PkStatsFactory, bmFactory BM25StatsFactory, actions ...SegmentAction) { _va := make([]interface{}, len(actions)) for _i := range actions { _va[_i] = actions[_i] } var _ca []interface{} - _ca = append(_ca, segInfo, factory) + _ca = append(_ca, segInfo, pkFactory, bmFactory) _ca = append(_ca, _va...) _m.Called(_ca...) } @@ -45,22 +45,23 @@ type MockMetaCache_AddSegment_Call struct { // AddSegment is a helper method to define mock.On call // - segInfo *datapb.SegmentInfo -// - factory PkStatsFactory +// - pkFactory PkStatsFactory +// - bmFactory BM25StatsFactory // - actions ...SegmentAction -func (_e *MockMetaCache_Expecter) AddSegment(segInfo interface{}, factory interface{}, actions ...interface{}) *MockMetaCache_AddSegment_Call { +func (_e *MockMetaCache_Expecter) AddSegment(segInfo interface{}, pkFactory interface{}, bmFactory interface{}, actions ...interface{}) *MockMetaCache_AddSegment_Call { return &MockMetaCache_AddSegment_Call{Call: _e.mock.On("AddSegment", - append([]interface{}{segInfo, factory}, actions...)...)} + append([]interface{}{segInfo, pkFactory, bmFactory}, actions...)...)} } -func (_c *MockMetaCache_AddSegment_Call) Run(run func(segInfo *datapb.SegmentInfo, factory PkStatsFactory, actions ...SegmentAction)) *MockMetaCache_AddSegment_Call { +func (_c *MockMetaCache_AddSegment_Call) Run(run func(segInfo *datapb.SegmentInfo, pkFactory PkStatsFactory, bmFactory BM25StatsFactory, actions ...SegmentAction)) *MockMetaCache_AddSegment_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]SegmentAction, len(args)-2) - for i, a := range args[2:] { + variadicArgs := make([]SegmentAction, len(args)-3) + for i, a := range args[3:] { if a != nil { variadicArgs[i] = a.(SegmentAction) } } - run(args[0].(*datapb.SegmentInfo), args[1].(PkStatsFactory), variadicArgs...) + run(args[0].(*datapb.SegmentInfo), args[1].(PkStatsFactory), args[2].(BM25StatsFactory), variadicArgs...) }) return _c } @@ -70,7 +71,7 @@ func (_c *MockMetaCache_AddSegment_Call) Return() *MockMetaCache_AddSegment_Call return _c } -func (_c *MockMetaCache_AddSegment_Call) RunAndReturn(run func(*datapb.SegmentInfo, PkStatsFactory, ...SegmentAction)) *MockMetaCache_AddSegment_Call { +func (_c *MockMetaCache_AddSegment_Call) RunAndReturn(run func(*datapb.SegmentInfo, PkStatsFactory, BM25StatsFactory, ...SegmentAction)) *MockMetaCache_AddSegment_Call { _c.Call.Return(run) return _c } diff --git a/internal/flushcommon/metacache/segment.go b/internal/flushcommon/metacache/segment.go index cb0ac3a70a48e..8c4906ff7201e 100644 --- a/internal/flushcommon/metacache/segment.go +++ b/internal/flushcommon/metacache/segment.go @@ -35,6 +35,7 @@ type SegmentInfo struct { bufferRows int64 syncingRows int64 bfs pkoracle.PkStat + bm25stats *SegmentBM25Stats level datapb.SegmentLevel syncingTasks int32 } @@ -78,6 +79,10 @@ func (s *SegmentInfo) GetBloomFilterSet() pkoracle.PkStat { return s.bfs } +func (s *SegmentInfo) GetBM25Stats() *SegmentBM25Stats { + return s.bm25stats +} + func (s *SegmentInfo) Level() datapb.SegmentLevel { return s.level } @@ -96,10 +101,11 @@ func (s *SegmentInfo) Clone() *SegmentInfo { bfs: s.bfs, level: s.level, syncingTasks: s.syncingTasks, + bm25stats: s.bm25stats, } } -func NewSegmentInfo(info *datapb.SegmentInfo, bfs pkoracle.PkStat) *SegmentInfo { +func NewSegmentInfo(info *datapb.SegmentInfo, bfs pkoracle.PkStat, bm25Stats *SegmentBM25Stats) *SegmentInfo { level := info.GetLevel() if level == datapb.SegmentLevel_Legacy { level = datapb.SegmentLevel_L1 @@ -114,5 +120,6 @@ func NewSegmentInfo(info *datapb.SegmentInfo, bfs pkoracle.PkStat) *SegmentInfo startPosRecorded: true, level: level, bfs: bfs, + bm25stats: bm25Stats, } } diff --git a/internal/flushcommon/metacache/segment_test.go b/internal/flushcommon/metacache/segment_test.go index 5e067fdd72047..cc01ae0257ea7 100644 --- a/internal/flushcommon/metacache/segment_test.go +++ b/internal/flushcommon/metacache/segment_test.go @@ -33,7 +33,8 @@ type SegmentSuite struct { func (s *SegmentSuite) TestBasic() { bfs := pkoracle.NewBloomFilterSet() - segment := NewSegmentInfo(s.info, bfs) + stats := NewEmptySegmentBM25Stats() + segment := NewSegmentInfo(s.info, bfs, stats) s.Equal(s.info.GetID(), segment.SegmentID()) s.Equal(s.info.GetPartitionID(), segment.PartitionID()) s.Equal(s.info.GetNumOfRows(), segment.NumOfRows()) @@ -45,7 +46,8 @@ func (s *SegmentSuite) TestBasic() { func (s *SegmentSuite) TestClone() { bfs := pkoracle.NewBloomFilterSet() - segment := NewSegmentInfo(s.info, bfs) + stats := NewEmptySegmentBM25Stats() + segment := NewSegmentInfo(s.info, bfs, stats) cloned := segment.Clone() s.Equal(segment.SegmentID(), cloned.SegmentID()) s.Equal(segment.PartitionID(), cloned.PartitionID()) diff --git a/internal/flushcommon/pipeline/data_sync_service.go b/internal/flushcommon/pipeline/data_sync_service.go index d97bd4e6fc0c5..2a38ac2745cb3 100644 --- a/internal/flushcommon/pipeline/data_sync_service.go +++ b/internal/flushcommon/pipeline/data_sync_service.go @@ -142,6 +142,7 @@ func initMetaCache(initCtx context.Context, chunkManager storage.ChunkManager, i futures := make([]*conc.Future[any], 0, len(unflushed)+len(flushed)) // segmentPks := typeutil.NewConcurrentMap[int64, []*storage.PkStatistics]() segmentPks := typeutil.NewConcurrentMap[int64, pkoracle.PkStat]() + segmentBm25 := typeutil.NewConcurrentMap[int64, map[int64]*storage.BM25Stats]() loadSegmentStats := func(segType string, segments []*datapb.SegmentInfo) { for _, item := range segments { @@ -164,6 +165,14 @@ func initMetaCache(initCtx context.Context, chunkManager storage.ChunkManager, i tickler.Inc() } + if segType == "growing" && len(segment.GetBm25Statslogs()) > 0 { + bm25stats, err := compaction.LoadBM25Stats(initCtx, chunkManager, info.GetSchema(), segment.GetID(), segment.GetBm25Statslogs()) + if err != nil { + return nil, err + } + segmentBm25.Insert(segment.GetID(), bm25stats) + } + return struct{}{}, nil }) @@ -220,10 +229,21 @@ func initMetaCache(initCtx context.Context, chunkManager storage.ChunkManager, i } // return channel, nil - metacache := metacache.NewMetaCache(info, func(segment *datapb.SegmentInfo) pkoracle.PkStat { + pkStatsFactory := func(segment *datapb.SegmentInfo) pkoracle.PkStat { pkStat, _ := segmentPks.Get(segment.GetID()) return pkStat - }) + } + + bm25StatsFactor := func(segment *datapb.SegmentInfo) *metacache.SegmentBM25Stats { + stats, ok := segmentBm25.Get(segment.GetID()) + if !ok { + return nil + } + segmentStats := metacache.NewSegmentBM25Stats(stats) + return segmentStats + } + // return channel, nil + metacache := metacache.NewMetaCache(info, pkStatsFactory, bm25StatsFactor) return metacache, nil } @@ -286,15 +306,15 @@ func getServiceWithChannel(initCtx context.Context, params *util.PipelineParams, // init flowgraph fg := flowgraph.NewTimeTickedFlowGraph(params.Ctx) + nodeList := []flowgraph.Node{} - var dmStreamNode *flowgraph.InputNode - dmStreamNode, err = newDmInputNode(initCtx, params.DispClient, info.GetVchan().GetSeekPosition(), config, input) + dmStreamNode, err := newDmInputNode(initCtx, params.DispClient, info.GetVchan().GetSeekPosition(), config, input) if err != nil { return nil, err } + nodeList = append(nodeList, dmStreamNode) - var ddNode *ddNode - ddNode, err = newDDNode( + ddNode, err := newDDNode( params.Ctx, collectionID, channelName, @@ -307,15 +327,29 @@ func getServiceWithChannel(initCtx context.Context, params *util.PipelineParams, if err != nil { return nil, err } + nodeList = append(nodeList, ddNode) + + if len(info.GetSchema().GetFunctions()) > 0 { + emNode, err := newEmbeddingNode(channelName, info.GetSchema()) + if err != nil { + return nil, err + } + nodeList = append(nodeList, emNode) + } + + writeNode, err := newWriteNode(params.Ctx, params.WriteBufferManager, ds.timetickSender, config) + if err != nil { + return nil, err + } + nodeList = append(nodeList, writeNode) - writeNode := newWriteNode(params.Ctx, params.WriteBufferManager, ds.timetickSender, config) - var ttNode *ttNode - ttNode, err = newTTNode(config, params.WriteBufferManager, params.CheckpointUpdater) + ttNode, err := newTTNode(config, params.WriteBufferManager, params.CheckpointUpdater) if err != nil { return nil, err } + nodeList = append(nodeList, ttNode) - if err := fg.AssembleNodes(dmStreamNode, ddNode, writeNode, ttNode); err != nil { + if err := fg.AssembleNodes(nodeList...); err != nil { return nil, err } ds.fg = fg @@ -371,7 +405,7 @@ func NewStreamingNodeDataSyncService(initCtx context.Context, pipelineParams *ut info.Vchan.UnflushedSegments = unflushedSegmentInfos metaCache := metacache.NewMetaCache(info, func(segment *datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() - }) + }, metacache.NoneBm25StatsFactory) return getServiceWithChannel(initCtx, pipelineParams, info, metaCache, unflushedSegmentInfos, flushedSegmentInfos, input) } diff --git a/internal/flushcommon/pipeline/flow_graph_embedding_node.go b/internal/flushcommon/pipeline/flow_graph_embedding_node.go new file mode 100644 index 0000000000000..80de77d7bedd5 --- /dev/null +++ b/internal/flushcommon/pipeline/flow_graph_embedding_node.go @@ -0,0 +1,164 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipeline + +import ( + "fmt" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/function" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// TODO support set EmbddingType +// type EmbeddingType int32 + +type embeddingNode struct { + BaseNode + + schema *schemapb.CollectionSchema + pkField *schemapb.FieldSchema + channelName string + + // embeddingType EmbeddingType + functionRunners map[int64]function.FunctionRunner +} + +func newEmbeddingNode(channelName string, schema *schemapb.CollectionSchema) (*embeddingNode, error) { + baseNode := BaseNode{} + baseNode.SetMaxQueueLength(paramtable.Get().DataNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()) + baseNode.SetMaxParallelism(paramtable.Get().DataNodeCfg.FlowGraphMaxParallelism.GetAsInt32()) + + node := &embeddingNode{ + BaseNode: baseNode, + channelName: channelName, + schema: schema, + functionRunners: make(map[int64]function.FunctionRunner), + } + + for _, field := range schema.GetFields() { + if field.GetIsPrimaryKey() { + node.pkField = field + break + } + } + + for _, tf := range schema.GetFunctions() { + functionRunner, err := function.NewFunctionRunner(schema, tf) + if err != nil { + return nil, err + } + node.functionRunners[tf.GetId()] = functionRunner + } + return node, nil +} + +func (eNode *embeddingNode) Name() string { + return fmt.Sprintf("embeddingNode-%s-%s", "BM25test", eNode.channelName) +} + +func (eNode *embeddingNode) bm25Embedding(runner function.FunctionRunner, inputFieldId, outputFieldId int64, data *storage.InsertData, meta map[int64]*storage.BM25Stats) error { + if _, ok := meta[outputFieldId]; !ok { + meta[outputFieldId] = storage.NewBM25Stats() + } + + embeddingData, ok := data.Data[inputFieldId].GetDataRows().([]string) + if !ok { + return fmt.Errorf("BM25 embedding failed: input field data not varchar") + } + + output, err := runner.BatchRun(embeddingData) + if err != nil { + return err + } + + sparseArray, ok := output[0].(*schemapb.SparseFloatArray) + if !ok { + return fmt.Errorf("BM25 embedding failed: BM25 runner output not sparse map") + } + + meta[outputFieldId].AppendBytes(sparseArray.GetContents()...) + data.Data[outputFieldId] = BuildSparseFieldData(sparseArray) + return nil +} + +func (eNode *embeddingNode) embedding(datas []*storage.InsertData) (map[int64]*storage.BM25Stats, error) { + meta := make(map[int64]*storage.BM25Stats) + for _, data := range datas { + for _, functionRunner := range eNode.functionRunners { + functionSchema := functionRunner.GetSchema() + switch functionSchema.GetType() { + case schemapb.FunctionType_BM25: + err := eNode.bm25Embedding(functionRunner, functionSchema.GetInputFieldIds()[0], functionSchema.GetOutputFieldIds()[0], data, meta) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unknown function type %s", functionSchema.Type) + } + } + } + return meta, nil +} + +func (eNode *embeddingNode) Embedding(datas []*writebuffer.InsertData) error { + for _, data := range datas { + stats, err := eNode.embedding(data.GetDatas()) + if err != nil { + return err + } + data.SetBM25Stats(stats) + } + return nil +} + +func (eNode *embeddingNode) Operate(in []Msg) []Msg { + fgMsg := in[0].(*FlowGraphMsg) + + if fgMsg.IsCloseMsg() { + return []Msg{fgMsg} + } + + insertData, err := writebuffer.PrepareInsert(eNode.schema, eNode.pkField, fgMsg.InsertMessages) + if err != nil { + log.Error("failed to prepare insert data", zap.Error(err)) + panic(err) + } + + err = eNode.Embedding(insertData) + if err != nil { + log.Warn("failed to embedding insert data", zap.Error(err)) + panic(err) + } + + fgMsg.InsertData = insertData + return []Msg{fgMsg} +} + +func BuildSparseFieldData(array *schemapb.SparseFloatArray) storage.FieldData { + return &storage.SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Contents: array.GetContents(), + Dim: array.GetDim(), + }, + } +} diff --git a/internal/flushcommon/pipeline/flow_graph_embedding_node_test.go b/internal/flushcommon/pipeline/flow_graph_embedding_node_test.go new file mode 100644 index 0000000000000..08d8b743becbf --- /dev/null +++ b/internal/flushcommon/pipeline/flow_graph_embedding_node_test.go @@ -0,0 +1,175 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/flowgraph" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/mq/msgstream" +) + +func TestEmbeddingNode_BM25_Operator(t *testing.T) { + collSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: common.TimeStampField, + Name: common.TimeStampFieldName, + DataType: schemapb.DataType_Int64, + }, { + Name: "pk", + FieldID: 100, + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, { + Name: "text", + FieldID: 101, + DataType: schemapb.DataType_VarChar, + }, { + Name: "sparse", + FieldID: 102, + DataType: schemapb.DataType_SparseFloatVector, + IsFunctionOutput: true, + }, + }, + Functions: []*schemapb.FunctionSchema{{ + Name: "BM25", + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + }}, + } + + t.Run("normal case", func(t *testing.T) { + node, err := newEmbeddingNode("test-channel", collSchema) + assert.NoError(t, err) + + var output []Msg + assert.NotPanics(t, func() { + output = node.Operate([]Msg{ + &FlowGraphMsg{ + BaseMsg: flowgraph.NewBaseMsg(false), + InsertMessages: []*msgstream.InsertMsg{{ + BaseMsg: msgstream.BaseMsg{}, + InsertRequest: &msgpb.InsertRequest{ + SegmentID: 1, + Version: msgpb.InsertDataVersion_ColumnBased, + Timestamps: []uint64{1, 1, 1}, + FieldsData: []*schemapb.FieldData{ + { + FieldId: 100, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}}, + }, + }, { + FieldId: 101, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}}}}, + }, + }, + }, + }, + }}, + }, + }) + }) + + assert.Equal(t, 1, len(output)) + + msg, ok := output[0].(*FlowGraphMsg) + assert.True(t, ok) + assert.NotNil(t, msg.InsertData) + }) + + t.Run("with close msg", func(t *testing.T) { + node, err := newEmbeddingNode("test-channel", collSchema) + assert.NoError(t, err) + + var output []Msg + + assert.NotPanics(t, func() { + output = node.Operate([]Msg{ + &FlowGraphMsg{ + BaseMsg: flowgraph.NewBaseMsg(true), + }, + }) + }) + + assert.Equal(t, 1, len(output)) + }) + + t.Run("prepare insert failed", func(t *testing.T) { + node, err := newEmbeddingNode("test-channel", collSchema) + assert.NoError(t, err) + + assert.Panics(t, func() { + node.Operate([]Msg{ + &FlowGraphMsg{ + BaseMsg: flowgraph.NewBaseMsg(false), + InsertMessages: []*msgstream.InsertMsg{{ + BaseMsg: msgstream.BaseMsg{}, + InsertRequest: &msgpb.InsertRequest{ + FieldsData: []*schemapb.FieldData{{ + FieldId: 1100, // invalid fieldID + }}, + }, + }}, + }, + }) + }) + }) + + t.Run("embedding failed", func(t *testing.T) { + node, err := newEmbeddingNode("test-channel", collSchema) + assert.NoError(t, err) + + node.functionRunners[0].GetSchema().Type = 0 + assert.Panics(t, func() { + node.Operate([]Msg{ + &FlowGraphMsg{ + BaseMsg: flowgraph.NewBaseMsg(false), + InsertMessages: []*msgstream.InsertMsg{{ + BaseMsg: msgstream.BaseMsg{}, + InsertRequest: &msgpb.InsertRequest{ + SegmentID: 1, + Version: msgpb.InsertDataVersion_ColumnBased, + Timestamps: []uint64{1, 1, 1}, + FieldsData: []*schemapb.FieldData{ + { + FieldId: 100, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}}, + }, + }, { + FieldId: 101, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}}}}, + }, + }, + }, + }, + }}, + }, + }) + }) + }) +} diff --git a/internal/flushcommon/pipeline/flow_graph_message.go b/internal/flushcommon/pipeline/flow_graph_message.go index 1222cfd2705c2..18877e5b124f5 100644 --- a/internal/flushcommon/pipeline/flow_graph_message.go +++ b/internal/flushcommon/pipeline/flow_graph_message.go @@ -19,6 +19,7 @@ package pipeline import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/flushcommon/util" + "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/flowgraph" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -47,7 +48,10 @@ type ( type FlowGraphMsg struct { BaseMsg InsertMessages []*msgstream.InsertMsg + InsertData []*writebuffer.InsertData + DeleteMessages []*msgstream.DeleteMsg + TimeRange util.TimeRange StartPositions []*msgpb.MsgPosition EndPositions []*msgpb.MsgPosition diff --git a/internal/flushcommon/pipeline/flow_graph_time_tick_node.go b/internal/flushcommon/pipeline/flow_graph_time_tick_node.go index 1fcfeb242f690..985dd39c2a06b 100644 --- a/internal/flushcommon/pipeline/flow_graph_time_tick_node.go +++ b/internal/flushcommon/pipeline/flow_graph_time_tick_node.go @@ -111,7 +111,6 @@ func (ttn *ttNode) Operate(in []Msg) []Msg { if needUpdate { ttn.updateChannelCP(channelPos, curTs, true) } - return []Msg{} } diff --git a/internal/flushcommon/pipeline/flow_graph_write_node.go b/internal/flushcommon/pipeline/flow_graph_write_node.go index a0f6048f08edb..46e1a5a4aad74 100644 --- a/internal/flushcommon/pipeline/flow_graph_write_node.go +++ b/internal/flushcommon/pipeline/flow_graph_write_node.go @@ -11,13 +11,14 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/flushcommon/metacache" "github.com/milvus-io/milvus/internal/flushcommon/util" "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type writeNode struct { @@ -27,6 +28,8 @@ type writeNode struct { wbManager writebuffer.BufferManager updater util.StatsUpdater metacache metacache.MetaCache + collSchema *schemapb.CollectionSchema + pkField *schemapb.FieldSchema } // Name returns node name, implementing flowgraph.Node @@ -79,14 +82,23 @@ func (wNode *writeNode) Operate(in []Msg) []Msg { start, end := fgMsg.StartPositions[0], fgMsg.EndPositions[0] - err := wNode.wbManager.BufferData(wNode.channelName, fgMsg.InsertMessages, fgMsg.DeleteMessages, start, end) + if fgMsg.InsertData == nil { + insertData, err := writebuffer.PrepareInsert(wNode.collSchema, wNode.pkField, fgMsg.InsertMessages) + if err != nil { + log.Error("failed to prepare data", zap.Error(err)) + panic(err) + } + fgMsg.InsertData = insertData + } + + err := wNode.wbManager.BufferData(wNode.channelName, fgMsg.InsertData, fgMsg.DeleteMessages, start, end) if err != nil { log.Error("failed to buffer data", zap.Error(err)) panic(err) } stats := lo.FilterMap( - lo.Keys(lo.SliceToMap(fgMsg.InsertMessages, func(msg *msgstream.InsertMsg) (int64, struct{}) { return msg.GetSegmentID(), struct{}{} })), + lo.Keys(lo.SliceToMap(fgMsg.InsertData, func(data *writebuffer.InsertData) (int64, struct{}) { return data.GetSegmentID(), struct{}{} })), func(id int64, _ int) (*commonpb.SegmentStats, bool) { segInfo, ok := wNode.metacache.GetSegmentByID(id) if !ok { @@ -127,16 +139,24 @@ func newWriteNode( writeBufferManager writebuffer.BufferManager, updater util.StatsUpdater, config *nodeConfig, -) *writeNode { +) (*writeNode, error) { baseNode := BaseNode{} baseNode.SetMaxQueueLength(paramtable.Get().DataNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()) baseNode.SetMaxParallelism(paramtable.Get().DataNodeCfg.FlowGraphMaxParallelism.GetAsInt32()) + collSchema := config.metacache.Schema() + pkField, err := typeutil.GetPrimaryFieldSchema(collSchema) + if err != nil { + return nil, err + } + return &writeNode{ BaseNode: baseNode, channelName: config.vChannelName, wbManager: writeBufferManager, updater: updater, metacache: config.metacache, - } + collSchema: collSchema, + pkField: pkField, + }, nil } diff --git a/internal/flushcommon/syncmgr/meta_writer.go b/internal/flushcommon/syncmgr/meta_writer.go index f558a7671719f..7ffb04de332b7 100644 --- a/internal/flushcommon/syncmgr/meta_writer.go +++ b/internal/flushcommon/syncmgr/meta_writer.go @@ -39,8 +39,9 @@ func BrokerMetaWriter(broker broker.Broker, serverID int64, opts ...retry.Option func (b *brokerMetaWriter) UpdateSync(ctx context.Context, pack *SyncTask) error { var ( - checkPoints = []*datapb.CheckPoint{} - deltaFieldBinlogs = []*datapb.FieldBinlog{} + checkPoints = []*datapb.CheckPoint{} + deltaFieldBinlogs = []*datapb.FieldBinlog{} + deltaBm25StatsBinlogs []*datapb.FieldBinlog = nil ) insertFieldBinlogs := lo.MapToSlice(pack.insertBinlogs, func(_ int64, fieldBinlog *datapb.FieldBinlog) *datapb.FieldBinlog { return fieldBinlog }) @@ -49,6 +50,9 @@ func (b *brokerMetaWriter) UpdateSync(ctx context.Context, pack *SyncTask) error deltaFieldBinlogs = append(deltaFieldBinlogs, pack.deltaBinlog) } + if len(pack.bm25Binlogs) > 0 { + deltaBm25StatsBinlogs = lo.MapToSlice(pack.bm25Binlogs, func(_ int64, fieldBinlog *datapb.FieldBinlog) *datapb.FieldBinlog { return fieldBinlog }) + } // only current segment checkpoint info segment, ok := pack.metacache.GetSegmentByID(pack.segmentID) if !ok { @@ -77,6 +81,7 @@ func (b *brokerMetaWriter) UpdateSync(ctx context.Context, pack *SyncTask) error zap.Int("binlogNum", lo.SumBy(insertFieldBinlogs, getBinlogNum)), zap.Int("statslogNum", lo.SumBy(statsFieldBinlogs, getBinlogNum)), zap.Int("deltalogNum", lo.SumBy(deltaFieldBinlogs, getBinlogNum)), + zap.Int("bm25logNum", lo.SumBy(deltaBm25StatsBinlogs, getBinlogNum)), zap.String("vChannelName", pack.channelName), ) @@ -91,6 +96,7 @@ func (b *brokerMetaWriter) UpdateSync(ctx context.Context, pack *SyncTask) error PartitionID: pack.partitionID, Field2BinlogPaths: insertFieldBinlogs, Field2StatslogPaths: statsFieldBinlogs, + Field2Bm25LogPaths: deltaBm25StatsBinlogs, Deltalogs: deltaFieldBinlogs, CheckPoints: checkPoints, diff --git a/internal/flushcommon/syncmgr/meta_writer_test.go b/internal/flushcommon/syncmgr/meta_writer_test.go index 5db8e1d6a2d6a..32890339d2637 100644 --- a/internal/flushcommon/syncmgr/meta_writer_test.go +++ b/internal/flushcommon/syncmgr/meta_writer_test.go @@ -41,7 +41,7 @@ func (s *MetaWriterSuite) TestNormalSave() { s.broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil) bfs := pkoracle.NewBloomFilterSet() - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs, nil) metacache.UpdateNumOfRows(1000)(seg) s.metacache.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) s.metacache.EXPECT().GetSegmentByID(mock.Anything).Return(seg, true) @@ -58,7 +58,7 @@ func (s *MetaWriterSuite) TestReturnError() { s.broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(errors.New("mocked")) bfs := pkoracle.NewBloomFilterSet() - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs, nil) metacache.UpdateNumOfRows(1000)(seg) s.metacache.EXPECT().GetSegmentByID(mock.Anything).Return(seg, true) s.metacache.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) diff --git a/internal/flushcommon/syncmgr/options.go b/internal/flushcommon/syncmgr/options.go index efed7fbbab70a..ec1eff91c9a64 100644 --- a/internal/flushcommon/syncmgr/options.go +++ b/internal/flushcommon/syncmgr/options.go @@ -19,6 +19,7 @@ func NewSyncTask() *SyncTask { insertBinlogs: make(map[int64]*datapb.FieldBinlog), statsBinlogs: make(map[int64]*datapb.FieldBinlog), deltaBinlog: &datapb.FieldBinlog{}, + bm25Binlogs: make(map[int64]*datapb.FieldBinlog), segmentData: make(map[string][]byte), binlogBlobs: make(map[int64]*storage.Blob), } diff --git a/internal/flushcommon/syncmgr/serializer.go b/internal/flushcommon/syncmgr/serializer.go index 621ced1bb33c5..7d8d64ad5f46f 100644 --- a/internal/flushcommon/syncmgr/serializer.go +++ b/internal/flushcommon/syncmgr/serializer.go @@ -41,6 +41,8 @@ type SyncPack struct { // data insertData []*storage.InsertData deltaData *storage.DeleteData + bm25Stats map[int64]*storage.BM25Stats + // statistics tsFrom typeutil.Timestamp tsTo typeutil.Timestamp @@ -71,6 +73,11 @@ func (p *SyncPack) WithDeleteData(deltaData *storage.DeleteData) *SyncPack { return p } +func (p *SyncPack) WithBM25Stats(stats map[int64]*storage.BM25Stats) *SyncPack { + p.bm25Stats = stats + return p +} + func (p *SyncPack) WithStartPosition(start *msgpb.MsgPosition) *SyncPack { p.startPosition = start return p diff --git a/internal/flushcommon/syncmgr/storage_serializer.go b/internal/flushcommon/syncmgr/storage_serializer.go index 1a3623af435eb..e22cdc206466f 100644 --- a/internal/flushcommon/syncmgr/storage_serializer.go +++ b/internal/flushcommon/syncmgr/storage_serializer.go @@ -99,6 +99,7 @@ func (s *storageV1Serializer) EncodeBuffer(ctx context.Context, pack *SyncPack) } task.binlogBlobs = binlogBlobs + actions := []metacache.SegmentAction{} singlePKStats, batchStatsBlob, err := s.serializeStatslog(pack) if err != nil { log.Warn("failed to serialized statslog", zap.Error(err)) @@ -106,7 +107,19 @@ func (s *storageV1Serializer) EncodeBuffer(ctx context.Context, pack *SyncPack) } task.batchStatsBlob = batchStatsBlob - s.metacache.UpdateSegments(metacache.RollStats(singlePKStats), metacache.WithSegmentIDs(pack.segmentID)) + actions = append(actions, metacache.RollStats(singlePKStats)) + + if len(pack.bm25Stats) > 0 { + statsBlobs, err := s.serializeBM25Stats(pack) + if err != nil { + return nil, err + } + + task.bm25Blobs = statsBlobs + actions = append(actions, metacache.MergeBm25Stats(pack.bm25Stats)) + } + + s.metacache.UpdateSegments(metacache.MergeSegmentAction(actions...), metacache.WithSegmentIDs(pack.segmentID)) } if pack.isFlush { @@ -117,6 +130,15 @@ func (s *storageV1Serializer) EncodeBuffer(ctx context.Context, pack *SyncPack) return nil, err } task.mergedStatsBlob = mergedStatsBlob + + if len(pack.bm25Stats) > 0 { + mergedBM25Blob, err := s.serializeMergedBM25Stats(pack) + if err != nil { + log.Warn("failed to serialize merged bm25 stats log", zap.Error(err)) + return nil, err + } + task.mergedBm25Blob = mergedBM25Blob + } } task.WithFlush() @@ -178,6 +200,23 @@ func (s *storageV1Serializer) serializeBinlog(ctx context.Context, pack *SyncPac return result, nil } +func (s *storageV1Serializer) serializeBM25Stats(pack *SyncPack) (map[int64]*storage.Blob, error) { + blobs := make(map[int64]*storage.Blob) + for fieldID, stats := range pack.bm25Stats { + bytes, err := stats.Serialize() + if err != nil { + return nil, err + } + + blobs[fieldID] = &storage.Blob{ + Value: bytes, + MemorySize: int64(len(bytes)), + RowNum: stats.NumRow(), + } + } + return blobs, nil +} + func (s *storageV1Serializer) serializeStatslog(pack *SyncPack) (*storage.PrimaryKeyStats, *storage.Blob, error) { var rowNum int64 var pkFieldData []storage.FieldData @@ -220,6 +259,33 @@ func (s *storageV1Serializer) serializeMergedPkStats(pack *SyncPack) (*storage.B }), segment.NumOfRows()) } +func (s *storageV1Serializer) serializeMergedBM25Stats(pack *SyncPack) (map[int64]*storage.Blob, error) { + segment, ok := s.metacache.GetSegmentByID(pack.segmentID) + if !ok { + return nil, merr.WrapErrSegmentNotFound(pack.segmentID) + } + + stats := segment.GetBM25Stats() + if stats == nil { + return nil, fmt.Errorf("searalize empty bm25 stats") + } + + fieldBytes, numRow, err := stats.Serialize() + if err != nil { + return nil, err + } + + blobs := make(map[int64]*storage.Blob) + for fieldID, bytes := range fieldBytes { + blobs[fieldID] = &storage.Blob{ + Value: bytes, + MemorySize: int64(len(bytes)), + RowNum: numRow[fieldID], + } + } + return blobs, nil +} + func (s *storageV1Serializer) serializeDeltalog(pack *SyncPack) (*storage.Blob, error) { if len(pack.deltaData.Pks) == 0 { return &storage.Blob{}, nil diff --git a/internal/flushcommon/syncmgr/storage_serializer_test.go b/internal/flushcommon/syncmgr/storage_serializer_test.go index 3f270af792285..2be0b0788eb17 100644 --- a/internal/flushcommon/syncmgr/storage_serializer_test.go +++ b/internal/flushcommon/syncmgr/storage_serializer_test.go @@ -245,7 +245,7 @@ func (s *StorageV1SerializerSuite) TestSerializeInsert() { pack.WithFlush() bfs := s.getBfs() - segInfo := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs) + segInfo := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs, nil) metacache.UpdateNumOfRows(1000)(segInfo) s.mockCache.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Run(func(action metacache.SegmentAction, filters ...metacache.SegmentFilter) { action(segInfo) diff --git a/internal/flushcommon/syncmgr/sync_manager_test.go b/internal/flushcommon/syncmgr/sync_manager_test.go index a47efd4630aa3..f05ccf554e2eb 100644 --- a/internal/flushcommon/syncmgr/sync_manager_test.go +++ b/internal/flushcommon/syncmgr/sync_manager_test.go @@ -155,7 +155,7 @@ func (s *SyncManagerSuite) getSuiteSyncTask() *SyncTask { func (s *SyncManagerSuite) TestSubmit() { s.broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil) bfs := pkoracle.NewBloomFilterSet() - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs, nil) metacache.UpdateNumOfRows(1000)(seg) s.metacache.EXPECT().GetSegmentByID(s.segmentID).Return(seg, true) s.metacache.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) @@ -184,7 +184,7 @@ func (s *SyncManagerSuite) TestCompacted() { segmentID.Store(req.GetSegmentID()) }).Return(nil) bfs := pkoracle.NewBloomFilterSet() - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs, nil) metacache.UpdateNumOfRows(1000)(seg) s.metacache.EXPECT().GetSegmentByID(s.segmentID).Return(seg, true) s.metacache.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) diff --git a/internal/flushcommon/syncmgr/task.go b/internal/flushcommon/syncmgr/task.go index e8dd6054ef5af..9e910b9cbcddb 100644 --- a/internal/flushcommon/syncmgr/task.go +++ b/internal/flushcommon/syncmgr/task.go @@ -71,14 +71,20 @@ type SyncTask struct { insertBinlogs map[int64]*datapb.FieldBinlog // map[int64]*datapb.Binlog statsBinlogs map[int64]*datapb.FieldBinlog // map[int64]*datapb.Binlog + bm25Binlogs map[int64]*datapb.FieldBinlog deltaBinlog *datapb.FieldBinlog - binlogBlobs map[int64]*storage.Blob // fieldID => blob - binlogMemsize map[int64]int64 // memory size + binlogBlobs map[int64]*storage.Blob // fieldID => blob + binlogMemsize map[int64]int64 // memory size + + bm25Blobs map[int64]*storage.Blob + mergedBm25Blob map[int64]*storage.Blob + batchStatsBlob *storage.Blob mergedStatsBlob *storage.Blob - deltaBlob *storage.Blob - deltaRowCount int64 + + deltaBlob *storage.Blob + deltaRowCount int64 // prefetched log ids ids []int64 @@ -145,6 +151,10 @@ func (t *SyncTask) Run(ctx context.Context) (err error) { t.processStatsBlob() t.processDeltaBlob() + if len(t.bm25Binlogs) > 0 || len(t.mergedBm25Blob) > 0 { + t.processBM25StastBlob() + } + err = t.writeLogs(ctx) if err != nil { log.Warn("failed to save serialized data into storage", zap.Error(err)) @@ -182,7 +192,7 @@ func (t *SyncTask) Run(ctx context.Context) (err error) { log.Info("segment removed", zap.Int64("segmentID", t.segment.SegmentID()), zap.String("channel", t.channelName)) } - log.Info("task done", zap.Float64("flushedSize", totalSize)) + log.Info("task done", zap.Float64("flushedSize", totalSize), zap.Duration("interval", t.tr.RecordSpan())) if !t.isFlush { metrics.DataNodeAutoFlushBufferCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SuccessLabel, t.level.String()).Inc() @@ -207,6 +217,10 @@ func (t *SyncTask) prefetchIDs() error { if t.deltaBlob != nil { totalIDCount++ } + if t.bm25Blobs != nil { + totalIDCount += len(t.bm25Blobs) + } + start, _, err := t.allocator.Alloc(uint32(totalIDCount)) if err != nil { return err @@ -240,6 +254,36 @@ func (t *SyncTask) processInsertBlobs() { } } +func (t *SyncTask) processBM25StastBlob() { + for fieldID, blob := range t.bm25Blobs { + k := metautil.JoinIDPath(t.collectionID, t.partitionID, t.segmentID, fieldID, t.nextID()) + key := path.Join(t.chunkManager.RootPath(), common.SegmentBm25LogPath, k) + t.segmentData[key] = blob.GetValue() + t.appendBM25Statslog(fieldID, &datapb.Binlog{ + EntriesNum: blob.RowNum, + TimestampFrom: t.tsFrom, + TimestampTo: t.tsTo, + LogPath: key, + LogSize: int64(len(blob.GetValue())), + MemorySize: blob.MemorySize, + }) + } + + for fieldID, blob := range t.mergedBm25Blob { + k := metautil.JoinIDPath(t.collectionID, t.partitionID, t.segmentID, fieldID, int64(storage.CompoundStatsType)) + key := path.Join(t.chunkManager.RootPath(), common.SegmentBm25LogPath, k) + t.segmentData[key] = blob.GetValue() + t.appendBM25Statslog(fieldID, &datapb.Binlog{ + EntriesNum: blob.RowNum, + TimestampFrom: t.tsFrom, + TimestampTo: t.tsTo, + LogPath: key, + LogSize: int64(len(blob.GetValue())), + MemorySize: blob.MemorySize, + }) + } +} + func (t *SyncTask) processStatsBlob() { if t.batchStatsBlob != nil { t.convertBlob2StatsBinlog(t.batchStatsBlob, t.pkField.GetFieldID(), t.nextID(), t.batchSize) @@ -297,6 +341,17 @@ func (t *SyncTask) appendBinlog(fieldID int64, binlog *datapb.Binlog) { fieldBinlog.Binlogs = append(fieldBinlog.Binlogs, binlog) } +func (t *SyncTask) appendBM25Statslog(fieldID int64, log *datapb.Binlog) { + fieldBinlog, ok := t.bm25Binlogs[fieldID] + if !ok { + fieldBinlog = &datapb.FieldBinlog{ + FieldID: fieldID, + } + t.bm25Binlogs[fieldID] = fieldBinlog + } + fieldBinlog.Binlogs = append(fieldBinlog.Binlogs, log) +} + func (t *SyncTask) appendStatslog(fieldID int64, statlog *datapb.Binlog) { fieldBinlog, ok := t.statsBinlogs[fieldID] if !ok { diff --git a/internal/flushcommon/syncmgr/task_test.go b/internal/flushcommon/syncmgr/task_test.go index 4cf9275ea3014..b4fc64ae7e73a 100644 --- a/internal/flushcommon/syncmgr/task_test.go +++ b/internal/flushcommon/syncmgr/task_test.go @@ -185,7 +185,7 @@ func (s *SyncTaskSuite) TestRunNormal() { } bfs.UpdatePKRange(fd) - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, bfs, nil) metacache.UpdateNumOfRows(1000)(seg) seg.GetBloomFilterSet().Roll() s.metacache.EXPECT().GetSegmentByID(s.segmentID).Return(seg, true) @@ -273,7 +273,7 @@ func (s *SyncTaskSuite) TestRunL0Segment() { defer cancel() s.broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil) bfs := pkoracle.NewBloomFilterSet() - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{Level: datapb.SegmentLevel_L0}, bfs) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{Level: datapb.SegmentLevel_L0}, bfs, nil) s.metacache.EXPECT().GetSegmentByID(s.segmentID).Return(seg, true) s.metacache.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) s.metacache.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() @@ -314,7 +314,7 @@ func (s *SyncTaskSuite) TestRunError() { }) s.metacache.ExpectedCalls = nil - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, pkoracle.NewBloomFilterSet()) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{}, pkoracle.NewBloomFilterSet(), nil) metacache.UpdateNumOfRows(1000)(seg) s.metacache.EXPECT().GetSegmentByID(s.segmentID).Return(seg, true) s.metacache.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) diff --git a/internal/flushcommon/writebuffer/bf_write_buffer.go b/internal/flushcommon/writebuffer/bf_write_buffer.go index 3fcb5df30dc34..b4541d9902dd0 100644 --- a/internal/flushcommon/writebuffer/bf_write_buffer.go +++ b/internal/flushcommon/writebuffer/bf_write_buffer.go @@ -30,7 +30,7 @@ func NewBFWriteBuffer(channel string, metacache metacache.MetaCache, syncMgr syn }, nil } -func (wb *bfWriteBuffer) dispatchDeleteMsgs(groups []*inData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) { +func (wb *bfWriteBuffer) dispatchDeleteMsgs(groups []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) { batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt() split := func(pks []storage.PrimaryKey, pkTss []uint64, segments []*metacache.SegmentInfo) { @@ -86,17 +86,12 @@ func (wb *bfWriteBuffer) dispatchDeleteMsgs(groups []*inData, deleteMsgs []*msgs } } -func (wb *bfWriteBuffer) BufferData(insertMsgs []*msgstream.InsertMsg, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) error { +func (wb *bfWriteBuffer) BufferData(insertData []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) error { wb.mut.Lock() defer wb.mut.Unlock() - groups, err := wb.prepareInsert(insertMsgs) - if err != nil { - return err - } - // buffer insert data and add segment if not exists - for _, inData := range groups { + for _, inData := range insertData { err := wb.bufferInsert(inData, startPos, endPos) if err != nil { return err @@ -105,10 +100,10 @@ func (wb *bfWriteBuffer) BufferData(insertMsgs []*msgstream.InsertMsg, deleteMsg // distribute delete msg // bf write buffer check bloom filter of segment and current insert batch to decide which segment to write delete data - wb.dispatchDeleteMsgs(groups, deleteMsgs, startPos, endPos) + wb.dispatchDeleteMsgs(insertData, deleteMsgs, startPos, endPos) // update pk oracle - for _, inData := range groups { + for _, inData := range insertData { // segment shall always exists after buffer insert segments := wb.metaCache.GetSegmentsBy( metacache.WithSegmentIDs(inData.segmentID)) diff --git a/internal/flushcommon/writebuffer/bf_write_buffer_test.go b/internal/flushcommon/writebuffer/bf_write_buffer_test.go index e50420170bb63..a9c13e5c3b568 100644 --- a/internal/flushcommon/writebuffer/bf_write_buffer_test.go +++ b/internal/flushcommon/writebuffer/bf_write_buffer_test.go @@ -29,14 +29,18 @@ import ( type BFWriteBufferSuite struct { testutils.PromMetricsSuite - collID int64 - channelName string - collInt64Schema *schemapb.CollectionSchema - collVarcharSchema *schemapb.CollectionSchema - syncMgr *syncmgr.MockSyncManager - metacacheInt64 *metacache.MockMetaCache - metacacheVarchar *metacache.MockMetaCache - broker *broker.MockBroker + collID int64 + channelName string + syncMgr *syncmgr.MockSyncManager + metacacheInt64 *metacache.MockMetaCache + metacacheVarchar *metacache.MockMetaCache + broker *broker.MockBroker + + collInt64Schema *schemapb.CollectionSchema + collInt64PkField *schemapb.FieldSchema + + collVarcharSchema *schemapb.CollectionSchema + collVarcharPkField *schemapb.FieldSchema } func (s *BFWriteBufferSuite) SetupSuite() { @@ -62,6 +66,11 @@ func (s *BFWriteBufferSuite) SetupSuite() { }, }, } + + s.collInt64PkField = &schemapb.FieldSchema{ + FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, + } + s.collVarcharSchema = &schemapb.CollectionSchema{ Name: "test_collection", Fields: []*schemapb.FieldSchema{ @@ -84,6 +93,11 @@ func (s *BFWriteBufferSuite) SetupSuite() { }, }, } + s.collVarcharPkField = &schemapb.FieldSchema{ + FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true, TypeParams: []*commonpb.KeyValuePair{ + {Key: common.MaxLengthKey, Value: "100"}, + }, + } } func (s *BFWriteBufferSuite) composeInsertMsg(segmentID int64, rowCount int, dim int, pkType schemapb.DataType) ([]int64, *msgstream.InsertMsg) { @@ -199,17 +213,20 @@ func (s *BFWriteBufferSuite) TestBufferData() { wb, err := NewBFWriteBuffer(s.channelName, s.metacacheInt64, s.syncMgr, &writeBufferOption{}) s.NoError(err) - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet()) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet(), nil) s.metacacheInt64.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) s.metacacheInt64.EXPECT().GetSegmentByID(int64(1000)).Return(nil, false) - s.metacacheInt64.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything).Return() + s.metacacheInt64.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() s.metacacheInt64.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() pks, msg := s.composeInsertMsg(1000, 10, 128, schemapb.DataType_Int64) delMsg := s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) })) metrics.DataNodeFlowGraphBufferDataSize.Reset() - err = wb.BufferData([]*msgstream.InsertMsg{msg}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) + insertData, err := PrepareInsert(s.collInt64Schema, s.collInt64PkField, []*msgstream.InsertMsg{msg}) + s.NoError(err) + + err = wb.BufferData(insertData, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) s.NoError(err) value, err := metrics.DataNodeFlowGraphBufferDataSize.GetMetricWithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(s.metacacheInt64.Collection())) @@ -217,7 +234,7 @@ func (s *BFWriteBufferSuite) TestBufferData() { s.MetricsEqual(value, 5607) delMsg = s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) })) - err = wb.BufferData([]*msgstream.InsertMsg{}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) + err = wb.BufferData([]*InsertData{}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) s.NoError(err) s.MetricsEqual(value, 5847) }) @@ -226,57 +243,38 @@ func (s *BFWriteBufferSuite) TestBufferData() { wb, err := NewBFWriteBuffer(s.channelName, s.metacacheVarchar, s.syncMgr, &writeBufferOption{}) s.NoError(err) - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet()) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet(), nil) s.metacacheVarchar.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) s.metacacheVarchar.EXPECT().GetSegmentByID(int64(1000)).Return(nil, false) - s.metacacheVarchar.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything).Return() + s.metacacheVarchar.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() s.metacacheVarchar.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() pks, msg := s.composeInsertMsg(1000, 10, 128, schemapb.DataType_VarChar) delMsg := s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewVarCharPrimaryKey(fmt.Sprintf("%v", id)) })) metrics.DataNodeFlowGraphBufferDataSize.Reset() - err = wb.BufferData([]*msgstream.InsertMsg{msg}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) + insertData, err := PrepareInsert(s.collVarcharSchema, s.collVarcharPkField, []*msgstream.InsertMsg{msg}) + s.NoError(err) + + err = wb.BufferData(insertData, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) s.NoError(err) value, err := metrics.DataNodeFlowGraphBufferDataSize.GetMetricWithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(s.metacacheInt64.Collection())) s.NoError(err) s.MetricsEqual(value, 7227) }) +} +func (s *BFWriteBufferSuite) TestPrepareInsert() { s.Run("int_pk_type_not_match", func() { - wb, err := NewBFWriteBuffer(s.channelName, s.metacacheInt64, s.syncMgr, &writeBufferOption{}) - s.NoError(err) - - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet()) - s.metacacheInt64.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) - s.metacacheInt64.EXPECT().GetSegmentByID(int64(1000)).Return(nil, false) - s.metacacheInt64.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything).Return() - s.metacacheInt64.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() - - pks, msg := s.composeInsertMsg(1000, 10, 128, schemapb.DataType_VarChar) - delMsg := s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) })) - - metrics.DataNodeFlowGraphBufferDataSize.Reset() - err = wb.BufferData([]*msgstream.InsertMsg{msg}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) + _, msg := s.composeInsertMsg(1000, 10, 128, schemapb.DataType_VarChar) + _, err := PrepareInsert(s.collInt64Schema, s.collInt64PkField, []*msgstream.InsertMsg{msg}) s.Error(err) }) s.Run("varchar_pk_not_match", func() { - wb, err := NewBFWriteBuffer(s.channelName, s.metacacheVarchar, s.syncMgr, &writeBufferOption{}) - s.NoError(err) - - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet()) - s.metacacheVarchar.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) - s.metacacheVarchar.EXPECT().GetSegmentByID(int64(1000)).Return(nil, false) - s.metacacheVarchar.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything).Return() - s.metacacheVarchar.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() - - pks, msg := s.composeInsertMsg(1000, 10, 128, schemapb.DataType_Int64) - delMsg := s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) })) - - metrics.DataNodeFlowGraphBufferDataSize.Reset() - err = wb.BufferData([]*msgstream.InsertMsg{msg}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) + _, msg := s.composeInsertMsg(1000, 10, 128, schemapb.DataType_Int64) + _, err := PrepareInsert(s.collVarcharSchema, s.collVarcharPkField, []*msgstream.InsertMsg{msg}) s.Error(err) }) } @@ -294,15 +292,15 @@ func (s *BFWriteBufferSuite) TestAutoSync() { }) s.NoError(err) - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet()) - seg1 := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1002}, pkoracle.NewBloomFilterSet()) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet(), nil) + seg1 := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1002}, pkoracle.NewBloomFilterSet(), nil) s.metacacheInt64.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) s.metacacheInt64.EXPECT().GetSegmentByID(int64(1000)).Return(nil, false).Once() s.metacacheInt64.EXPECT().GetSegmentByID(int64(1000)).Return(seg, true).Once() s.metacacheInt64.EXPECT().GetSegmentByID(int64(1002)).Return(seg1, true) s.metacacheInt64.EXPECT().GetSegmentIDsBy(mock.Anything).Return([]int64{1002}) s.metacacheInt64.EXPECT().GetSegmentIDsBy(mock.Anything, mock.Anything, mock.Anything).Return([]int64{}) - s.metacacheInt64.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything).Return() + s.metacacheInt64.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() s.metacacheInt64.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() s.metacacheInt64.EXPECT().UpdateSegments(mock.Anything, mock.Anything, mock.Anything).Return() s.syncMgr.EXPECT().SyncData(mock.Anything, mock.Anything, mock.Anything).Return(nil) @@ -310,8 +308,11 @@ func (s *BFWriteBufferSuite) TestAutoSync() { pks, msg := s.composeInsertMsg(1000, 10, 128, schemapb.DataType_Int64) delMsg := s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) })) + insertData, err := PrepareInsert(s.collInt64Schema, s.collInt64PkField, []*msgstream.InsertMsg{msg}) + s.NoError(err) + metrics.DataNodeFlowGraphBufferDataSize.Reset() - err = wb.BufferData([]*msgstream.InsertMsg{msg}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) + err = wb.BufferData(insertData, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) s.NoError(err) value, err := metrics.DataNodeFlowGraphBufferDataSize.GetMetricWithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(s.metacacheInt64.Collection())) diff --git a/internal/flushcommon/writebuffer/insert_buffer.go b/internal/flushcommon/writebuffer/insert_buffer.go index b7f496e83ada1..6023311dfb559 100644 --- a/internal/flushcommon/writebuffer/insert_buffer.go +++ b/internal/flushcommon/writebuffer/insert_buffer.go @@ -74,7 +74,8 @@ type InsertBuffer struct { BufferBase collSchema *schemapb.CollectionSchema - buffers []*storage.InsertData + buffers []*storage.InsertData + statsBuffer *statsBuffer } func NewInsertBuffer(sch *schemapb.CollectionSchema) (*InsertBuffer, error) { @@ -100,6 +101,9 @@ func NewInsertBuffer(sch *schemapb.CollectionSchema) (*InsertBuffer, error) { collSchema: sch, } + if len(sch.GetFunctions()) > 0 { + ib.statsBuffer = newStatsBuffer() + } return ib, nil } @@ -116,17 +120,28 @@ func (ib *InsertBuffer) Yield() []*storage.InsertData { return result } -func (ib *InsertBuffer) Buffer(inData *inData, startPos, endPos *msgpb.MsgPosition) int64 { +func (ib *InsertBuffer) YieldStats() map[int64]*storage.BM25Stats { + if ib.statsBuffer == nil { + return nil + } + return ib.statsBuffer.yieldBuffer() +} + +func (ib *InsertBuffer) Buffer(inData *InsertData, startPos, endPos *msgpb.MsgPosition) int64 { bufferedSize := int64(0) for idx, data := range inData.data { tsData := inData.tsField[idx] + tr := ib.getTimestampRange(tsData) ib.buffer(data, tr, startPos, endPos) - // update buffer size ib.UpdateStatistics(int64(data.GetRowNum()), int64(data.GetMemorySize()), tr, startPos, endPos) bufferedSize += int64(data.GetMemorySize()) } + if inData.bm25Stats != nil { + ib.statsBuffer.Buffer(inData.bm25Stats) + } + return bufferedSize } diff --git a/internal/flushcommon/writebuffer/insert_buffer_test.go b/internal/flushcommon/writebuffer/insert_buffer_test.go index c7ac20d215343..a50016b6c905b 100644 --- a/internal/flushcommon/writebuffer/insert_buffer_test.go +++ b/internal/flushcommon/writebuffer/insert_buffer_test.go @@ -20,6 +20,7 @@ import ( type InsertBufferSuite struct { suite.Suite collSchema *schemapb.CollectionSchema + pkField *schemapb.FieldSchema } func (s *InsertBufferSuite) SetupSuite() { @@ -44,6 +45,7 @@ func (s *InsertBufferSuite) SetupSuite() { }, }, } + s.pkField = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true} } func (s *InsertBufferSuite) composeInsertMsg(rowCount int, dim int) ([]int64, *msgstream.InsertMsg) { @@ -127,15 +129,12 @@ func (s *InsertBufferSuite) TestBasic() { } func (s *InsertBufferSuite) TestBuffer() { - wb := &writeBufferBase{ - collSchema: s.collSchema, - } _, insertMsg := s.composeInsertMsg(10, 128) insertBuffer, err := NewInsertBuffer(s.collSchema) s.Require().NoError(err) - groups, err := wb.prepareInsert([]*msgstream.InsertMsg{insertMsg}) + groups, err := PrepareInsert(s.collSchema, s.pkField, []*msgstream.InsertMsg{insertMsg}) s.Require().NoError(err) s.Require().Len(groups, 1) @@ -146,9 +145,6 @@ func (s *InsertBufferSuite) TestBuffer() { } func (s *InsertBufferSuite) TestYield() { - wb := &writeBufferBase{ - collSchema: s.collSchema, - } insertBuffer, err := NewInsertBuffer(s.collSchema) s.Require().NoError(err) @@ -159,7 +155,7 @@ func (s *InsertBufferSuite) TestYield() { s.Require().NoError(err) pks, insertMsg := s.composeInsertMsg(10, 128) - groups, err := wb.prepareInsert([]*msgstream.InsertMsg{insertMsg}) + groups, err := PrepareInsert(s.collSchema, s.pkField, []*msgstream.InsertMsg{insertMsg}) s.Require().NoError(err) s.Require().Len(groups, 1) diff --git a/internal/flushcommon/writebuffer/l0_write_buffer.go b/internal/flushcommon/writebuffer/l0_write_buffer.go index 45adda5f5011b..3f9b458b5ffa9 100644 --- a/internal/flushcommon/writebuffer/l0_write_buffer.go +++ b/internal/flushcommon/writebuffer/l0_write_buffer.go @@ -52,9 +52,9 @@ func NewL0WriteBuffer(channel string, metacache metacache.MetaCache, syncMgr syn }, nil } -func (wb *l0WriteBuffer) dispatchDeleteMsgs(groups []*inData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) { +func (wb *l0WriteBuffer) dispatchDeleteMsgs(groups []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) { batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt() - split := func(pks []storage.PrimaryKey, pkTss []uint64, partitionSegments []*metacache.SegmentInfo, partitionGroups []*inData) []bool { + split := func(pks []storage.PrimaryKey, pkTss []uint64, partitionSegments []*metacache.SegmentInfo, partitionGroups []*InsertData) []bool { lc := storage.NewBatchLocationsCache(pks) // use hits to cache result @@ -93,7 +93,7 @@ func (wb *l0WriteBuffer) dispatchDeleteMsgs(groups []*inData, deleteMsgs []*msgs pkTss := delMsg.GetTimestamps() partitionSegments := wb.metaCache.GetSegmentsBy(metacache.WithPartitionID(delMsg.PartitionID), metacache.WithSegmentState(commonpb.SegmentState_Growing, commonpb.SegmentState_Sealed, commonpb.SegmentState_Flushing, commonpb.SegmentState_Flushed)) - partitionGroups := lo.Filter(groups, func(inData *inData, _ int) bool { + partitionGroups := lo.Filter(groups, func(inData *InsertData, _ int) bool { return delMsg.GetPartitionID() == common.AllPartitionsID || delMsg.GetPartitionID() == inData.partitionID }) @@ -151,17 +151,12 @@ func (wb *l0WriteBuffer) dispatchDeleteMsgsWithoutFilter(deleteMsgs []*msgstream } } -func (wb *l0WriteBuffer) BufferData(insertMsgs []*msgstream.InsertMsg, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) error { +func (wb *l0WriteBuffer) BufferData(insertData []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) error { wb.mut.Lock() defer wb.mut.Unlock() - groups, err := wb.prepareInsert(insertMsgs) - if err != nil { - return err - } - // buffer insert data and add segment if not exists - for _, inData := range groups { + for _, inData := range insertData { err := wb.bufferInsert(inData, startPos, endPos) if err != nil { return err @@ -175,11 +170,11 @@ func (wb *l0WriteBuffer) BufferData(insertMsgs []*msgstream.InsertMsg, deleteMsg } else { // distribute delete msg // bf write buffer check bloom filter of segment and current insert batch to decide which segment to write delete data - wb.dispatchDeleteMsgs(groups, deleteMsgs, startPos, endPos) + wb.dispatchDeleteMsgs(insertData, deleteMsgs, startPos, endPos) } // update pk oracle - for _, inData := range groups { + for _, inData := range insertData { // segment shall always exists after buffer insert segments := wb.metaCache.GetSegmentsBy(metacache.WithSegmentIDs(inData.segmentID)) for _, segment := range segments { @@ -230,7 +225,7 @@ func (wb *l0WriteBuffer) getL0SegmentID(partitionID int64, startPos *msgpb.MsgPo StartPosition: startPos, State: commonpb.SegmentState_Growing, Level: datapb.SegmentLevel_L0, - }, func(_ *datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() }, metacache.SetStartPosRecorded(false)) + }, func(_ *datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSet() }, metacache.NoneBm25StatsFactory, metacache.SetStartPosRecorded(false)) log.Info("Add a new level zero segment", zap.Int64("segmentID", segmentID), zap.String("level", datapb.SegmentLevel_L0.String()), diff --git a/internal/flushcommon/writebuffer/l0_write_buffer_test.go b/internal/flushcommon/writebuffer/l0_write_buffer_test.go index e86a9fa6d8cf1..53c911544a342 100644 --- a/internal/flushcommon/writebuffer/l0_write_buffer_test.go +++ b/internal/flushcommon/writebuffer/l0_write_buffer_test.go @@ -32,6 +32,7 @@ type L0WriteBufferSuite struct { channelName string collID int64 collSchema *schemapb.CollectionSchema + pkSchema *schemapb.FieldSchema syncMgr *syncmgr.MockSyncManager metacache *metacache.MockMetaCache allocator *allocator.MockGIDAllocator @@ -60,6 +61,13 @@ func (s *L0WriteBufferSuite) SetupSuite() { }, }, } + + for _, field := range s.collSchema.Fields { + if field.GetIsPrimaryKey() { + s.pkSchema = field + break + } + } s.channelName = "by-dev-rootcoord-dml_0v0" } @@ -177,14 +185,16 @@ func (s *L0WriteBufferSuite) TestBufferData() { pks, msg := s.composeInsertMsg(1000, 10, 128, schemapb.DataType_Int64) delMsg := s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) })) - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet()) + seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet(), nil) s.metacache.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) s.metacache.EXPECT().GetSegmentByID(int64(1000)).Return(nil, false).Once() - s.metacache.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything).Return() + s.metacache.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() s.metacache.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() metrics.DataNodeFlowGraphBufferDataSize.Reset() - err = wb.BufferData([]*msgstream.InsertMsg{msg}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) + insertData, err := PrepareInsert(s.collSchema, s.pkSchema, []*msgstream.InsertMsg{msg}) + s.NoError(err) + err = wb.BufferData(insertData, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) s.NoError(err) value, err := metrics.DataNodeFlowGraphBufferDataSize.GetMetricWithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(s.metacache.Collection())) @@ -192,29 +202,10 @@ func (s *L0WriteBufferSuite) TestBufferData() { s.MetricsEqual(value, 5607) delMsg = s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) })) - err = wb.BufferData([]*msgstream.InsertMsg{}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) + err = wb.BufferData([]*InsertData{}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) s.NoError(err) s.MetricsEqual(value, 5847) }) - - s.Run("pk_type_not_match", func() { - wb, err := NewL0WriteBuffer(s.channelName, s.metacache, s.syncMgr, &writeBufferOption{ - idAllocator: s.allocator, - }) - s.NoError(err) - - pks, msg := s.composeInsertMsg(1000, 10, 128, schemapb.DataType_VarChar) - delMsg := s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) })) - - seg := metacache.NewSegmentInfo(&datapb.SegmentInfo{ID: 1000}, pkoracle.NewBloomFilterSet()) - s.metacache.EXPECT().GetSegmentsBy(mock.Anything, mock.Anything).Return([]*metacache.SegmentInfo{seg}) - s.metacache.EXPECT().AddSegment(mock.Anything, mock.Anything, mock.Anything).Return() - s.metacache.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() - - metrics.DataNodeFlowGraphBufferDataSize.Reset() - err = wb.BufferData([]*msgstream.InsertMsg{msg}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) - s.Error(err) - }) } func (s *L0WriteBufferSuite) TestCreateFailure() { diff --git a/internal/flushcommon/writebuffer/manager.go b/internal/flushcommon/writebuffer/manager.go index 028c8e5503e83..cd46c68c8b0f3 100644 --- a/internal/flushcommon/writebuffer/manager.go +++ b/internal/flushcommon/writebuffer/manager.go @@ -35,7 +35,7 @@ type BufferManager interface { DropChannel(channel string) DropPartitions(channel string, partitionIDs []int64) // BufferData put data into channel write buffer. - BufferData(channel string, insertMsgs []*msgstream.InsertMsg, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) error + BufferData(channel string, insertData []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) error // GetCheckpoint returns checkpoint for provided channel. GetCheckpoint(channel string) (*msgpb.MsgPosition, bool, error) // NotifyCheckpointUpdated notify write buffer checkpoint updated to reset flushTs. @@ -188,7 +188,7 @@ func (m *bufferManager) FlushChannel(ctx context.Context, channel string, flushT } // BufferData put data into channel write buffer. -func (m *bufferManager) BufferData(channel string, insertMsgs []*msgstream.InsertMsg, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) error { +func (m *bufferManager) BufferData(channel string, insertData []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) error { m.mut.RLock() buf, ok := m.buffers[channel] m.mut.RUnlock() @@ -199,7 +199,7 @@ func (m *bufferManager) BufferData(channel string, insertMsgs []*msgstream.Inser return merr.WrapErrChannelNotFound(channel) } - return buf.BufferData(insertMsgs, deleteMsgs, startPos, endPos) + return buf.BufferData(insertData, deleteMsgs, startPos, endPos) } // GetCheckpoint returns checkpoint for provided channel. diff --git a/internal/flushcommon/writebuffer/mock_manager.go b/internal/flushcommon/writebuffer/mock_manager.go index d58830cc1ec28..4b9bde855779a 100644 --- a/internal/flushcommon/writebuffer/mock_manager.go +++ b/internal/flushcommon/writebuffer/mock_manager.go @@ -26,13 +26,13 @@ func (_m *MockBufferManager) EXPECT() *MockBufferManager_Expecter { return &MockBufferManager_Expecter{mock: &_m.Mock} } -// BufferData provides a mock function with given fields: channel, insertMsgs, deleteMsgs, startPos, endPos -func (_m *MockBufferManager) BufferData(channel string, insertMsgs []*msgstream.InsertMsg, deleteMsgs []*msgstream.DeleteMsg, startPos *msgpb.MsgPosition, endPos *msgpb.MsgPosition) error { - ret := _m.Called(channel, insertMsgs, deleteMsgs, startPos, endPos) +// BufferData provides a mock function with given fields: channel, insertData, deleteMsgs, startPos, endPos +func (_m *MockBufferManager) BufferData(channel string, insertData []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos *msgpb.MsgPosition, endPos *msgpb.MsgPosition) error { + ret := _m.Called(channel, insertData, deleteMsgs, startPos, endPos) var r0 error - if rf, ok := ret.Get(0).(func(string, []*msgstream.InsertMsg, []*msgstream.DeleteMsg, *msgpb.MsgPosition, *msgpb.MsgPosition) error); ok { - r0 = rf(channel, insertMsgs, deleteMsgs, startPos, endPos) + if rf, ok := ret.Get(0).(func(string, []*InsertData, []*msgstream.DeleteMsg, *msgpb.MsgPosition, *msgpb.MsgPosition) error); ok { + r0 = rf(channel, insertData, deleteMsgs, startPos, endPos) } else { r0 = ret.Error(0) } @@ -47,17 +47,17 @@ type MockBufferManager_BufferData_Call struct { // BufferData is a helper method to define mock.On call // - channel string -// - insertMsgs []*msgstream.InsertMsg +// - insertData []*InsertData // - deleteMsgs []*msgstream.DeleteMsg // - startPos *msgpb.MsgPosition // - endPos *msgpb.MsgPosition -func (_e *MockBufferManager_Expecter) BufferData(channel interface{}, insertMsgs interface{}, deleteMsgs interface{}, startPos interface{}, endPos interface{}) *MockBufferManager_BufferData_Call { - return &MockBufferManager_BufferData_Call{Call: _e.mock.On("BufferData", channel, insertMsgs, deleteMsgs, startPos, endPos)} +func (_e *MockBufferManager_Expecter) BufferData(channel interface{}, insertData interface{}, deleteMsgs interface{}, startPos interface{}, endPos interface{}) *MockBufferManager_BufferData_Call { + return &MockBufferManager_BufferData_Call{Call: _e.mock.On("BufferData", channel, insertData, deleteMsgs, startPos, endPos)} } -func (_c *MockBufferManager_BufferData_Call) Run(run func(channel string, insertMsgs []*msgstream.InsertMsg, deleteMsgs []*msgstream.DeleteMsg, startPos *msgpb.MsgPosition, endPos *msgpb.MsgPosition)) *MockBufferManager_BufferData_Call { +func (_c *MockBufferManager_BufferData_Call) Run(run func(channel string, insertData []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos *msgpb.MsgPosition, endPos *msgpb.MsgPosition)) *MockBufferManager_BufferData_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].([]*msgstream.InsertMsg), args[2].([]*msgstream.DeleteMsg), args[3].(*msgpb.MsgPosition), args[4].(*msgpb.MsgPosition)) + run(args[0].(string), args[1].([]*InsertData), args[2].([]*msgstream.DeleteMsg), args[3].(*msgpb.MsgPosition), args[4].(*msgpb.MsgPosition)) }) return _c } @@ -67,7 +67,7 @@ func (_c *MockBufferManager_BufferData_Call) Return(_a0 error) *MockBufferManage return _c } -func (_c *MockBufferManager_BufferData_Call) RunAndReturn(run func(string, []*msgstream.InsertMsg, []*msgstream.DeleteMsg, *msgpb.MsgPosition, *msgpb.MsgPosition) error) *MockBufferManager_BufferData_Call { +func (_c *MockBufferManager_BufferData_Call) RunAndReturn(run func(string, []*InsertData, []*msgstream.DeleteMsg, *msgpb.MsgPosition, *msgpb.MsgPosition) error) *MockBufferManager_BufferData_Call { _c.Call.Return(run) return _c } diff --git a/internal/flushcommon/writebuffer/mock_write_buffer.go b/internal/flushcommon/writebuffer/mock_write_buffer.go index 9b85350e27dd6..93635c4178c28 100644 --- a/internal/flushcommon/writebuffer/mock_write_buffer.go +++ b/internal/flushcommon/writebuffer/mock_write_buffer.go @@ -25,11 +25,11 @@ func (_m *MockWriteBuffer) EXPECT() *MockWriteBuffer_Expecter { } // BufferData provides a mock function with given fields: insertMsgs, deleteMsgs, startPos, endPos -func (_m *MockWriteBuffer) BufferData(insertMsgs []*msgstream.InsertMsg, deleteMsgs []*msgstream.DeleteMsg, startPos *msgpb.MsgPosition, endPos *msgpb.MsgPosition) error { +func (_m *MockWriteBuffer) BufferData(insertMsgs []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos *msgpb.MsgPosition, endPos *msgpb.MsgPosition) error { ret := _m.Called(insertMsgs, deleteMsgs, startPos, endPos) var r0 error - if rf, ok := ret.Get(0).(func([]*msgstream.InsertMsg, []*msgstream.DeleteMsg, *msgpb.MsgPosition, *msgpb.MsgPosition) error); ok { + if rf, ok := ret.Get(0).(func([]*InsertData, []*msgstream.DeleteMsg, *msgpb.MsgPosition, *msgpb.MsgPosition) error); ok { r0 = rf(insertMsgs, deleteMsgs, startPos, endPos) } else { r0 = ret.Error(0) @@ -44,7 +44,7 @@ type MockWriteBuffer_BufferData_Call struct { } // BufferData is a helper method to define mock.On call -// - insertMsgs []*msgstream.InsertMsg +// - insertMsgs []*InsertData // - deleteMsgs []*msgstream.DeleteMsg // - startPos *msgpb.MsgPosition // - endPos *msgpb.MsgPosition @@ -52,9 +52,9 @@ func (_e *MockWriteBuffer_Expecter) BufferData(insertMsgs interface{}, deleteMsg return &MockWriteBuffer_BufferData_Call{Call: _e.mock.On("BufferData", insertMsgs, deleteMsgs, startPos, endPos)} } -func (_c *MockWriteBuffer_BufferData_Call) Run(run func(insertMsgs []*msgstream.InsertMsg, deleteMsgs []*msgstream.DeleteMsg, startPos *msgpb.MsgPosition, endPos *msgpb.MsgPosition)) *MockWriteBuffer_BufferData_Call { +func (_c *MockWriteBuffer_BufferData_Call) Run(run func(insertMsgs []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos *msgpb.MsgPosition, endPos *msgpb.MsgPosition)) *MockWriteBuffer_BufferData_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]*msgstream.InsertMsg), args[1].([]*msgstream.DeleteMsg), args[2].(*msgpb.MsgPosition), args[3].(*msgpb.MsgPosition)) + run(args[0].([]*InsertData), args[1].([]*msgstream.DeleteMsg), args[2].(*msgpb.MsgPosition), args[3].(*msgpb.MsgPosition)) }) return _c } @@ -64,7 +64,7 @@ func (_c *MockWriteBuffer_BufferData_Call) Return(_a0 error) *MockWriteBuffer_Bu return _c } -func (_c *MockWriteBuffer_BufferData_Call) RunAndReturn(run func([]*msgstream.InsertMsg, []*msgstream.DeleteMsg, *msgpb.MsgPosition, *msgpb.MsgPosition) error) *MockWriteBuffer_BufferData_Call { +func (_c *MockWriteBuffer_BufferData_Call) RunAndReturn(run func([]*InsertData, []*msgstream.DeleteMsg, *msgpb.MsgPosition, *msgpb.MsgPosition) error) *MockWriteBuffer_BufferData_Call { _c.Call.Return(run) return _c } diff --git a/internal/flushcommon/writebuffer/segment_buffer.go b/internal/flushcommon/writebuffer/segment_buffer.go index 6afd64fff7fa4..8b17e98703b31 100644 --- a/internal/flushcommon/writebuffer/segment_buffer.go +++ b/internal/flushcommon/writebuffer/segment_buffer.go @@ -32,8 +32,11 @@ func (buf *segmentBuffer) IsFull() bool { return buf.insertBuffer.IsFull() || buf.deltaBuffer.IsFull() } -func (buf *segmentBuffer) Yield() (insert []*storage.InsertData, delete *storage.DeleteData) { - return buf.insertBuffer.Yield(), buf.deltaBuffer.Yield() +func (buf *segmentBuffer) Yield() (insert []*storage.InsertData, bm25stats map[int64]*storage.BM25Stats, delete *storage.DeleteData) { + insert = buf.insertBuffer.Yield() + bm25stats = buf.insertBuffer.YieldStats() + delete = buf.deltaBuffer.Yield() + return } func (buf *segmentBuffer) MinTimestamp() typeutil.Timestamp { diff --git a/internal/flushcommon/writebuffer/stats_buffer.go b/internal/flushcommon/writebuffer/stats_buffer.go new file mode 100644 index 0000000000000..1cd2330df52bf --- /dev/null +++ b/internal/flushcommon/writebuffer/stats_buffer.go @@ -0,0 +1,48 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package writebuffer + +import ( + "github.com/milvus-io/milvus/internal/storage" +) + +// stats buffer used for bm25 stats +type statsBuffer struct { + bm25Stats map[int64]*storage.BM25Stats +} + +func (b *statsBuffer) Buffer(stats map[int64]*storage.BM25Stats) { + for fieldID, stat := range stats { + if fieldMeta, ok := b.bm25Stats[fieldID]; ok { + fieldMeta.Merge(stat) + } else { + b.bm25Stats[fieldID] = stat + } + } +} + +func (b *statsBuffer) yieldBuffer() map[int64]*storage.BM25Stats { + result := b.bm25Stats + b.bm25Stats = make(map[int64]*storage.BM25Stats) + return result +} + +func newStatsBuffer() *statsBuffer { + return &statsBuffer{ + bm25Stats: make(map[int64]*storage.BM25Stats), + } +} diff --git a/internal/flushcommon/writebuffer/write_buffer.go b/internal/flushcommon/writebuffer/write_buffer.go index 48589461e2610..020425e09debb 100644 --- a/internal/flushcommon/writebuffer/write_buffer.go +++ b/internal/flushcommon/writebuffer/write_buffer.go @@ -38,7 +38,7 @@ type WriteBuffer interface { // HasSegment checks whether certain segment exists in this buffer. HasSegment(segmentID int64) bool // BufferData is the method to buffer dml data msgs. - BufferData(insertMsgs []*msgstream.InsertMsg, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) error + BufferData(insertMsgs []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) error // FlushTimestamp set flush timestamp for write buffer SetFlushTimestamp(flushTs uint64) // GetFlushTimestamp get current flush timestamp @@ -82,12 +82,24 @@ func (c *checkpointCandidates) Remove(segmentID int64, timestamp uint64) { delete(c.candidates, fmt.Sprintf("%d-%d", segmentID, timestamp)) } +func (c *checkpointCandidates) RemoveChannel(channel string, timestamp uint64) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.candidates, fmt.Sprintf("%s-%d", channel, timestamp)) +} + func (c *checkpointCandidates) Add(segmentID int64, position *msgpb.MsgPosition, source string) { c.mu.Lock() defer c.mu.Unlock() c.candidates[fmt.Sprintf("%d-%d", segmentID, position.GetTimestamp())] = &checkpointCandidate{segmentID, position, source} } +func (c *checkpointCandidates) AddChannel(channel string, position *msgpb.MsgPosition, source string) { + c.mu.Lock() + defer c.mu.Unlock() + c.candidates[fmt.Sprintf("%s-%d", channel, position.GetTimestamp())] = &checkpointCandidate{-1, position, source} +} + func (c *checkpointCandidates) GetEarliestWithDefault(def *checkpointCandidate) *checkpointCandidate { c.mu.RLock() defer c.mu.RUnlock() @@ -126,8 +138,6 @@ type writeBufferBase struct { metaWriter syncmgr.MetaWriter collSchema *schemapb.CollectionSchema - helper *typeutil.SchemaHelper - pkField *schemapb.FieldSchema estSizePerRecord int metaCache metacache.MetaCache @@ -169,21 +179,11 @@ func newWriteBufferBase(channel string, metacache metacache.MetaCache, syncMgr s if err != nil { return nil, err } - helper, err := typeutil.CreateSchemaHelper(schema) - if err != nil { - return nil, err - } - pkField, err := helper.GetPrimaryKeyField() - if err != nil { - return nil, err - } wb := &writeBufferBase{ channelName: channel, collectionID: metacache.Collection(), collSchema: schema, - helper: helper, - pkField: pkField, estSizePerRecord: estSize, syncMgr: syncMgr, metaWriter: option.metaWriter, @@ -391,34 +391,96 @@ func (wb *writeBufferBase) getOrCreateBuffer(segmentID int64) *segmentBuffer { return buffer } -func (wb *writeBufferBase) yieldBuffer(segmentID int64) ([]*storage.InsertData, *storage.DeleteData, *TimeRange, *msgpb.MsgPosition) { +func (wb *writeBufferBase) yieldBuffer(segmentID int64) ([]*storage.InsertData, map[int64]*storage.BM25Stats, *storage.DeleteData, *TimeRange, *msgpb.MsgPosition) { buffer, ok := wb.buffers[segmentID] if !ok { - return nil, nil, nil, nil + return nil, nil, nil, nil, nil } // remove buffer and move it to sync manager delete(wb.buffers, segmentID) start := buffer.EarliestPosition() timeRange := buffer.GetTimeRange() - insert, delta := buffer.Yield() + insert, bm25, delta := buffer.Yield() - return insert, delta, timeRange, start + return insert, bm25, delta, timeRange, start } -type inData struct { +type InsertData struct { segmentID int64 partitionID int64 data []*storage.InsertData - pkField []storage.FieldData - tsField []*storage.Int64FieldData - rowNum int64 + bm25Stats map[int64]*storage.BM25Stats + + pkField []storage.FieldData + pkType schemapb.DataType + + tsField []*storage.Int64FieldData + rowNum int64 intPKTs map[int64]int64 strPKTs map[string]int64 } -func (id *inData) pkExists(pk storage.PrimaryKey, ts uint64) bool { +func NewInsertData(segmentID, partitionID int64, cap int, pkType schemapb.DataType) *InsertData { + data := &InsertData{ + segmentID: segmentID, + partitionID: partitionID, + data: make([]*storage.InsertData, 0, cap), + pkField: make([]storage.FieldData, 0, cap), + pkType: pkType, + } + + switch pkType { + case schemapb.DataType_Int64: + data.intPKTs = make(map[int64]int64) + case schemapb.DataType_VarChar: + data.strPKTs = make(map[string]int64) + } + + return data +} + +func (id *InsertData) Append(data *storage.InsertData, pkFieldData storage.FieldData, tsFieldData *storage.Int64FieldData) { + id.data = append(id.data, data) + id.pkField = append(id.pkField, pkFieldData) + id.tsField = append(id.tsField, tsFieldData) + id.rowNum += int64(data.GetRowNum()) + + timestamps := tsFieldData.GetDataRows().([]int64) + switch id.pkType { + case schemapb.DataType_Int64: + pks := pkFieldData.GetDataRows().([]int64) + for idx, pk := range pks { + ts, ok := id.intPKTs[pk] + if !ok || timestamps[idx] < ts { + id.intPKTs[pk] = timestamps[idx] + } + } + case schemapb.DataType_VarChar: + pks := pkFieldData.GetDataRows().([]string) + for idx, pk := range pks { + ts, ok := id.strPKTs[pk] + if !ok || timestamps[idx] < ts { + id.strPKTs[pk] = timestamps[idx] + } + } + } +} + +func (id *InsertData) GetSegmentID() int64 { + return id.segmentID +} + +func (id *InsertData) SetBM25Stats(bm25Stats map[int64]*storage.BM25Stats) { + id.bm25Stats = bm25Stats +} + +func (id *InsertData) GetDatas() []*storage.InsertData { + return id.data +} + +func (id *InsertData) pkExists(pk storage.PrimaryKey, ts uint64) bool { var ok bool var minTs int64 switch pk.Type() { @@ -431,7 +493,7 @@ func (id *inData) pkExists(pk storage.PrimaryKey, ts uint64) bool { return ok && ts > uint64(minTs) } -func (id *inData) batchPkExists(pks []storage.PrimaryKey, tss []uint64, hits []bool) []bool { +func (id *InsertData) batchPkExists(pks []storage.PrimaryKey, tss []uint64, hits []bool) []bool { if len(pks) == 0 { return nil } @@ -457,84 +519,8 @@ func (id *inData) batchPkExists(pks []storage.PrimaryKey, tss []uint64, hits []b return hits } -// prepareInsert transfers InsertMsg into organized InsertData grouped by segmentID -// also returns primary key field data -func (wb *writeBufferBase) prepareInsert(insertMsgs []*msgstream.InsertMsg) ([]*inData, error) { - groups := lo.GroupBy(insertMsgs, func(msg *msgstream.InsertMsg) int64 { return msg.SegmentID }) - segmentPartition := lo.SliceToMap(insertMsgs, func(msg *msgstream.InsertMsg) (int64, int64) { return msg.GetSegmentID(), msg.GetPartitionID() }) - - result := make([]*inData, 0, len(groups)) - for segment, msgs := range groups { - inData := &inData{ - segmentID: segment, - partitionID: segmentPartition[segment], - data: make([]*storage.InsertData, 0, len(msgs)), - pkField: make([]storage.FieldData, 0, len(msgs)), - } - switch wb.pkField.GetDataType() { - case schemapb.DataType_Int64: - inData.intPKTs = make(map[int64]int64) - case schemapb.DataType_VarChar: - inData.strPKTs = make(map[string]int64) - } - - for _, msg := range msgs { - data, err := storage.InsertMsgToInsertData(msg, wb.collSchema) - if err != nil { - log.Warn("failed to transfer insert msg to insert data", zap.Error(err)) - return nil, err - } - - pkFieldData, err := storage.GetPkFromInsertData(wb.collSchema, data) - if err != nil { - return nil, err - } - if pkFieldData.RowNum() != data.GetRowNum() { - return nil, merr.WrapErrServiceInternal("pk column row num not match") - } - - tsFieldData, err := storage.GetTimestampFromInsertData(data) - if err != nil { - return nil, err - } - if tsFieldData.RowNum() != data.GetRowNum() { - return nil, merr.WrapErrServiceInternal("timestamp column row num not match") - } - - timestamps := tsFieldData.GetDataRows().([]int64) - - switch wb.pkField.GetDataType() { - case schemapb.DataType_Int64: - pks := pkFieldData.GetDataRows().([]int64) - for idx, pk := range pks { - ts, ok := inData.intPKTs[pk] - if !ok || timestamps[idx] < ts { - inData.intPKTs[pk] = timestamps[idx] - } - } - case schemapb.DataType_VarChar: - pks := pkFieldData.GetDataRows().([]string) - for idx, pk := range pks { - ts, ok := inData.strPKTs[pk] - if !ok || timestamps[idx] < ts { - inData.strPKTs[pk] = timestamps[idx] - } - } - } - - inData.data = append(inData.data, data) - inData.pkField = append(inData.pkField, pkFieldData) - inData.tsField = append(inData.tsField, tsFieldData) - inData.rowNum += int64(data.GetRowNum()) - } - result = append(result, inData) - } - - return result, nil -} - -// bufferInsert transform InsertMsg into bufferred InsertData and returns primary key field data for future usage. -func (wb *writeBufferBase) bufferInsert(inData *inData, startPos, endPos *msgpb.MsgPosition) error { +// bufferInsert function InsertMsg into bufferred InsertData and returns primary key field data for future usage. +func (wb *writeBufferBase) bufferInsert(inData *InsertData, startPos, endPos *msgpb.MsgPosition) error { _, ok := wb.metaCache.GetSegmentByID(inData.segmentID) // new segment if !ok { @@ -547,7 +533,7 @@ func (wb *writeBufferBase) bufferInsert(inData *inData, startPos, endPos *msgpb. State: commonpb.SegmentState_Growing, }, func(_ *datapb.SegmentInfo) pkoracle.PkStat { return pkoracle.NewBloomFilterSetWithBatchSize(wb.getEstBatchSize()) - }, metacache.SetStartPosRecorded(false)) + }, metacache.NewBM25StatsFactory, metacache.SetStartPosRecorded(false)) log.Info("add growing segment", zap.Int64("segmentID", inData.segmentID), zap.String("channel", wb.channelName)) } @@ -582,7 +568,7 @@ func (wb *writeBufferBase) getSyncTask(ctx context.Context, segmentID int64) (sy var totalMemSize float64 = 0 var tsFrom, tsTo uint64 - insert, delta, timeRange, startPos := wb.yieldBuffer(segmentID) + insert, bm25, delta, timeRange, startPos := wb.yieldBuffer(segmentID) if timeRange != nil { tsFrom, tsTo = timeRange.timestampMin, timeRange.timestampMax } @@ -619,6 +605,10 @@ func (wb *writeBufferBase) getSyncTask(ctx context.Context, segmentID int64) (sy WithBatchSize(batchSize). WithErrorHandler(wb.errHandler) + if len(bm25) != 0 { + pack.WithBM25Stats(bm25) + } + if segmentInfo.State() == commonpb.SegmentState_Flushing || segmentInfo.Level() == datapb.SegmentLevel_L0 { // Level zero segment will always be sync as flushed pack.WithFlush() @@ -685,3 +675,79 @@ func (wb *writeBufferBase) Close(ctx context.Context, drop bool) { panic(err) } } + +// prepareInsert transfers InsertMsg into organized InsertData grouped by segmentID +// also returns primary key field data +func PrepareInsert(collSchema *schemapb.CollectionSchema, pkField *schemapb.FieldSchema, insertMsgs []*msgstream.InsertMsg) ([]*InsertData, error) { + groups := lo.GroupBy(insertMsgs, func(msg *msgstream.InsertMsg) int64 { return msg.SegmentID }) + segmentPartition := lo.SliceToMap(insertMsgs, func(msg *msgstream.InsertMsg) (int64, int64) { return msg.GetSegmentID(), msg.GetPartitionID() }) + + result := make([]*InsertData, 0, len(groups)) + for segment, msgs := range groups { + inData := &InsertData{ + segmentID: segment, + partitionID: segmentPartition[segment], + data: make([]*storage.InsertData, 0, len(msgs)), + pkField: make([]storage.FieldData, 0, len(msgs)), + } + switch pkField.GetDataType() { + case schemapb.DataType_Int64: + inData.intPKTs = make(map[int64]int64) + case schemapb.DataType_VarChar: + inData.strPKTs = make(map[string]int64) + } + + for _, msg := range msgs { + data, err := storage.InsertMsgToInsertData(msg, collSchema) + if err != nil { + log.Warn("failed to transfer insert msg to insert data", zap.Error(err)) + return nil, err + } + + pkFieldData, err := storage.GetPkFromInsertData(collSchema, data) + if err != nil { + return nil, err + } + if pkFieldData.RowNum() != data.GetRowNum() { + return nil, merr.WrapErrServiceInternal("pk column row num not match") + } + + tsFieldData, err := storage.GetTimestampFromInsertData(data) + if err != nil { + return nil, err + } + if tsFieldData.RowNum() != data.GetRowNum() { + return nil, merr.WrapErrServiceInternal("timestamp column row num not match") + } + + timestamps := tsFieldData.GetDataRows().([]int64) + + switch pkField.GetDataType() { + case schemapb.DataType_Int64: + pks := pkFieldData.GetDataRows().([]int64) + for idx, pk := range pks { + ts, ok := inData.intPKTs[pk] + if !ok || timestamps[idx] < ts { + inData.intPKTs[pk] = timestamps[idx] + } + } + case schemapb.DataType_VarChar: + pks := pkFieldData.GetDataRows().([]string) + for idx, pk := range pks { + ts, ok := inData.strPKTs[pk] + if !ok || timestamps[idx] < ts { + inData.strPKTs[pk] = timestamps[idx] + } + } + } + + inData.data = append(inData.data, data) + inData.pkField = append(inData.pkField, pkFieldData) + inData.tsField = append(inData.tsField, tsFieldData) + inData.rowNum += int64(data.GetRowNum()) + } + result = append(result, inData) + } + + return result, nil +} diff --git a/internal/flushcommon/writebuffer/write_buffer_test.go b/internal/flushcommon/writebuffer/write_buffer_test.go index 5029be5cec4b0..7363fa2ad93c1 100644 --- a/internal/flushcommon/writebuffer/write_buffer_test.go +++ b/internal/flushcommon/writebuffer/write_buffer_test.go @@ -276,7 +276,7 @@ func (s *WriteBufferSuite) TestSyncSegmentsError() { segment := metacache.NewSegmentInfo(&datapb.SegmentInfo{ ID: 1, - }, nil) + }, nil, nil) s.metacache.EXPECT().GetSegmentByID(int64(1)).Return(segment, true) s.metacache.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() @@ -348,7 +348,7 @@ func (s *WriteBufferSuite) TestEvictBuffer() { segment := metacache.NewSegmentInfo(&datapb.SegmentInfo{ ID: 2, - }, nil) + }, nil, nil) s.metacache.EXPECT().GetSegmentByID(int64(2)).Return(segment, true) s.metacache.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() serializer.EXPECT().EncodeBuffer(mock.Anything, mock.Anything).Return(syncmgr.NewSyncTask(), nil) diff --git a/internal/metastore/kv/binlog/binlog.go b/internal/metastore/kv/binlog/binlog.go index f0dbe45c54128..49abca78187f8 100644 --- a/internal/metastore/kv/binlog/binlog.go +++ b/internal/metastore/kv/binlog/binlog.go @@ -42,6 +42,10 @@ func CompressSaveBinlogPaths(req *datapb.SaveBinlogPathsRequest) error { if err != nil { return err } + err = CompressFieldBinlogs(req.GetField2Bm25LogPaths()) + if err != nil { + return err + } return nil } @@ -133,6 +137,11 @@ func DecompressBinLogs(s *datapb.SegmentInfo) error { if err != nil { return err } + + err = DecompressBinLog(storage.BM25Binlog, collectionID, partitionID, segmentID, s.GetBm25Statslogs()) + if err != nil { + return err + } return nil } @@ -167,6 +176,8 @@ func BuildLogPath(binlogType storage.BinlogType, collectionID, partitionID, segm return metautil.BuildDeltaLogPath(chunkManagerRootPath, collectionID, partitionID, segmentID, logID), nil case storage.StatsBinlog: return metautil.BuildStatsLogPath(chunkManagerRootPath, collectionID, partitionID, segmentID, fieldID, logID), nil + case storage.BM25Binlog: + return metautil.BuildBm25LogPath(chunkManagerRootPath, collectionID, partitionID, segmentID, fieldID, logID), nil } // should not happen return "", merr.WrapErrParameterInvalidMsg("invalid binlog type") diff --git a/internal/metastore/kv/datacoord/constant.go b/internal/metastore/kv/datacoord/constant.go index 6b4083c3cd735..7e4a44cb88e3c 100644 --- a/internal/metastore/kv/datacoord/constant.go +++ b/internal/metastore/kv/datacoord/constant.go @@ -22,6 +22,7 @@ const ( SegmentBinlogPathPrefix = MetaPrefix + "/binlog" SegmentDeltalogPathPrefix = MetaPrefix + "/deltalog" SegmentStatslogPathPrefix = MetaPrefix + "/statslog" + SegmentBM25logPathPrefix = MetaPrefix + "/bm25log" ChannelRemovePrefix = MetaPrefix + "/channel-removal" ChannelCheckpointPrefix = MetaPrefix + "/channel-cp" ImportJobPrefix = MetaPrefix + "/import-job" diff --git a/internal/metastore/kv/datacoord/kv_catalog.go b/internal/metastore/kv/datacoord/kv_catalog.go index 88cd39a4d36db..a0aae7c2bff4f 100644 --- a/internal/metastore/kv/datacoord/kv_catalog.go +++ b/internal/metastore/kv/datacoord/kv_catalog.go @@ -64,6 +64,7 @@ func (kc *Catalog) ListSegments(ctx context.Context) ([]*datapb.SegmentInfo, err insertLogs := make(map[typeutil.UniqueID][]*datapb.FieldBinlog, 1) deltaLogs := make(map[typeutil.UniqueID][]*datapb.FieldBinlog, 1) statsLogs := make(map[typeutil.UniqueID][]*datapb.FieldBinlog, 1) + bm25Logs := make(map[typeutil.UniqueID][]*datapb.FieldBinlog, 1) executeFn := func(binlogType storage.BinlogType, result map[typeutil.UniqueID][]*datapb.FieldBinlog) { group.Go(func() error { @@ -81,6 +82,7 @@ func (kc *Catalog) ListSegments(ctx context.Context) ([]*datapb.SegmentInfo, err executeFn(storage.InsertBinlog, insertLogs) executeFn(storage.DeleteBinlog, deltaLogs) executeFn(storage.StatsBinlog, statsLogs) + executeFn(storage.BM25Binlog, bm25Logs) group.Go(func() error { ret, err := kc.listSegments() if err != nil { @@ -95,7 +97,7 @@ func (kc *Catalog) ListSegments(ctx context.Context) ([]*datapb.SegmentInfo, err return nil, err } - err = kc.applyBinlogInfo(segments, insertLogs, deltaLogs, statsLogs) + err = kc.applyBinlogInfo(segments, insertLogs, deltaLogs, statsLogs, bm25Logs) if err != nil { return nil, err } @@ -172,6 +174,8 @@ func (kc *Catalog) listBinlogs(binlogType storage.BinlogType) (map[typeutil.Uniq logPathPrefix = SegmentDeltalogPathPrefix case storage.StatsBinlog: logPathPrefix = SegmentStatslogPathPrefix + case storage.BM25Binlog: + logPathPrefix = SegmentBM25logPathPrefix default: err = fmt.Errorf("invalid binlog type: %d", binlogType) } @@ -218,7 +222,7 @@ func (kc *Catalog) listBinlogs(binlogType storage.BinlogType) (map[typeutil.Uniq } func (kc *Catalog) applyBinlogInfo(segments []*datapb.SegmentInfo, insertLogs, deltaLogs, - statsLogs map[typeutil.UniqueID][]*datapb.FieldBinlog, + statsLogs, bm25Logs map[typeutil.UniqueID][]*datapb.FieldBinlog, ) error { var err error for _, segmentInfo := range segments { @@ -242,6 +246,13 @@ func (kc *Catalog) applyBinlogInfo(segments []*datapb.SegmentInfo, insertLogs, d if err = binlog.CompressFieldBinlogs(segmentInfo.Statslogs); err != nil { return err } + + if len(segmentInfo.Bm25Statslogs) == 0 { + segmentInfo.Bm25Statslogs = bm25Logs[segmentInfo.ID] + } + if err = binlog.CompressFieldBinlogs(segmentInfo.Bm25Statslogs); err != nil { + return err + } } return nil } @@ -309,7 +320,7 @@ func (kc *Catalog) AlterSegments(ctx context.Context, segments []*datapb.Segment segment := b.Segment binlogKvs, err := buildBinlogKvsWithLogID(segment.GetCollectionID(), segment.GetPartitionID(), segment.GetID(), - cloneLogs(segment.GetBinlogs()), cloneLogs(segment.GetDeltalogs()), cloneLogs(segment.GetStatslogs())) + cloneLogs(segment.GetBinlogs()), cloneLogs(segment.GetDeltalogs()), cloneLogs(segment.GetStatslogs()), cloneLogs(segment.GetBm25Statslogs())) if err != nil { return err } @@ -328,7 +339,7 @@ func (kc *Catalog) handleDroppedSegment(segment *datapb.SegmentInfo) (kvs map[st } // To be compatible with previous implementation, we have to write binlogs on etcd for correct gc. if !has { - kvs, err = buildBinlogKvsWithLogID(segment.GetCollectionID(), segment.GetPartitionID(), segment.GetID(), cloneLogs(segment.GetBinlogs()), cloneLogs(segment.GetDeltalogs()), cloneLogs(segment.GetStatslogs())) + kvs, err = buildBinlogKvsWithLogID(segment.GetCollectionID(), segment.GetPartitionID(), segment.GetID(), cloneLogs(segment.GetBinlogs()), cloneLogs(segment.GetDeltalogs()), cloneLogs(segment.GetStatslogs()), cloneLogs(segment.GetBm25Statslogs())) if err != nil { return } @@ -398,7 +409,7 @@ func (kc *Catalog) SaveDroppedSegmentsInBatch(ctx context.Context, segments []*d kvs := make(map[string]string) for _, s := range segments { key := buildSegmentPath(s.GetCollectionID(), s.GetPartitionID(), s.GetID()) - noBinlogsSegment, _, _, _ := CloneSegmentWithExcludeBinlogs(s) + noBinlogsSegment, _, _, _, _ := CloneSegmentWithExcludeBinlogs(s) // `s` is not mutated above. Also, `noBinlogsSegment` is a cloned version of `s`. segmentutil.ReCalcRowCount(s, noBinlogsSegment) segBytes, err := proto.Marshal(noBinlogsSegment) diff --git a/internal/metastore/kv/datacoord/kv_catalog_test.go b/internal/metastore/kv/datacoord/kv_catalog_test.go index 085cbf4876d28..567b82f23cca5 100644 --- a/internal/metastore/kv/datacoord/kv_catalog_test.go +++ b/internal/metastore/kv/datacoord/kv_catalog_test.go @@ -130,20 +130,6 @@ var ( }, } - getlogs = func(id int64) []*datapb.FieldBinlog { - return []*datapb.FieldBinlog{ - { - FieldID: 1, - Binlogs: []*datapb.Binlog{ - { - EntriesNum: 5, - LogID: id, - }, - }, - }, - } - } - segment1 = &datapb.SegmentInfo{ ID: segmentID, CollectionID: collectionID, @@ -154,17 +140,6 @@ var ( Deltalogs: deltalogs, Statslogs: statslogs, } - - droppedSegment = &datapb.SegmentInfo{ - ID: segmentID2, - CollectionID: collectionID, - PartitionID: partitionID, - NumOfRows: 100, - State: commonpb.SegmentState_Dropped, - Binlogs: getlogs(logID), - Deltalogs: getlogs(logID), - Statslogs: getlogs(logID), - } ) func Test_ListSegments(t *testing.T) { @@ -255,6 +230,10 @@ func Test_ListSegments(t *testing.T) { if strings.HasPrefix(k3, s) { return f([]byte(k3), []byte(savedKvs[k3])) } + // return empty bm25log list + if strings.HasPrefix(s, SegmentBM25logPathPrefix) { + return nil + } return errors.New("should not reach here") }) diff --git a/internal/metastore/kv/datacoord/util.go b/internal/metastore/kv/datacoord/util.go index df67aa3ddaf27..2f7262066d2af 100644 --- a/internal/metastore/kv/datacoord/util.go +++ b/internal/metastore/kv/datacoord/util.go @@ -93,10 +93,10 @@ func hasSpecialStatslog(segment *datapb.SegmentInfo) bool { } func buildBinlogKvsWithLogID(collectionID, partitionID, segmentID typeutil.UniqueID, - binlogs, deltalogs, statslogs []*datapb.FieldBinlog, + binlogs, deltalogs, statslogs, bm25logs []*datapb.FieldBinlog, ) (map[string]string, error) { // all the FieldBinlog will only have logid - kvs, err := buildBinlogKvs(collectionID, partitionID, segmentID, binlogs, deltalogs, statslogs) + kvs, err := buildBinlogKvs(collectionID, partitionID, segmentID, binlogs, deltalogs, statslogs, bm25logs) if err != nil { return nil, err } @@ -105,12 +105,12 @@ func buildBinlogKvsWithLogID(collectionID, partitionID, segmentID typeutil.Uniqu } func buildSegmentAndBinlogsKvs(segment *datapb.SegmentInfo) (map[string]string, error) { - noBinlogsSegment, binlogs, deltalogs, statslogs := CloneSegmentWithExcludeBinlogs(segment) + noBinlogsSegment, binlogs, deltalogs, statslogs, bm25logs := CloneSegmentWithExcludeBinlogs(segment) // `segment` is not mutated above. Also, `noBinlogsSegment` is a cloned version of `segment`. segmentutil.ReCalcRowCount(segment, noBinlogsSegment) // save binlogs separately - kvs, err := buildBinlogKvsWithLogID(noBinlogsSegment.CollectionID, noBinlogsSegment.PartitionID, noBinlogsSegment.ID, binlogs, deltalogs, statslogs) + kvs, err := buildBinlogKvsWithLogID(noBinlogsSegment.CollectionID, noBinlogsSegment.PartitionID, noBinlogsSegment.ID, binlogs, deltalogs, statslogs, bm25logs) if err != nil { return nil, err } @@ -125,32 +125,11 @@ func buildSegmentAndBinlogsKvs(segment *datapb.SegmentInfo) (map[string]string, return kvs, nil } -func buildBinlogKeys(segment *datapb.SegmentInfo) []string { - var keys []string - // binlog - for _, binlog := range segment.Binlogs { - key := buildFieldBinlogPath(segment.CollectionID, segment.PartitionID, segment.ID, binlog.FieldID) - keys = append(keys, key) - } - - // deltalog - for _, deltalog := range segment.Deltalogs { - key := buildFieldDeltalogPath(segment.CollectionID, segment.PartitionID, segment.ID, deltalog.FieldID) - keys = append(keys, key) - } - - // statslog - for _, statslog := range segment.Statslogs { - key := buildFieldStatslogPath(segment.CollectionID, segment.PartitionID, segment.ID, statslog.FieldID) - keys = append(keys, key) - } - return keys -} - func resetBinlogFields(segment *datapb.SegmentInfo) { segment.Binlogs = nil segment.Deltalogs = nil segment.Statslogs = nil + segment.Bm25Statslogs = nil } func cloneLogs(binlogs []*datapb.FieldBinlog) []*datapb.FieldBinlog { @@ -161,7 +140,7 @@ func cloneLogs(binlogs []*datapb.FieldBinlog) []*datapb.FieldBinlog { return res } -func buildBinlogKvs(collectionID, partitionID, segmentID typeutil.UniqueID, binlogs, deltalogs, statslogs []*datapb.FieldBinlog) (map[string]string, error) { +func buildBinlogKvs(collectionID, partitionID, segmentID typeutil.UniqueID, binlogs, deltalogs, statslogs, bm25logs []*datapb.FieldBinlog) (map[string]string, error) { kv := make(map[string]string) checkLogID := func(fieldBinlog *datapb.FieldBinlog) error { @@ -215,19 +194,33 @@ func buildBinlogKvs(collectionID, partitionID, segmentID typeutil.UniqueID, binl kv[key] = string(binlogBytes) } + for _, bm25log := range bm25logs { + if err := checkLogID(bm25log); err != nil { + return nil, err + } + binlogBytes, err := proto.Marshal(bm25log) + if err != nil { + return nil, fmt.Errorf("marshal bm25log failed, collectionID:%d, segmentID:%d, fieldID:%d, error:%w", collectionID, segmentID, bm25log.FieldID, err) + } + key := buildFieldBM25StatslogPath(collectionID, partitionID, segmentID, bm25log.FieldID) + kv[key] = string(binlogBytes) + } + return kv, nil } -func CloneSegmentWithExcludeBinlogs(segment *datapb.SegmentInfo) (*datapb.SegmentInfo, []*datapb.FieldBinlog, []*datapb.FieldBinlog, []*datapb.FieldBinlog) { +func CloneSegmentWithExcludeBinlogs(segment *datapb.SegmentInfo) (*datapb.SegmentInfo, []*datapb.FieldBinlog, []*datapb.FieldBinlog, []*datapb.FieldBinlog, []*datapb.FieldBinlog) { clonedSegment := proto.Clone(segment).(*datapb.SegmentInfo) binlogs := clonedSegment.Binlogs deltalogs := clonedSegment.Deltalogs statlogs := clonedSegment.Statslogs + bm25logs := clonedSegment.Bm25Statslogs clonedSegment.Binlogs = nil clonedSegment.Deltalogs = nil clonedSegment.Statslogs = nil - return clonedSegment, binlogs, deltalogs, statlogs + clonedSegment.Bm25Statslogs = nil + return clonedSegment, binlogs, deltalogs, statlogs, bm25logs } func marshalSegmentInfo(segment *datapb.SegmentInfo) (string, error) { @@ -298,6 +291,10 @@ func buildFieldStatslogPath(collectionID typeutil.UniqueID, partitionID typeutil return fmt.Sprintf("%s/%d/%d/%d/%d", SegmentStatslogPathPrefix, collectionID, partitionID, segmentID, fieldID) } +func buildFieldBM25StatslogPath(collectionID typeutil.UniqueID, partitionID typeutil.UniqueID, segmentID typeutil.UniqueID, fieldID typeutil.UniqueID) string { + return fmt.Sprintf("%s/%d/%d/%d/%d", SegmentBM25logPathPrefix, collectionID, partitionID, segmentID, fieldID) +} + func buildFieldBinlogPathPrefix(collectionID typeutil.UniqueID, partitionID typeutil.UniqueID, segmentID typeutil.UniqueID) string { return fmt.Sprintf("%s/%d/%d/%d", SegmentBinlogPathPrefix, collectionID, partitionID, segmentID) } diff --git a/internal/proto/data_coord.proto b/internal/proto/data_coord.proto index 1cd55e9c132f2..52b1f3e3a344b 100644 --- a/internal/proto/data_coord.proto +++ b/internal/proto/data_coord.proto @@ -356,6 +356,7 @@ message SegmentInfo { // textStatsLogs is used to record tokenization index for fields. map textStatsLogs = 26; + repeated FieldBinlog bm25statslogs = 27; } message SegmentStartPosition { @@ -379,6 +380,7 @@ message SaveBinlogPathsRequest { SegmentLevel seg_level =13; int64 partitionID =14; // report partitionID for create L0 segment int64 storageVersion = 15; + repeated FieldBinlog field2Bm25logPaths = 16; } message CheckPoint { @@ -621,6 +623,7 @@ message CompactionSegment { repeated FieldBinlog deltalogs = 6; string channel = 7; bool is_sorted = 8; + repeated FieldBinlog bm25logs = 9; } message CompactionPlanResult { diff --git a/internal/storage/binlog_writer.go b/internal/storage/binlog_writer.go index 173aae219e7b8..4b3d1d31b9ad0 100644 --- a/internal/storage/binlog_writer.go +++ b/internal/storage/binlog_writer.go @@ -39,6 +39,8 @@ const ( IndexFileBinlog // StatsBinlog BinlogType for stats data StatsBinlog + // BM25 BinlogType for bm25 stats data + BM25Binlog ) const ( diff --git a/internal/storage/stats.go b/internal/storage/stats.go index 75da19ab5ecd6..f71930b9358e7 100644 --- a/internal/storage/stats.go +++ b/internal/storage/stats.go @@ -17,10 +17,14 @@ package storage import ( + "bytes" + "encoding/binary" "encoding/json" "fmt" + "math" "go.uber.org/zap" + "golang.org/x/exp/maps" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/util/bloomfilter" @@ -28,9 +32,10 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) -// PrimaryKeyStats contains statistics data for pk column +// PrimaryKeyStats contains rowsWithToken data for pk column type PrimaryKeyStats struct { FieldID int64 `json:"fieldID"` Max int64 `json:"max"` // useless, will delete @@ -299,6 +304,173 @@ func (sr *StatsReader) GetPrimaryKeyStatsList() ([]*PrimaryKeyStats, error) { return stats, nil } +type BM25Stats struct { + rowsWithToken map[uint32]int32 // mapping token => row num include token + numRow int64 // total row num + numToken int64 // total token num +} + +const BM25VERSION int32 = 0 + +func NewBM25Stats() *BM25Stats { + return &BM25Stats{ + rowsWithToken: map[uint32]int32{}, + } +} + +func NewBM25StatsWithBytes(bytes []byte) (*BM25Stats, error) { + stats := NewBM25Stats() + err := stats.Deserialize(bytes) + if err != nil { + return nil, err + } + return stats, nil +} + +func (m *BM25Stats) Append(rows ...map[uint32]float32) { + for _, row := range rows { + for key, value := range row { + m.rowsWithToken[key] += 1 + m.numToken += int64(value) + } + + m.numRow += 1 + } +} + +func (m *BM25Stats) AppendFieldData(datas ...*SparseFloatVectorFieldData) { + for _, data := range datas { + m.AppendBytes(data.GetContents()...) + } +} + +// Update BM25Stats by sparse vector bytes +func (m *BM25Stats) AppendBytes(datas ...[]byte) { + for _, data := range datas { + dim := len(data) / 8 + for i := 0; i < dim; i++ { + index := typeutil.SparseFloatRowIndexAt(data, i) + value := typeutil.SparseFloatRowValueAt(data, i) + m.rowsWithToken[index] += 1 + m.numToken += int64(value) + } + m.numRow += 1 + } +} + +func (m *BM25Stats) NumRow() int64 { + return m.numRow +} + +func (m *BM25Stats) NumToken() int64 { + return m.numToken +} + +func (m *BM25Stats) Merge(meta *BM25Stats) { + for key, value := range meta.rowsWithToken { + m.rowsWithToken[key] += value + } + m.numRow += meta.NumRow() + m.numToken += meta.numToken +} + +func (m *BM25Stats) Minus(meta *BM25Stats) { + for key, value := range meta.rowsWithToken { + m.rowsWithToken[key] -= value + } + m.numRow -= meta.numRow + m.numToken -= meta.numToken +} + +func (m *BM25Stats) Clone() *BM25Stats { + return &BM25Stats{ + rowsWithToken: maps.Clone(m.rowsWithToken), + numRow: m.numRow, + numToken: m.numToken, + } +} + +func (m *BM25Stats) Serialize() ([]byte, error) { + buffer := bytes.NewBuffer(make([]byte, 0, len(m.rowsWithToken)*8+20)) + + if err := binary.Write(buffer, common.Endian, BM25VERSION); err != nil { + return nil, err + } + + if err := binary.Write(buffer, common.Endian, m.numRow); err != nil { + return nil, err + } + + if err := binary.Write(buffer, common.Endian, m.numToken); err != nil { + return nil, err + } + + for key, value := range m.rowsWithToken { + if err := binary.Write(buffer, common.Endian, key); err != nil { + return nil, err + } + + if err := binary.Write(buffer, common.Endian, value); err != nil { + return nil, err + } + } + + // TODO ADD Serialize Time Metric + return buffer.Bytes(), nil +} + +func (m *BM25Stats) Deserialize(bs []byte) error { + buffer := bytes.NewBuffer(bs) + dim := (len(bs) - 20) / 8 + var numRow, tokenNum int64 + var version int32 + if err := binary.Read(buffer, common.Endian, &version); err != nil { + return err + } + + if err := binary.Read(buffer, common.Endian, &numRow); err != nil { + return err + } + + if err := binary.Read(buffer, common.Endian, &tokenNum); err != nil { + return err + } + + var keys []uint32 = make([]uint32, dim) + var values []int32 = make([]int32, dim) + for i := 0; i < dim; i++ { + if err := binary.Read(buffer, common.Endian, &keys[i]); err != nil { + return err + } + + if err := binary.Read(buffer, common.Endian, &values[i]); err != nil { + return err + } + } + + m.numRow += numRow + m.numToken += tokenNum + for i := 0; i < dim; i++ { + m.rowsWithToken[keys[i]] += values[i] + } + + log.Info("test-- deserialize", zap.Int64("numrow", m.numRow), zap.Int64("tokenNum", m.numToken)) + return nil +} + +func (m *BM25Stats) BuildIDF(tf map[uint32]float32) map[uint32]float32 { + vector := make(map[uint32]float32) + for key, value := range tf { + nq := m.rowsWithToken[key] + vector[key] = value * float32(math.Log(1+(float64(m.numRow)-float64(nq)+0.5)/(float64(nq)+0.5))) + } + return vector +} + +func (m *BM25Stats) GetAvgdl() float64 { + return float64(m.numToken) / float64(m.numRow) +} + // DeserializeStats deserialize @blobs as []*PrimaryKeyStats func DeserializeStats(blobs []*Blob) ([]*PrimaryKeyStats, error) { results := make([]*PrimaryKeyStats, 0, len(blobs)) diff --git a/internal/storage/utils.go b/internal/storage/utils.go index f6566402af00d..50210431827f8 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -371,6 +371,10 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap } for _, field := range collSchema.Fields { + if field.GetIsFunctionOutput() { + continue + } + switch field.DataType { case schemapb.DataType_FloatVector: dim, err := GetDimFromParams(field.TypeParams) @@ -482,7 +486,7 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap // ColumnBasedInsertMsgToInsertData converts an InsertMsg msg into InsertData based // on provided CollectionSchema collSchema. // -// This function checks whether all fields are provided in the collSchema.Fields. +// This function checks whether all fields are provided in the collSchema.Fields and not function output. // If any field is missing in the msg, an error will be returned. // // This funcion also checks the length of each column. All columns shall have the same length. @@ -499,6 +503,10 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche } length := 0 for _, field := range collSchema.Fields { + if field.GetIsFunctionOutput() { + continue + } + srcField, ok := srcFields[field.GetFieldID()] if !ok && field.GetFieldID() >= common.StartOfUserFieldID { return nil, merr.WrapErrFieldNotFound(field.GetFieldID(), fmt.Sprintf("field %s not found when converting insert msg to insert data", field.GetName())) diff --git a/internal/util/function/bm25_function.go b/internal/util/function/bm25_function.go new file mode 100644 index 0000000000000..275be8e412f29 --- /dev/null +++ b/internal/util/function/bm25_function.go @@ -0,0 +1,159 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "sync" + + "github.com/samber/lo" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/ctokenizer" + "github.com/milvus-io/milvus/internal/util/tokenizerapi" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// BM25 Runner +// Input: string +// Output: map[uint32]float32 +type BM25FunctionRunner struct { + tokenizer tokenizerapi.Tokenizer + schema *schemapb.FunctionSchema + outputField *schemapb.FieldSchema + concurrency int +} + +func NewBM25FunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*BM25FunctionRunner, error) { + if len(schema.GetOutputFieldIds()) != 1 { + return nil, fmt.Errorf("bm25 function should only have one output field, but now %d", len(schema.GetOutputFieldIds())) + } + + runner := &BM25FunctionRunner{ + schema: schema, + concurrency: 8, + } + for _, field := range coll.GetFields() { + if field.GetFieldID() == schema.GetOutputFieldIds()[0] { + runner.outputField = field + break + } + } + + if runner.outputField == nil { + return nil, fmt.Errorf("no output field") + } + tokenizer, err := ctokenizer.NewTokenizer(map[string]string{}) + if err != nil { + return nil, err + } + + runner.tokenizer = tokenizer + return runner, nil +} + +func (v *BM25FunctionRunner) run(data []string, dst []map[uint32]float32) error { + // TODO AOIASD Support single Tokenizer concurrency + tokenizer, err := ctokenizer.NewTokenizer(map[string]string{}) + if err != nil { + return err + } + defer tokenizer.Destroy() + + for i := 0; i < len(data); i++ { + embeddingMap := map[uint32]float32{} + tokenStream := tokenizer.NewTokenStream(data[i]) + defer tokenStream.Destroy() + for tokenStream.Advance() { + token := tokenStream.Token() + // TODO More Hash Option + hash := typeutil.HashString2Uint32(token) + embeddingMap[hash] += 1 + } + dst[i] = embeddingMap + } + return nil +} + +func (v *BM25FunctionRunner) BatchRun(inputs ...any) ([]any, error) { + if len(inputs) > 1 { + return nil, fmt.Errorf("BM25 function receieve more than one input") + } + + text, ok := inputs[0].([]string) + if !ok { + return nil, fmt.Errorf("BM25 function batch input not string list") + } + + rowNum := len(text) + embedData := make([]map[uint32]float32, rowNum) + wg := sync.WaitGroup{} + + errCh := make(chan error, v.concurrency) + for i, j := 0, 0; i < v.concurrency && j < rowNum; i++ { + start := j + end := start + rowNum/v.concurrency + if i < rowNum%v.concurrency { + end += 1 + } + wg.Add(1) + go func() { + defer wg.Done() + err := v.run(text[start:end], embedData[start:end]) + if err != nil { + errCh <- err + return + } + }() + j = end + } + + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { + return nil, err + } + } + + return []any{buildSparseFloatArray(embedData)}, nil +} + +func (v *BM25FunctionRunner) GetSchema() *schemapb.FunctionSchema { + return v.schema +} + +func (v *BM25FunctionRunner) GetOutputFields() []*schemapb.FieldSchema { + return []*schemapb.FieldSchema{v.outputField} +} + +func buildSparseFloatArray(mapdata []map[uint32]float32) *schemapb.SparseFloatArray { + dim := 0 + bytes := lo.Map(mapdata, func(sparseMap map[uint32]float32, _ int) []byte { + if len(sparseMap) > dim { + dim = len(sparseMap) + } + return typeutil.CreateAndSortSparseFloatRow(sparseMap) + }) + + return &schemapb.SparseFloatArray{ + Contents: bytes, + Dim: int64(dim), + } +} diff --git a/internal/util/function/function.go b/internal/util/function/function.go new file mode 100644 index 0000000000000..a9056af41298d --- /dev/null +++ b/internal/util/function/function.go @@ -0,0 +1,41 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +type FunctionRunner interface { + BatchRun(inputs ...any) ([]any, error) + + GetSchema() *schemapb.FunctionSchema + GetOutputFields() []*schemapb.FieldSchema +} + +func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (FunctionRunner, error) { + switch schema.GetType() { + case schemapb.FunctionType_BM25: + return NewBM25FunctionRunner(coll, schema) + default: + return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String()) + } +} diff --git a/internal/util/function/function_test.go b/internal/util/function/function_test.go new file mode 100644 index 0000000000000..964ec3fc8edab --- /dev/null +++ b/internal/util/function/function_test.go @@ -0,0 +1,82 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func TestFunctionRunnerSuite(t *testing.T) { + suite.Run(t, new(FunctionRunnerSuite)) +} + +type FunctionRunnerSuite struct { + suite.Suite + schema *schemapb.CollectionSchema +} + +func (s *FunctionRunnerSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "sparse", DataType: schemapb.DataType_SparseFloatVector}, + }, + } +} + +func (s *FunctionRunnerSuite) TestBM25() { + _, err := NewFunctionRunner(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{101}, + }) + s.Error(err) + + runner, err := NewFunctionRunner(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + }) + + s.NoError(err) + + // test batch function run + output, err := runner.BatchRun([]string{"test string", "test string 2"}) + s.NoError(err) + + s.Equal(1, len(output)) + result, ok := output[0].(*schemapb.SparseFloatArray) + s.True(ok) + s.Equal(2, len(result.GetContents())) + + // return error because receive more than one field input + _, err = runner.BatchRun([]string{}, []string{}) + s.Error(err) + + // return error because field not string + _, err = runner.BatchRun([]int64{}) + s.Error(err) +} diff --git a/pkg/common/common.go b/pkg/common/common.go index 94f361da4a316..22b7b873b6523 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -98,6 +98,9 @@ const ( // SegmentIndexPath storage path const for segment index files. SegmentIndexPath = `index_files` + // SegmentBm25LogPath storage path const for bm25 statistic + SegmentBm25LogPath = `bm25_stats` + // PartitionStatsPath storage path const for partition stats files PartitionStatsPath = `part_stats` diff --git a/pkg/util/metautil/binlog.go b/pkg/util/metautil/binlog.go index 0394c1fcc8639..28887829480e2 100644 --- a/pkg/util/metautil/binlog.go +++ b/pkg/util/metautil/binlog.go @@ -52,6 +52,11 @@ func BuildStatsLogPath(rootPath string, collectionID, partitionID, segmentID, fi return path.Join(rootPath, common.SegmentStatslogPath, k) } +func BuildBm25LogPath(rootPath string, collectionID, partitionID, segmentID, fieldID, logID typeutil.UniqueID) string { + k := JoinIDPath(collectionID, partitionID, segmentID, fieldID, logID) + return path.Join(rootPath, common.SegmentBm25LogPath, k) +} + func GetSegmentIDFromStatsLogPath(logPath string) typeutil.UniqueID { return getSegmentIDFromPath(logPath, 3) } diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index e530100b38064..cdeef109da640 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -1694,6 +1694,18 @@ func SortSparseFloatRow(indices []uint32, values []float32) ([]uint32, []float32 return sortedIndices, sortedValues } +func CreateAndSortSparseFloatRow(sparse map[uint32]float32) []byte { + row := make([]byte, len(sparse)*8) + data := lo.MapToSlice(sparse, func(indices uint32, value float32) Pair[uint32, float32] { + return Pair[uint32, float32]{indices, value} + }) + sort.Slice(data, func(i, j int) bool { return data[i].A < data[j].A }) + for i := 0; i < len(data); i++ { + SparseFloatRowSetAt(row, i, data[i].A, data[i].B) + } + return row +} + func CreateSparseFloatRow(indices []uint32, values []float32) []byte { row := make([]byte, len(indices)*8) for i := 0; i < len(indices); i++ {