diff --git a/client/bulkwriter/bulk_import.go b/client/bulkwriter/bulk_import.go new file mode 100644 index 0000000000000..16813600b5eb9 --- /dev/null +++ b/client/bulkwriter/bulk_import.go @@ -0,0 +1,320 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bulkwriter + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +// ResponseBase is the common milvus restful response struct. +type ResponseBase struct { + Status int `json:"status"` + Message string `json:"message"` +} + +// CheckStatus checks the response status and return error if not ok. +func (b ResponseBase) CheckStatus() error { + if b.Status != 0 { + return fmt.Errorf("bulk import return error, status: %d, message: %s", b.Status, b.Message) + } + return nil +} + +type BulkImportOption struct { + // milvus params + URL string `json:"-"` + CollectionName string `json:"collectionName"` + // optional in cloud api, use object url instead + Files [][]string `json:"files,omitempty"` + // optional params + PartitionName string `json:"partitionName,omitempty"` + APIKey string `json:"-"` + // cloud extra params + ObjectURL string `json:"objectUrl,omitempty"` + ClusterID string `json:"clusterId,omitempty"` + AccessKey string `json:"accessKey,omitempty"` + SecretKey string `json:"secretKey,omitempty"` + + // reserved extra options + Options map[string]string `json:"options,omitempty"` +} + +func (opt *BulkImportOption) GetRequest() ([]byte, error) { + return json.Marshal(opt) +} + +func (opt *BulkImportOption) WithPartition(partitionName string) *BulkImportOption { + opt.PartitionName = partitionName + return opt +} + +func (opt *BulkImportOption) WithAPIKey(key string) *BulkImportOption { + opt.APIKey = key + return opt +} + +func (opt *BulkImportOption) WithOption(key, value string) *BulkImportOption { + if opt.Options == nil { + opt.Options = make(map[string]string) + } + opt.Options[key] = value + return opt +} + +// NewBulkImportOption returns BulkImportOption for Milvus bulk import API. +func NewBulkImportOption(uri string, + collectionName string, + files [][]string, +) *BulkImportOption { + return &BulkImportOption{ + URL: uri, + CollectionName: collectionName, + Files: files, + } +} + +// NewCloudBulkImportOption returns import option for cloud import API. +func NewCloudBulkImportOption(uri string, + collectionName string, + apiKey string, + objectURL string, + clusterID string, + accessKey string, + secretKey string, +) *BulkImportOption { + return &BulkImportOption{ + URL: uri, + CollectionName: collectionName, + APIKey: apiKey, + ObjectURL: objectURL, + ClusterID: clusterID, + AccessKey: accessKey, + SecretKey: secretKey, + } +} + +type BulkImportResponse struct { + ResponseBase + Data struct { + JobID string `json:"jobId"` + } `json:"data"` +} + +// BulkImport is the API wrapper for restful import API. +func BulkImport(ctx context.Context, option *BulkImportOption) (*BulkImportResponse, error) { + url := option.URL + "/v2/vectordb/jobs/import/create" + bs, err := option.GetRequest() + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bs)) + if err != nil { + return nil, err + } + if option.APIKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", option.APIKey)) + } + + result := &BulkImportResponse{} + err = doPostRequest(req, result) + if err != nil { + return nil, err + } + return result, result.CheckStatus() +} + +type ListImportJobsOption struct { + URL string `json:"-"` + CollectionName string `json:"collectionName"` + ClusterID string `json:"clusterId,omitempty"` + APIKey string `json:"-"` + PageSize int `json:"pageSize,omitempty"` + CurrentPage int `json:"currentPage,omitempty"` +} + +func (opt *ListImportJobsOption) WithAPIKey(key string) *ListImportJobsOption { + opt.APIKey = key + return opt +} + +func (opt *ListImportJobsOption) WithPageSize(pageSize int) *ListImportJobsOption { + opt.PageSize = pageSize + return opt +} + +func (opt *ListImportJobsOption) WithCurrentPage(currentPage int) *ListImportJobsOption { + opt.CurrentPage = currentPage + return opt +} + +func (opt *ListImportJobsOption) GetRequest() ([]byte, error) { + return json.Marshal(opt) +} + +func NewListImportJobsOption(uri string, collectionName string) *ListImportJobsOption { + return &ListImportJobsOption{ + URL: uri, + CollectionName: collectionName, + CurrentPage: 1, + PageSize: 10, + } +} + +type ListImportJobsResponse struct { + ResponseBase + Data *ListImportJobData `json:"data"` +} + +type ListImportJobData struct { + Records []*ImportJobRecord `json:"records"` +} + +type ImportJobRecord struct { + JobID string `json:"jobId"` + CollectionName string `json:"collectionName"` + State string `json:"state"` + Progress int64 `json:"progress"` + Reason string `json:"reason"` +} + +func ListImportJobs(ctx context.Context, option *ListImportJobsOption) (*ListImportJobsResponse, error) { + url := option.URL + "/v2/vectordb/jobs/import/list" + bs, err := option.GetRequest() + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bs)) + if err != nil { + return nil, err + } + if option.APIKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", option.APIKey)) + } + + result := &ListImportJobsResponse{} + if err := doPostRequest(req, result); err != nil { + return nil, err + } + + return result, result.CheckStatus() +} + +type GetImportProgressOption struct { + URL string `json:"-"` + JobID string `json:"jobId"` + // optional + ClusterID string `json:"clusterId"` + APIKey string `json:"-"` +} + +func (opt *GetImportProgressOption) GetRequest() ([]byte, error) { + return json.Marshal(opt) +} + +func (opt *GetImportProgressOption) WithAPIKey(key string) *GetImportProgressOption { + opt.APIKey = key + return opt +} + +func NewGetImportProgressOption(uri string, jobID string) *GetImportProgressOption { + return &GetImportProgressOption{ + URL: uri, + JobID: jobID, + } +} + +func NewCloudGetImportProgressOption(uri string, jobID string, apiKey string, clusterID string) *GetImportProgressOption { + return &GetImportProgressOption{ + URL: uri, + JobID: jobID, + APIKey: apiKey, + ClusterID: clusterID, + } +} + +type GetImportProgressResponse struct { + ResponseBase + Data *ImportProgressData `json:"data"` +} + +type ImportProgressData struct { + CollectionName string `json:"collectionName"` + JobID string `json:"jobId"` + CompleteTime string `json:"completeTime"` + State string `json:"state"` + Progress int64 `json:"progress"` + ImportedRows int64 `json:"importedRows"` + TotalRows int64 `json:"totalRows"` + Reason string `json:"reason"` + FileSize int64 `json:"fileSize"` + Details []*ImportProgressDetail `json:"details"` +} + +type ImportProgressDetail struct { + FileName string `json:"fileName"` + FileSize int64 `json:"fileSize"` + Progress int64 `json:"progress"` + CompleteTime string `json:"completeTime"` + State string `json:"state"` + ImportedRows int64 `json:"importedRows"` + TotalRows int64 `json:"totalRows"` +} + +func GetImportProgress(ctx context.Context, option *GetImportProgressOption) (*GetImportProgressResponse, error) { + url := option.URL + "/v2/vectordb/jobs/import/describe" + + bs, err := option.GetRequest() + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(bs)) + if err != nil { + return nil, err + } + if option.APIKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", option.APIKey)) + } + + result := &GetImportProgressResponse{} + if err := doPostRequest(req, result); err != nil { + return nil, err + } + return result, result.CheckStatus() +} + +func doPostRequest(req *http.Request, response any) error { + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + respData, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + return json.Unmarshal(respData, response) +} diff --git a/client/bulkwriter/bulk_import_test.go b/client/bulkwriter/bulk_import_test.go new file mode 100644 index 0000000000000..7f3fb26d4f797 --- /dev/null +++ b/client/bulkwriter/bulk_import_test.go @@ -0,0 +1,168 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bulkwriter + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/suite" +) + +type BulkImportSuite struct { + suite.Suite +} + +func (s *BulkImportSuite) TestBulkImport() { + s.Run("normal_case", func() { + svr := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + authHeader := req.Header.Get("Authorization") + s.Equal("Bearer root:Milvus", authHeader) + s.True(strings.Contains(req.URL.Path, "/v2/vectordb/jobs/import/create")) + rw.Write([]byte(`{"status":0, "data":{"jobId": "123"}}`)) + })) + defer svr.Close() + + resp, err := BulkImport(context.Background(), + NewBulkImportOption(svr.URL, "hello_milvus", [][]string{{"files/a.json", "files/b.json"}}). + WithPartition("_default"). + WithOption("backup", "true"). + WithAPIKey("root:Milvus"), + ) + s.NoError(err) + s.EqualValues(0, resp.Status) + s.Equal("123", resp.Data.JobID) + }) + + s.Run("svr_error", func() { + svr := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + // rw. + rw.WriteHeader(http.StatusInternalServerError) + rw.Write([]byte(`interal server error`)) + })) + defer svr.Close() + + _, err := BulkImport(context.Background(), NewBulkImportOption(svr.URL, "hello_milvus", [][]string{{"files/a.json", "files/b.json"}})) + s.Error(err) + }) + + s.Run("status_error", func() { + svr := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + s.True(strings.Contains(req.URL.Path, "/v2/vectordb/jobs/import/create")) + rw.Write([]byte(`{"status":1100, "message": "import job failed"}`)) + })) + defer svr.Close() + + _, err := BulkImport(context.Background(), NewBulkImportOption(svr.URL, "hello_milvus", [][]string{{"files/a.json", "files/b.json"}})) + s.Error(err) + }) + + s.Run("server_closed", func() { + svr2 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})) + svr2.Close() + _, err := BulkImport(context.Background(), NewBulkImportOption(svr2.URL, "hello_milvus", [][]string{{"files/a.json", "files/b.json"}})) + s.Error(err) + }) +} + +func (s *BulkImportSuite) TestListImportJobs() { + s.Run("normal_case", func() { + svr := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + authHeader := req.Header.Get("Authorization") + s.Equal("Bearer root:Milvus", authHeader) + s.True(strings.Contains(req.URL.Path, "/v2/vectordb/jobs/import/list")) + rw.Write([]byte(`{"status":0, "data":{"records": [{"jobID": "abc", "collectionName": "hello_milvus", "state":"Importing", "progress": 50}]}}`)) + })) + defer svr.Close() + + resp, err := ListImportJobs(context.Background(), + NewListImportJobsOption(svr.URL, "hello_milvus"). + WithPageSize(10). + WithCurrentPage(1). + WithAPIKey("root:Milvus"), + ) + s.NoError(err) + s.EqualValues(0, resp.Status) + if s.Len(resp.Data.Records, 1) { + record := resp.Data.Records[0] + s.Equal("abc", record.JobID) + s.Equal("hello_milvus", record.CollectionName) + s.Equal("Importing", record.State) + s.EqualValues(50, record.Progress) + } + }) + + s.Run("svr_error", func() { + svr := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusInternalServerError) + })) + defer svr.Close() + + _, err := ListImportJobs(context.Background(), NewListImportJobsOption(svr.URL, "hello_milvus")) + s.Error(err) + }) +} + +func (s *BulkImportSuite) TestGetImportProgress() { + s.Run("normal_case", func() { + svr := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + authHeader := req.Header.Get("Authorization") + s.Equal("Bearer root:Milvus", authHeader) + s.True(strings.Contains(req.URL.Path, "/v2/vectordb/jobs/import/describe")) + rw.Write([]byte(`{"status":0, "data":{"collectionName": "hello_milvus","jobId":"abc", "state":"Importing", "progress": 50, "importedRows": 20000,"totalRows": 40000, "details":[{"fileName": "files/a.json", "fileSize": 64312, "progress": 100, "state": "Completed"}, {"fileName":"files/b.json", "fileSize":52912, "progress":0, "state":"Importing"}]}}`)) + })) + defer svr.Close() + + resp, err := GetImportProgress(context.Background(), + NewGetImportProgressOption(svr.URL, "abc"). + WithAPIKey("root:Milvus"), + ) + s.NoError(err) + s.EqualValues(0, resp.Status) + s.Equal("hello_milvus", resp.Data.CollectionName) + s.Equal("abc", resp.Data.JobID) + s.Equal("Importing", resp.Data.State) + s.EqualValues(50, resp.Data.Progress) + if s.Len(resp.Data.Details, 2) { + detail1 := resp.Data.Details[0] + s.Equal("files/a.json", detail1.FileName) + s.Equal("Completed", detail1.State) + s.EqualValues(100, detail1.Progress) + detail2 := resp.Data.Details[1] + s.Equal("files/b.json", detail2.FileName) + s.Equal("Importing", detail2.State) + s.EqualValues(0, detail2.Progress) + } + }) + + s.Run("svr_error", func() { + svr := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusInternalServerError) + })) + defer svr.Close() + + _, err := GetImportProgress(context.Background(), NewGetImportProgressOption(svr.URL, "abc")) + s.Error(err) + }) +} + +func TestBulkImportAPIs(t *testing.T) { + suite.Run(t, new(BulkImportSuite)) +}