diff --git a/internal/distributed/streaming/streaming.go b/internal/distributed/streaming/streaming.go index e08ee103a1176..39d4cb31980cf 100644 --- a/internal/distributed/streaming/streaming.go +++ b/internal/distributed/streaming/streaming.go @@ -10,10 +10,6 @@ import ( var singleton WALAccesser = nil -func SetWAL(w WALAccesser) { - singleton = w -} - // Init initializes the wal accesser with the given etcd client. // should be called before any other operations. func Init() { @@ -23,9 +19,7 @@ func Init() { // Release releases the resources of the wal accesser. func Release() { - if w, ok := singleton.(*walAccesserImpl); ok && w != nil { - w.Close() - } + singleton.Close() } // WAL is the entrance to interact with the milvus write ahead log. @@ -67,4 +61,7 @@ type WALAccesser interface { // Read returns a scanner for reading records from the wal. Read(ctx context.Context, opts ReadOption) Scanner + + // Close closes the wal accesser + Close() } diff --git a/internal/distributed/streaming/test_streaming.go b/internal/distributed/streaming/test_streaming.go new file mode 100644 index 0000000000000..bf878d86fa5f3 --- /dev/null +++ b/internal/distributed/streaming/test_streaming.go @@ -0,0 +1,24 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build test + +package streaming + +// SetWALForTest initializes the singleton of wal for test. +func SetWALForTest(w WALAccesser) { + singleton = w +} diff --git a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go index 72234f9a33846..733381f1bd8ec 100644 --- a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go +++ b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go @@ -81,6 +81,38 @@ func (_c *MockWALAccesser_Append_Call) RunAndReturn(run func(context.Context, .. return _c } +// Close provides a mock function with given fields: +func (_m *MockWALAccesser) Close() { + _m.Called() +} + +// MockWALAccesser_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockWALAccesser_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockWALAccesser_Expecter) Close() *MockWALAccesser_Close_Call { + return &MockWALAccesser_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockWALAccesser_Close_Call) Run(run func()) *MockWALAccesser_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockWALAccesser_Close_Call) Return() *MockWALAccesser_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockWALAccesser_Close_Call) RunAndReturn(run func()) *MockWALAccesser_Close_Call { + _c.Call.Return(run) + return _c +} + // Read provides a mock function with given fields: ctx, opts func (_m *MockWALAccesser) Read(ctx context.Context, opts streaming.ReadOption) streaming.Scanner { ret := _m.Called(ctx, opts) diff --git a/internal/rootcoord/create_partition_task.go b/internal/rootcoord/create_partition_task.go index 4f108beaa8d89..6d7ceaa73df3d 100644 --- a/internal/rootcoord/create_partition_task.go +++ b/internal/rootcoord/create_partition_task.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" ) @@ -96,6 +97,14 @@ func (t *createPartitionTask) Execute(ctx context.Context) error { ts: t.GetTs(), }) + if streamingutil.IsStreamingServiceEnabled() { + undoTask.AddStep(&broadcastCreatePartitionMsgStep{ + baseStep: baseStep{core: t.core}, + vchannels: t.collMeta.VirtualChannelNames, + partition: partition, + }, &nullStep{}) + } + undoTask.AddStep(&nullStep{}, &releasePartitionsStep{ baseStep: baseStep{core: t.core}, collectionID: t.collMeta.CollectionID, diff --git a/internal/rootcoord/garbage_collector_test.go b/internal/rootcoord/garbage_collector_test.go index a42667e06cf95..d63a066fb7392 100644 --- a/internal/rootcoord/garbage_collector_test.go +++ b/internal/rootcoord/garbage_collector_test.go @@ -548,7 +548,7 @@ func TestGcPartitionData(t *testing.T) { wal := mock_streaming.NewMockWALAccesser(t) wal.EXPECT().Append(mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{}) - streaming.SetWAL(wal) + streaming.SetWALForTest(wal) tsoAllocator := mocktso.NewAllocator(t) tsoAllocator.EXPECT().GenerateTSO(mock.Anything).Return(1000, nil) diff --git a/internal/rootcoord/step.go b/internal/rootcoord/step.go index feba1b3db42e5..742fb4bdc25c1 100644 --- a/internal/rootcoord/step.go +++ b/internal/rootcoord/step.go @@ -21,10 +21,15 @@ import ( "fmt" "time" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/metastore/model" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/util/proxyutil" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" ) type stepPriority int @@ -374,6 +379,49 @@ func (s *addPartitionMetaStep) Desc() string { return fmt.Sprintf("add partition to meta table, collection: %d, partition: %d", s.partition.CollectionID, s.partition.PartitionID) } +type broadcastCreatePartitionMsgStep struct { + baseStep + vchannels []string + partition *model.Partition +} + +func (s *broadcastCreatePartitionMsgStep) Execute(ctx context.Context) ([]nestedStep, error) { + req := &msgpb.CreatePartitionRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_CreatePartition), + commonpbutil.WithTimeStamp(0), // ts is given by streamingnode. + ), + PartitionName: s.partition.PartitionName, + CollectionID: s.partition.CollectionID, + PartitionID: s.partition.PartitionID, + } + + msgs := make([]message.MutableMessage, 0, len(s.vchannels)) + for _, vchannel := range s.vchannels { + msg, err := message.NewCreatePartitionMessageBuilderV1(). + WithVChannel(vchannel). + WithHeader(&message.CreatePartitionMessageHeader{ + CollectionId: s.partition.CollectionID, + PartitionId: s.partition.PartitionID, + }). + WithBody(req). + BuildMutable() + if err != nil { + return nil, err + } + msgs = append(msgs, msg) + } + resp := streaming.WAL().Append(ctx, msgs...) + if err := resp.IsAnyError(); err != nil { + return nil, err + } + return nil, nil +} + +func (s *broadcastCreatePartitionMsgStep) Desc() string { + return fmt.Sprintf("broadcast create partition message to mq, collection: %d, partition: %d", s.partition.CollectionID, s.partition.PartitionID) +} + type changePartitionStateStep struct { baseStep collectionID UniqueID diff --git a/internal/rootcoord/step_test.go b/internal/rootcoord/step_test.go index ef1315e54b0df..c59c50b1d3495 100644 --- a/internal/rootcoord/step_test.go +++ b/internal/rootcoord/step_test.go @@ -22,6 +22,11 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/metastore/model" + "github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming" ) func Test_waitForTsSyncedStep_Execute(t *testing.T) { @@ -115,3 +120,21 @@ func TestSkip(t *testing.T) { assert.NoError(t, err) } } + +func TestBroadcastCreatePartitionMsgStep(t *testing.T) { + wal := mock_streaming.NewMockWALAccesser(t) + wal.EXPECT().Append(mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{}) + streaming.SetWALForTest(wal) + + step := &broadcastCreatePartitionMsgStep{ + baseStep: baseStep{core: nil}, + vchannels: []string{"ch-0", "ch-1"}, + partition: &model.Partition{ + CollectionID: 1, + PartitionID: 2, + }, + } + t.Logf("%v\n", step.Desc()) + _, err := step.Execute(context.Background()) + assert.NoError(t, err) +}