Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the replicate message api #622

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (

"google.golang.org/grpc"

"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"

"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

Expand All @@ -38,9 +40,9 @@ type Client interface {
// ListDatabases list all database in milvus cluster.
ListDatabases(ctx context.Context) ([]entity.Database, error)
// CreateDatabase create database with the given name.
CreateDatabase(ctx context.Context, dbName string) error
CreateDatabase(ctx context.Context, dbName string, opts ...CreateDatabaseOption) error
// DropDatabase drop database with the given db name.
DropDatabase(ctx context.Context, dbName string) error
DropDatabase(ctx context.Context, dbName string, opts ...DropDatabaseOption) error

// -- collection --

Expand All @@ -53,13 +55,13 @@ type Client interface {
// DescribeCollection describe collection meta
DescribeCollection(ctx context.Context, collName string) (*entity.Collection, error)
// DropCollection drop the specified collection
DropCollection(ctx context.Context, collName string) error
DropCollection(ctx context.Context, collName string, opts ...DropCollectionOption) error
// GetCollectionStatistics get collection statistics
GetCollectionStatistics(ctx context.Context, collName string) (map[string]string, error)
// LoadCollection load collection into memory
LoadCollection(ctx context.Context, collName string, async bool, opts ...LoadCollectionOption) error
// ReleaseCollection release loaded collection
ReleaseCollection(ctx context.Context, collName string) error
ReleaseCollection(ctx context.Context, collName string, opts ...ReleaseCollectionOption) error
// HasCollection check whether collection exists
HasCollection(ctx context.Context, collName string) (bool, error)
// RenameCollection performs renaming for provided collection.
Expand Down Expand Up @@ -91,17 +93,17 @@ type Client interface {
// -- partition --

// CreatePartition create partition for collection
CreatePartition(ctx context.Context, collName string, partitionName string) error
CreatePartition(ctx context.Context, collName string, partitionName string, opts ...CreatePartitionOption) error
// DropPartition drop partition from collection
DropPartition(ctx context.Context, collName string, partitionName string) error
DropPartition(ctx context.Context, collName string, partitionName string, opts ...DropPartitionOption) error
// ShowPartitions list all partitions from collection
ShowPartitions(ctx context.Context, collName string) ([]*entity.Partition, error)
// HasPartition check whether partition exists in collection
HasPartition(ctx context.Context, collName string, partitionName string) (bool, error)
// LoadPartitions load partitions into memory
LoadPartitions(ctx context.Context, collName string, partitionNames []string, async bool) error
LoadPartitions(ctx context.Context, collName string, partitionNames []string, async bool, opts ...LoadPartitionsOption) error
// ReleasePartitions release partitions
ReleasePartitions(ctx context.Context, collName string, partitionNames []string) error
ReleasePartitions(ctx context.Context, collName string, partitionNames []string, opts ...ReleasePartitionsOption) error

// -- segment --
GetPersistentSegmentInfo(ctx context.Context, collName string) ([]*entity.Segment, error)
Expand All @@ -124,10 +126,10 @@ type Client interface {
// Insert column-based data into collection, returns id column values
Insert(ctx context.Context, collName string, partitionName string, columns ...entity.Column) (entity.Column, error)
// Flush collection, specified
Flush(ctx context.Context, collName string, async bool) error
Flush(ctx context.Context, collName string, async bool, opts ...FlushOption) error
// FlushV2 flush collection, specified, return newly sealed segmentIds, all flushed segmentIds of the collection, seal time and error
// currently it is only used in milvus-backup(https://github.com/zilliztech/milvus-backup)
FlushV2(ctx context.Context, collName string, async bool) ([]int64, []int64, int64, error)
FlushV2(ctx context.Context, collName string, async bool, opts ...FlushOption) ([]int64, []int64, int64, error)
// DeleteByPks deletes entries related to provided primary keys
DeleteByPks(ctx context.Context, collName string, partitionName string, ids entity.Column) error
// Delete deletes entries match expression
Expand Down Expand Up @@ -211,6 +213,12 @@ type Client interface {
GetVersion(ctx context.Context) (string, error)
// CheckHealth returns milvus state
CheckHealth(ctx context.Context) (*entity.MilvusState, error)

ReplicateMessage(ctx context.Context,
channelName string, beginTs, endTs uint64,
msgsBytes [][]byte, startPositions, endPositions []*msgpb.MsgPosition,
opts ...ReplicateMessageOption,
) (*entity.MessageInfo, error)
}

// NewClient create a client connected to remote milvus cluster.
Expand Down
12 changes: 10 additions & 2 deletions client/client_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ const (
MListDatabase ServiceMethod = 1000
MCreateDatabase ServiceMethod = 1001
MDropDatabase ServiceMethod = 1002

MReplicateMessage ServiceMethod = 1100
)

// injection function definition
Expand Down Expand Up @@ -924,8 +926,14 @@ func (m *MockServer) AllocTimestamp(_ context.Context, _ *milvuspb.AllocTimestam
panic("not implemented")
}

func (m *MockServer) ReplicateMessage(_ context.Context, _ *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) {
panic("not implemented")
func (m *MockServer) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) {
f := m.GetInjection(MReplicateMessage)
if f != nil {
r, err := f(ctx, req)
return r.(*milvuspb.ReplicateMessageResponse), err
}
s, err := SuccessStatus()
return &milvuspb.ReplicateMessageResponse{Status: s}, err
}

func (m *MockServer) Connect(_ context.Context, _ *milvuspb.ConnectRequest) (*milvuspb.ConnectResponse, error) {
Expand Down
12 changes: 10 additions & 2 deletions client/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/cockroachdb/errors"

"github.com/golang/protobuf/proto"

"github.com/milvus-io/milvus-sdk-go/v2/entity"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
Expand Down Expand Up @@ -151,6 +152,7 @@ func (c *GrpcClient) requestCreateCollection(ctx context.Context, sch *entity.Sc
}

req := &milvuspb.CreateCollectionRequest{
Base: opt.MsgBase,
DbName: "", // reserved fields, not used for now
CollectionName: sch.CollectionName,
Schema: bs,
Expand Down Expand Up @@ -279,7 +281,7 @@ func (c *GrpcClient) DescribeCollection(ctx context.Context, collName string) (*
}

// DropCollection drop collection by name
func (c *GrpcClient) DropCollection(ctx context.Context, collName string) error {
func (c *GrpcClient) DropCollection(ctx context.Context, collName string, opts ...DropCollectionOption) error {
if c.Service == nil {
return ErrClientNotReady
}
Expand All @@ -290,6 +292,9 @@ func (c *GrpcClient) DropCollection(ctx context.Context, collName string) error
req := &milvuspb.DropCollectionRequest{
CollectionName: collName,
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.DropCollection(ctx, req)
if err != nil {
return err
Expand Down Expand Up @@ -447,7 +452,7 @@ func (c *GrpcClient) LoadCollection(ctx context.Context, collName string, async
}

// ReleaseCollection release loaded collection
func (c *GrpcClient) ReleaseCollection(ctx context.Context, collName string) error {
func (c *GrpcClient) ReleaseCollection(ctx context.Context, collName string, opts ...ReleaseCollectionOption) error {
if c.Service == nil {
return ErrClientNotReady
}
Expand All @@ -459,6 +464,9 @@ func (c *GrpcClient) ReleaseCollection(ctx context.Context, collName string) err
DbName: "", // reserved
CollectionName: collName,
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.ReleaseCollection(ctx, req)
if err != nil {
return err
Expand Down
8 changes: 4 additions & 4 deletions client/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (s *CollectionSuite) TestCreateCollection() {
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil)

err := c.CreateCollection(ctx, ds, shardsNum)
err := c.CreateCollection(ctx, ds, shardsNum, WithCreateCollectionMsgBase(&commonpb.MsgBase{}))
s.NoError(err)
})

Expand Down Expand Up @@ -514,7 +514,7 @@ func (s *CollectionSuite) TestLoadCollection() {

s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)

err := c.LoadCollection(ctx, testCollectionName, true)
err := c.LoadCollection(ctx, testCollectionName, true, WithLoadCollectionMsgBase(&commonpb.MsgBase{}))
s.NoError(err)
})

Expand Down Expand Up @@ -663,7 +663,7 @@ func TestGrpcClientDropCollection(t *testing.T) {
})

t.Run("Test Normal drop", func(t *testing.T) {
assert.Nil(t, c.DropCollection(ctx, testCollectionName))
assert.Nil(t, c.DropCollection(ctx, testCollectionName, WithDropCollectionMsgBase(&commonpb.MsgBase{})))
})

t.Run("Test drop non-existing collection", func(t *testing.T) {
Expand All @@ -685,7 +685,7 @@ func TestReleaseCollection(t *testing.T) {
return SuccessStatus()
})

c.ReleaseCollection(ctx, testCollectionName)
c.ReleaseCollection(ctx, testCollectionName, WithReleaseCollectionMsgBase(&commonpb.MsgBase{}))
}

func TestGrpcClientHasCollection(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion client/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestGrpcClientFlush(t *testing.T) {
c := testClient(ctx, t)

t.Run("test async flush", func(t *testing.T) {
assert.Nil(t, c.Flush(ctx, testCollectionName, true))
assert.Nil(t, c.Flush(ctx, testCollectionName, true, WithFlushMsgBase(&commonpb.MsgBase{})))
})

t.Run("test sync flush", func(t *testing.T) {
Expand Down
10 changes: 8 additions & 2 deletions client/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (c *GrpcClient) UsingDatabase(ctx context.Context, dbName string) error {

// CreateDatabase creates a new database for remote Milvus cluster.
// TODO:New options can be added as expanding parameters.
func (c *GrpcClient) CreateDatabase(ctx context.Context, dbName string) error {
func (c *GrpcClient) CreateDatabase(ctx context.Context, dbName string, opts ...CreateDatabaseOption) error {
if c.Service == nil {
return ErrClientNotReady
}
Expand All @@ -50,6 +50,9 @@ func (c *GrpcClient) CreateDatabase(ctx context.Context, dbName string) error {
req := &milvuspb.CreateDatabaseRequest{
DbName: dbName,
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.CreateDatabase(ctx, req)
if err != nil {
return err
Expand Down Expand Up @@ -84,7 +87,7 @@ func (c *GrpcClient) ListDatabases(ctx context.Context) ([]entity.Database, erro
}

// DropDatabase drop all database in milvus cluster.
func (c *GrpcClient) DropDatabase(ctx context.Context, dbName string) error {
func (c *GrpcClient) DropDatabase(ctx context.Context, dbName string, opts ...DropDatabaseOption) error {
if c.Service == nil {
return ErrClientNotReady
}
Expand All @@ -95,6 +98,9 @@ func (c *GrpcClient) DropDatabase(ctx context.Context, dbName string) error {
req := &milvuspb.DropDatabaseRequest{
DbName: dbName,
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.DropDatabase(ctx, req)
if err != nil {
return err
Expand Down
5 changes: 3 additions & 2 deletions client/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/go-faker/faker/v4"
"github.com/go-faker/faker/v4/pkg/options"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -45,7 +46,7 @@ func TestGrpcClientCreateDatabase(t *testing.T) {
mockServer.SetInjection(MCreateDatabase, func(ctx context.Context, m proto.Message) (proto.Message, error) {
return SuccessStatus()
})
err := c.CreateDatabase(ctx, "a")
err := c.CreateDatabase(ctx, "a", WithCreateDatabaseMsgBase(&commonpb.MsgBase{}))
assert.Nil(t, err)
}

Expand All @@ -55,6 +56,6 @@ func TestGrpcClientDropDatabase(t *testing.T) {
mockServer.SetInjection(MDropDatabase, func(ctx context.Context, m proto.Message) (proto.Message, error) {
return SuccessStatus()
})
err := c.DropDatabase(ctx, "a")
err := c.DropDatabase(ctx, "a", WithDropDatabaseMsgBase(&commonpb.MsgBase{}))
assert.Nil(t, err)
}
9 changes: 9 additions & 0 deletions client/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type indexDef struct {
name string
fieldName string
collectionName string
MsgBase *commonpb.MsgBase
}

// IndexOption is the predefined function to alter index def.
Expand All @@ -71,6 +72,12 @@ func WithIndexName(name string) IndexOption {
}
}

func WithIndexMsgBase(msgBase *commonpb.MsgBase) IndexOption {
return func(def *indexDef) {
def.MsgBase = msgBase
}
}

func getIndexDef(opts ...IndexOption) indexDef {
idxDef := indexDef{}
for _, opt := range opts {
Expand All @@ -93,6 +100,7 @@ func (c *GrpcClient) CreateIndex(ctx context.Context, collName string, fieldName
idxDef := getIndexDef(opts...)

req := &milvuspb.CreateIndexRequest{
Base: idxDef.MsgBase,
DbName: "", // reserved
CollectionName: collName,
FieldName: fieldName,
Expand Down Expand Up @@ -167,6 +175,7 @@ func (c *GrpcClient) DropIndex(ctx context.Context, collName string, fieldName s

idxDef := getIndexDef(opts...)
req := &milvuspb.DropIndexRequest{
Base: idxDef.MsgBase,
DbName: "", //reserved,
CollectionName: collName,
FieldName: fieldName,
Expand Down
4 changes: 2 additions & 2 deletions client/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestGrpcClientCreateIndex(t *testing.T) {
})

t.Run("test async create index", func(t *testing.T) {
assert.Nil(t, c.CreateIndex(ctx, testCollectionName, fieldName, idx, true))
assert.Nil(t, c.CreateIndex(ctx, testCollectionName, fieldName, idx, true, WithIndexMsgBase(&commonpb.MsgBase{})))
})

t.Run("test sync create index", func(t *testing.T) {
Expand Down Expand Up @@ -85,7 +85,7 @@ func TestGrpcClientDropIndex(t *testing.T) {
c := testClient(ctx, t)
mockServer.SetInjection(MHasCollection, hasCollectionDefault)
mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema()))
assert.Nil(t, c.DropIndex(ctx, testCollectionName, "vector"))
assert.Nil(t, c.DropIndex(ctx, testCollectionName, "vector", WithIndexMsgBase(&commonpb.MsgBase{})))
}

func TestGrpcClientDescribeIndex(t *testing.T) {
Expand Down
9 changes: 6 additions & 3 deletions client/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,14 @@ func (c *GrpcClient) mergeDynamicColumns(dynamicName string, rowSize int, column

// Flush force collection to flush memory records into storage
// in sync mode, flush will wait all segments to be flushed
func (c *GrpcClient) Flush(ctx context.Context, collName string, async bool) error {
_, _, _, err := c.FlushV2(ctx, collName, async)
func (c *GrpcClient) Flush(ctx context.Context, collName string, async bool, opts ...FlushOption) error {
_, _, _, err := c.FlushV2(ctx, collName, async, opts...)
return err
}

// Flush force collection to flush memory records into storage
// in sync mode, flush will wait all segments to be flushed
func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool) ([]int64, []int64, int64, error) {
func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool, opts ...FlushOption) ([]int64, []int64, int64, error) {
if c.Service == nil {
return nil, nil, 0, ErrClientNotReady
}
Expand All @@ -208,6 +208,9 @@ func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool) (
DbName: "", // reserved,
CollectionNames: []string{collName},
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.Flush(ctx, req)
if err != nil {
return nil, nil, 0, err
Expand Down
41 changes: 41 additions & 0 deletions client/mq_message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package client

import (
"context"

"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

func (c *GrpcClient) ReplicateMessage(ctx context.Context,
channelName string, beginTs, endTs uint64,
msgsBytes [][]byte, startPositions, endPositions []*msgpb.MsgPosition,
opts ...ReplicateMessageOption) (*entity.MessageInfo, error) {

if c.Service == nil {
return nil, ErrClientNotReady
}
req := &milvuspb.ReplicateMessageRequest{
ChannelName: channelName,
BeginTs: beginTs,
EndTs: endTs,
Msgs: msgsBytes,
StartPositions: startPositions,
EndPositions: endPositions,
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.ReplicateMessage(ctx, req)
if err != nil {
return nil, err
}
err = handleRespStatus(resp.GetStatus())
if err != nil {
return nil, err
}
return &entity.MessageInfo{
Position: resp.GetPosition(),
}, nil
}
Loading
Loading