diff --git a/client/index/gpu.go b/client/index/gpu.go new file mode 100644 index 0000000000000..129d0e0071935 --- /dev/null +++ b/client/index/gpu.go @@ -0,0 +1,130 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package index + +import "strconv" + +var _ Index = gpuBruteForceIndex{} + +type gpuBruteForceIndex struct { + baseIndex +} + +func (idx gpuBruteForceIndex) Params() map[string]string { + return map[string]string{ + // build meta + MetricTypeKey: string(idx.metricType), + IndexTypeKey: string(GPUBruteForce), + } +} + +func NewGPUBruteForceIndex(metricType MetricType) Index { + return gpuBruteForceIndex{ + baseIndex: baseIndex{ + metricType: metricType, + }, + } +} + +var _ Index = gpuIVFFlatIndex{} + +type gpuIVFFlatIndex struct { + baseIndex + nlist int +} + +func (idx gpuIVFFlatIndex) Params() map[string]string { + return map[string]string{ + // build meta + MetricTypeKey: string(idx.metricType), + IndexTypeKey: string(GPUIvfFlat), + // build param + ivfNlistKey: strconv.Itoa(idx.nlist), + } +} + +func NewGPUIVPFlatIndex(metricType MetricType) Index { + return gpuIVFFlatIndex{ + baseIndex: baseIndex{ + metricType: metricType, + }, + } +} + +var _ Index = gpuIVFPQIndex{} + +type gpuIVFPQIndex struct { + baseIndex + nlist int + m int + nbits int +} + +func (idx gpuIVFPQIndex) Params() map[string]string { + return map[string]string{ + // build meta + MetricTypeKey: string(idx.metricType), + IndexTypeKey: string(GPUIvfFlat), + // build params + ivfNlistKey: strconv.Itoa(idx.nlist), + ivfPQMKey: strconv.Itoa(idx.m), + ivfPQNbits: strconv.Itoa(idx.nbits), + } +} + +func NewGPUIVPPQIndex(metricType MetricType) Index { + return gpuIVFPQIndex{ + baseIndex: baseIndex{ + metricType: metricType, + }, + } +} + +const ( + cagraInterGraphDegreeKey = `intermediate_graph_degree` + cagraGraphDegreeKey = `"graph_degree"` +) + +type gpuCagra struct { + baseIndex + intermediateGraphDegree int + graphDegree int +} + +func (idx gpuCagra) Params() map[string]string { + return map[string]string{ + // build meta + MetricTypeKey: string(idx.metricType), + IndexTypeKey: string(GPUIvfFlat), + // build params + cagraInterGraphDegreeKey: strconv.Itoa(idx.intermediateGraphDegree), + cagraGraphDegreeKey: strconv.Itoa(idx.graphDegree), + } +} + +func NewGPUCagraIndex(metricType MetricType, + intermediateGraphDegree, + graphDegree int, +) Index { + return gpuCagra{ + baseIndex: baseIndex{ + metricType: metricType, + }, + intermediateGraphDegree: intermediateGraphDegree, + graphDegree: graphDegree, + } +} diff --git a/client/maintenance.go b/client/maintenance.go index 98ec167b39de1..71471220bc3b9 100644 --- a/client/maintenance.go +++ b/client/maintenance.go @@ -106,6 +106,25 @@ func (c *Client) LoadPartitions(ctx context.Context, option LoadPartitionsOption return task, err } +func (c *Client) ReleaseCollection(ctx context.Context, option ReleaseCollectionOption, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.ReleaseCollection(ctx, req, callOptions...) + + return merr.CheckRPCCall(resp, err) + }) +} + +func (c *Client) ReleasePartitions(ctx context.Context, option ReleasePartitionsOption, callOptions ...grpc.CallOption) error { + req := option.Request() + + return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.ReleasePartitions(ctx, req, callOptions...) + return merr.CheckRPCCall(resp, err) + }) +} + type FlushTask struct { client *Client collectionName string diff --git a/client/maintenance_options.go b/client/maintenance_options.go index 37bd4423895fa..66c41c7529e11 100644 --- a/client/maintenance_options.go +++ b/client/maintenance_options.go @@ -92,6 +92,53 @@ func NewLoadPartitionsOption(collectionName string, partitionsNames []string) *l } } +type ReleaseCollectionOption interface { + Request() *milvuspb.ReleaseCollectionRequest +} + +var _ ReleaseCollectionOption = (*releaseCollectionOption)(nil) + +type releaseCollectionOption struct { + collectionName string +} + +func (opt *releaseCollectionOption) Request() *milvuspb.ReleaseCollectionRequest { + return &milvuspb.ReleaseCollectionRequest{ + CollectionName: opt.collectionName, + } +} + +func NewReleaseCollectionOption(collectionName string) *releaseCollectionOption { + return &releaseCollectionOption{ + collectionName: collectionName, + } +} + +type ReleasePartitionsOption interface { + Request() *milvuspb.ReleasePartitionsRequest +} + +var _ ReleasePartitionsOption = (*releasePartitionsOption)(nil) + +type releasePartitionsOption struct { + collectionName string + partitionNames []string +} + +func (opt *releasePartitionsOption) Request() *milvuspb.ReleasePartitionsRequest { + return &milvuspb.ReleasePartitionsRequest{ + CollectionName: opt.collectionName, + PartitionNames: opt.partitionNames, + } +} + +func NewReleasePartitionsOptions(collectionName string, partitionNames ...string) *releasePartitionsOption { + return &releasePartitionsOption{ + collectionName: collectionName, + partitionNames: partitionNames, + } +} + type FlushOption interface { Request() *milvuspb.FlushRequest CollectionName() string diff --git a/client/maintenance_test.go b/client/maintenance_test.go index 0efcd449dfc41..4ccac9bc84fce 100644 --- a/client/maintenance_test.go +++ b/client/maintenance_test.go @@ -162,6 +162,57 @@ func (s *MaintenanceSuite) TestLoadPartitions() { }) } +func (s *MaintenanceSuite) TestReleaseCollection() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + s.mock.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, rcr *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error) { + s.Equal(collectionName, rcr.GetCollectionName()) + return merr.Success(), nil + }).Once() + + err := s.client.ReleaseCollection(ctx, NewReleaseCollectionOption(collectionName)) + s.NoError(err) + }) + + s.Run("failure", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + s.mock.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.ReleaseCollection(ctx, NewReleaseCollectionOption(collectionName)) + s.Error(err) + }) +} + +func (s *MaintenanceSuite) TestReleasePartitions() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + partitionName := fmt.Sprintf("part_%s", s.randString(6)) + s.mock.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, rpr *milvuspb.ReleasePartitionsRequest) (*commonpb.Status, error) { + s.Equal(collectionName, rpr.GetCollectionName()) + s.ElementsMatch([]string{partitionName}, rpr.GetPartitionNames()) + return merr.Success(), nil + }).Once() + + err := s.client.ReleasePartitions(ctx, NewReleasePartitionsOptions(collectionName, partitionName)) + s.NoError(err) + }) + + s.Run("failure", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + partitionName := fmt.Sprintf("part_%s", s.randString(6)) + s.mock.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() + + err := s.client.ReleasePartitions(ctx, NewReleasePartitionsOptions(collectionName, partitionName)) + s.Error(err) + }) +} + func (s *MaintenanceSuite) TestFlush() { ctx, cancel := context.WithCancel(context.Background()) defer cancel()