Skip to content

Commit

Permalink
enhance: [GoSDK] Add alter collection API & expose options (#37365)
Browse files Browse the repository at this point in the history
Related to #31293

This PR:

- Add `AlterCollection` API for collection property modification
- Expose hidden or missing option methods

---------

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Nov 6, 2024
1 parent b3de4b0 commit c83b939
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 22 deletions.
9 changes: 9 additions & 0 deletions client/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,12 @@ func (c *Client) RenameCollection(ctx context.Context, option RenameCollectionOp
return merr.CheckRPCCall(resp, err)
})
}

func (c *Client) AlterCollection(ctx context.Context, option AlterCollectionOption, callOptions ...grpc.CallOption) error {
req := option.Request()

return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.AlterCollection(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
}
47 changes: 47 additions & 0 deletions client/collection_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package client

import (
"fmt"

"google.golang.org/protobuf/proto"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
Expand Down Expand Up @@ -56,6 +58,8 @@ type createCollectionOption struct {
// partition key
numPartitions int64

indexOptions []CreateIndexOption

// is fast create collection
isFast bool
// fast creation with index
Expand Down Expand Up @@ -83,6 +87,21 @@ func (opt *createCollectionOption) WithVarcharPK(varcharPK bool, maxLen int) *cr
return opt
}

func (opt *createCollectionOption) WithIndexOptions(indexOpts ...CreateIndexOption) *createCollectionOption {
opt.indexOptions = append(opt.indexOptions, indexOpts...)
return opt
}

func (opt *createCollectionOption) WithProperty(key string, value any) *createCollectionOption {
opt.properties[key] = fmt.Sprintf("%v", value)
return opt
}

func (opt *createCollectionOption) WithConsistencyLevel(cl entity.ConsistencyLevel) *createCollectionOption {
opt.consistencyLevel = cl
return opt
}

func (opt *createCollectionOption) Request() *milvuspb.CreateCollectionRequest {
// fast create collection
if opt.isFast {
Expand All @@ -103,6 +122,7 @@ func (opt *createCollectionOption) Request() *milvuspb.CreateCollectionRequest {

var schemaBytes []byte
if opt.schema != nil {
opt.schema.WithName(opt.name)
schemaProto := opt.schema.ProtoMessage()
schemaBytes, _ = proto.Marshal(schemaProto)
}
Expand Down Expand Up @@ -144,6 +164,7 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti
dim: dim,
enabledDynamicSchema: true,
consistencyLevel: entity.DefaultConsistencyLevel,
properties: make(map[string]string),

isFast: true,
metricType: entity.COSINE,
Expand All @@ -157,6 +178,7 @@ func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *cr
shardNum: 1,
schema: collectionSchema,
consistencyLevel: entity.DefaultConsistencyLevel,
properties: make(map[string]string),

metricType: entity.COSINE,
}
Expand Down Expand Up @@ -263,3 +285,28 @@ func NewRenameCollectionOption(oldName, newName string) *renameCollectionOption
newCollectionName: newName,
}
}

type AlterCollectionOption interface {
Request() *milvuspb.AlterCollectionRequest
}

type alterCollectionOption struct {
collectionName string
properties map[string]string
}

func (opt *alterCollectionOption) WithProperty(key string, value any) *alterCollectionOption {
opt.properties[key] = fmt.Sprintf("%v", value)
return opt
}

func (opt *alterCollectionOption) Request() *milvuspb.AlterCollectionRequest {
return &milvuspb.AlterCollectionRequest{
CollectionName: opt.collectionName,
Properties: entity.MapKvPairs(opt.properties),
}
}

func NewAlterCollectionOption(collection string) *alterCollectionOption {
return &alterCollectionOption{collectionName: collection, properties: make(map[string]string)}
}
43 changes: 42 additions & 1 deletion client/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/merr"
)

Expand Down Expand Up @@ -117,11 +118,20 @@ func (s *CollectionSuite) TestCreateCollectionOptions() {
WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
WithField(entity.NewField().WithName("vector").WithDim(128).WithDataType(entity.FieldTypeFloatVector))

opt = NewCreateCollectionOption(collectionName, schema).WithShardNum(2)
opt = NewCreateCollectionOption(collectionName, schema).
WithShardNum(2).
WithConsistencyLevel(entity.ClEventually).
WithProperty(common.CollectionTTLConfigKey, 86400)

req = opt.Request()
s.Equal(collectionName, req.GetCollectionName())
s.EqualValues(2, req.GetShardsNum())
s.EqualValues(commonpb.ConsistencyLevel_Eventually, req.GetConsistencyLevel())
if s.Len(req.GetProperties(), 1) {
kv := req.GetProperties()[0]
s.Equal(common.CollectionTTLConfigKey, kv.GetKey())
s.Equal("86400", kv.GetValue())
}

collSchema = &schemapb.CollectionSchema{}
err = proto.Unmarshal(req.GetSchema(), collSchema)
Expand Down Expand Up @@ -274,6 +284,37 @@ func (s *CollectionSuite) TestRenameCollection() {
})
}

func (s *CollectionSuite) TestAlterCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

collName := fmt.Sprintf("test_collection_%s", s.randString(6))
key := s.randString(6)
value := s.randString(6)

