Skip to content

Commit

Permalink
enhance: [GoSDK] Add release methods & GPU indexes (milvus-io#34690)
Browse files Browse the repository at this point in the history
Related to milvus-io#31293

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Jul 16, 2024
1 parent 7306d2d commit ceb138d
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 0 deletions.
130 changes: 130 additions & 0 deletions client/index/gpu.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package index

import "strconv"

var _ Index = gpuBruteForceIndex{}

type gpuBruteForceIndex struct {
baseIndex
}

func (idx gpuBruteForceIndex) Params() map[string]string {
return map[string]string{
// build meta
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(GPUBruteForce),
}
}

func NewGPUBruteForceIndex(metricType MetricType) Index {
return gpuBruteForceIndex{
baseIndex: baseIndex{
metricType: metricType,
},
}
}

var _ Index = gpuIVFFlatIndex{}

type gpuIVFFlatIndex struct {
baseIndex
nlist int
}

func (idx gpuIVFFlatIndex) Params() map[string]string {
return map[string]string{
// build meta
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(GPUIvfFlat),
// build param
ivfNlistKey: strconv.Itoa(idx.nlist),
}
}

func NewGPUIVPFlatIndex(metricType MetricType) Index {
return gpuIVFFlatIndex{
baseIndex: baseIndex{
metricType: metricType,
},
}
}

var _ Index = gpuIVFPQIndex{}

type gpuIVFPQIndex struct {
baseIndex
nlist int
m int
nbits int
}

func (idx gpuIVFPQIndex) Params() map[string]string {
return map[string]string{
// build meta
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(GPUIvfFlat),
// build params
ivfNlistKey: strconv.Itoa(idx.nlist),
ivfPQMKey: strconv.Itoa(idx.m),
ivfPQNbits: strconv.Itoa(idx.nbits),
}
}

func NewGPUIVPPQIndex(metricType MetricType) Index {
return gpuIVFPQIndex{
baseIndex: baseIndex{
metricType: metricType,
},
}
}

const (
cagraInterGraphDegreeKey = `intermediate_graph_degree`
cagraGraphDegreeKey = `"graph_degree"`
)

type gpuCagra struct {
baseIndex
intermediateGraphDegree int
graphDegree int
}

func (idx gpuCagra) Params() map[string]string {
return map[string]string{
// build meta
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(GPUIvfFlat),
// build params
cagraInterGraphDegreeKey: strconv.Itoa(idx.intermediateGraphDegree),
cagraGraphDegreeKey: strconv.Itoa(idx.graphDegree),
}
}

func NewGPUCagraIndex(metricType MetricType,
intermediateGraphDegree,
graphDegree int,
) Index {
return gpuCagra{
baseIndex: baseIndex{
metricType: metricType,
},
intermediateGraphDegree: intermediateGraphDegree,
graphDegree: graphDegree,
}
}
19 changes: 19 additions & 0 deletions client/maintenance.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,25 @@ func (c *Client) LoadPartitions(ctx context.Context, option LoadPartitionsOption
return task, err
}

func (c *Client) ReleaseCollection(ctx context.Context, option ReleaseCollectionOption, callOptions ...grpc.CallOption) error {
req := option.Request()

return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.ReleaseCollection(ctx, req, callOptions...)

return merr.CheckRPCCall(resp, err)
})
}

func (c *Client) ReleasePartitions(ctx context.Context, option ReleasePartitionsOption, callOptions ...grpc.CallOption) error {
req := option.Request()

return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.ReleasePartitions(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
}

type FlushTask struct {
client *Client
collectionName string
Expand Down
47 changes: 47 additions & 0 deletions client/maintenance_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,53 @@ func NewLoadPartitionsOption(collectionName string, partitionsNames []string) *l
}
}

type ReleaseCollectionOption interface {
Request() *milvuspb.ReleaseCollectionRequest
}

var _ ReleaseCollectionOption = (*releaseCollectionOption)(nil)

type releaseCollectionOption struct {
collectionName string
}

func (opt *releaseCollectionOption) Request() *milvuspb.ReleaseCollectionRequest {
return &milvuspb.ReleaseCollectionRequest{
CollectionName: opt.collectionName,
}
}

func NewReleaseCollectionOption(collectionName string) *releaseCollectionOption {
return &releaseCollectionOption{
collectionName: collectionName,
}
}

type ReleasePartitionsOption interface {
Request() *milvuspb.ReleasePartitionsRequest
}

var _ ReleasePartitionsOption = (*releasePartitionsOption)(nil)

type releasePartitionsOption struct {
collectionName string
partitionNames []string
}

func (opt *releasePartitionsOption) Request() *milvuspb.ReleasePartitionsRequest {
return &milvuspb.ReleasePartitionsRequest{
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
}
}

func NewReleasePartitionsOptions(collectionName string, partitionNames ...string) *releasePartitionsOption {
return &releasePartitionsOption{
collectionName: collectionName,
partitionNames: partitionNames,
}
}

type FlushOption interface {
Request() *milvuspb.FlushRequest
CollectionName() string
Expand Down
51 changes: 51 additions & 0 deletions client/maintenance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,57 @@ func (s *MaintenanceSuite) TestLoadPartitions() {
})
}

func (s *MaintenanceSuite) TestReleaseCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, rcr *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error) {
s.Equal(collectionName, rcr.GetCollectionName())
return merr.Success(), nil
}).Once()

err := s.client.ReleaseCollection(ctx, NewReleaseCollectionOption(collectionName))
s.NoError(err)
})

s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()

err := s.client.ReleaseCollection(ctx, NewReleaseCollectionOption(collectionName))
s.Error(err)
})
}

func (s *MaintenanceSuite) TestReleasePartitions() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, rpr *milvuspb.ReleasePartitionsRequest) (*commonpb.Status, error) {
s.Equal(collectionName, rpr.GetCollectionName())
s.ElementsMatch([]string{partitionName}, rpr.GetPartitionNames())
return merr.Success(), nil
}).Once()

err := s.client.ReleasePartitions(ctx, NewReleasePartitionsOptions(collectionName, partitionName))
s.NoError(err)
})

s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()

err := s.client.ReleasePartitions(ctx, NewReleasePartitionsOptions(collectionName, partitionName))
s.Error(err)
})
}

func (s *MaintenanceSuite) TestFlush() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down

0 comments on commit ceb138d

Please sign in to comment.