s.Run("success", func() {
s.mock.EXPECT().AlterCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, acr *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) {
s.Equal(collName, acr.GetCollectionName())
if s.Len(acr.GetProperties(), 1) {
item := acr.GetProperties()[0]
s.Equal(key, item.GetKey())
s.Equal(value, item.GetValue())
}
return merr.Success(), nil
}).Once()

err := s.client.AlterCollection(ctx, NewAlterCollectionOption(collName).WithProperty(key, value))
s.NoError(err)
})

s.Run("failure", func() {
s.mock.EXPECT().AlterCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()

err := s.client.AlterCollection(ctx, NewAlterCollectionOption(collName).WithProperty(key, value))
s.Error(err)
})
}

func TestCollection(t *testing.T) {
suite.Run(t, new(CollectionSuite))
}
5 changes: 5 additions & 0 deletions client/read_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ func (opt *searchOption) WithPartitions(partitionNames ...string) *searchOption
return opt
}

func (opt *searchOption) WithGroupByField(groupByField string) *searchOption {
opt.request.groupByField = groupByField
return opt
}

func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
return &searchOption{
collectionName: collectionName,
Expand Down
3 changes: 2 additions & 1 deletion client/read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func (s *ReadSuite) TestSearch() {
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
s.Equal(collectionName, sr.GetCollectionName())
s.ElementsMatch([]string{partitionName}, sr.GetPartitionNames())
// s.Equal(s)

return &milvuspb.SearchResults{
Status: merr.Success(),
Expand All @@ -71,7 +72,7 @@ func (s *ReadSuite) TestSearch() {
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
})),
}).WithPartitions(partitionName))
}).WithPartitions(partitionName).WithGroupByField("group_by"))
s.NoError(err)
})

Expand Down
26 changes: 6 additions & 20 deletions tests/go_client/testcases/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,35 +445,22 @@ func TestCreateCollectionWithInvalidCollectionName(t *testing.T) {
// connect
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := createDefaultMilvusClient(ctx, t)
collName := common.GenRandomString(prefix, 6)

// create collection and schema no name
schema := genDefaultSchema()
err2 := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema))
err2 := mc.CreateCollection(ctx, client.NewCreateCollectionOption("", schema))
common.CheckErr(t, err2, false, "collection name should not be empty")

// create collection with invalid schema name
for _, invalidName := range common.GenInvalidNames() {
log.Debug("TestCreateCollectionWithInvalidCollectionName", zap.String("collectionName", invalidName))

// schema has invalid name
schema.WithName(invalidName)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema))
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(invalidName, schema))
common.CheckErr(t, err, false, "collection name should not be empty",
"the first character of a collection name must be an underscore or letter",
"collection name can only contain numbers, letters and underscores",
fmt.Sprintf("the length of a collection name must be less than %d characters", common.MaxCollectionNameLen))

// collection option has invalid name
schema.WithName(collName)
err2 := mc.CreateCollection(ctx, client.NewCreateCollectionOption(invalidName, schema))
common.CheckErr(t, err2, false, "collection name matches schema name")
}

// collection name not equal to schema name
schema.WithName(collName)
err3 := mc.CreateCollection(ctx, client.NewCreateCollectionOption(common.GenRandomString("pre", 4), schema))
common.CheckErr(t, err3, false, "collection name matches schema name")
}

// create collection missing pk field or vector field
Expand Down Expand Up @@ -937,11 +924,10 @@ func TestCreateCollectionInvalid(t *testing.T) {
vecField := entity.NewField().WithName("vec").WithDataType(entity.FieldTypeFloatVector).WithDim(8)
mSchemaErrs := []mSchemaErr{
{schema: nil, errMsg: "schema does not contain vector field"},
{schema: entity.NewSchema().WithField(vecField), errMsg: "collection name should not be empty"}, // no collection name
{schema: entity.NewSchema().WithName("aaa").WithField(vecField), errMsg: "primary key is not specified"}, // no pk field
{schema: entity.NewSchema().WithName("aaa").WithField(vecField).WithField(entity.NewField()), errMsg: "primary key is not specified"},
{schema: entity.NewSchema().WithName("aaa").WithField(vecField).WithField(entity.NewField().WithIsPrimaryKey(true)), errMsg: "the data type of primary key should be Int64 or VarChar"},
{schema: entity.NewSchema().WithName("aaa").WithField(vecField).WithField(entity.NewField().WithIsPrimaryKey(true).WithDataType(entity.FieldTypeVarChar)), errMsg: "field name should not be empty"},
{schema: entity.NewSchema().WithField(vecField), errMsg: "primary key is not specified"}, // no pk field
{schema: entity.NewSchema().WithField(vecField).WithField(entity.NewField()), errMsg: "primary key is not specified"},
{schema: entity.NewSchema().WithField(vecField).WithField(entity.NewField().WithIsPrimaryKey(true)), errMsg: "the data type of primary key should be Int64 or VarChar"},
{schema: entity.NewSchema().WithField(vecField).WithField(entity.NewField().WithIsPrimaryKey(true).WithDataType(entity.FieldTypeVarChar)), errMsg: "field name should not be empty"},
}
for _, mSchema := range mSchemaErrs {
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, mSchema.schema))
Expand Down

0 comments on commit c83b939

Please sign in to comment.