From e7e5e7cef69f3d342f67b4e0b7b0df4fc4e20e21 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Tue, 17 Dec 2024 17:32:44 +0800 Subject: [PATCH] Add vertexai Signed-off-by: junjie.jiang --- internal/models/openai/openai_embedding.go | 6 +- .../models/openai/openai_embedding_test.go | 62 ++++- internal/models/utils/embedding_util.go | 6 +- .../vertexai/vertexai_text_embedding.go | 163 +++++++++++++ .../vertexai/vertexai_text_embedding_test.go | 90 +++++++ internal/proxy/task_search.go | 1 - internal/proxy/task_test.go | 44 ++++ internal/proxy/util.go | 5 + .../util/function/ali_embedding_provider.go | 27 ++- .../alitext_embedding_provider_test.go | 8 +- .../function/bedrock_embedding_provider.go | 10 +- .../bedrock_text_embedding_provider_test.go | 6 +- internal/util/function/common.go | 32 ++- internal/util/function/function_base.go | 10 +- internal/util/function/function_executor.go | 9 + .../util/function/mock_embedding_service.go | 35 +++ .../function/openai_embedding_provider.go | 17 +- .../openai_text_embedding_provider_test.go | 8 +- .../util/function/text_embedding_function.go | 19 +- .../function/vertexai_embedding_provider.go | 221 ++++++++++++++++++ .../vertexai_embedding_provider_test.go | 170 ++++++++++++++ 21 files changed, 892 insertions(+), 57 deletions(-) create mode 100644 internal/models/vertexai/vertexai_text_embedding.go create mode 100644 internal/models/vertexai/vertexai_text_embedding_test.go create mode 100644 internal/util/function/vertexai_embedding_provider.go create mode 100644 internal/util/function/vertexai_embedding_provider_test.go diff --git a/internal/models/openai/openai_embedding.go b/internal/models/openai/openai_embedding.go index ee1b7cc03330b..bb6f88be0cd18 100644 --- a/internal/models/openai/openai_embedding.go +++ b/internal/models/openai/openai_embedding.go @@ -215,14 +215,14 @@ func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, params.Add("api-version", c.apiVersion) base.RawQuery = params.Encode() - ctx, cancel := context.WithTimeout(context.Background(), timeoutSec) + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, base.String(), bytes.NewBuffer(data)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + req.Header.Set("api-key", c.apiKey) body, err := utils.RetrySend(req, 3) if err != nil { return nil, err diff --git a/internal/models/openai/openai_embedding_test.go b/internal/models/openai/openai_embedding_test.go index c935e7c2cfbff..87f44b4ea6308 100644 --- a/internal/models/openai/openai_embedding_test.go +++ b/internal/models/openai/openai_embedding_test.go @@ -72,7 +72,20 @@ func TestEmbeddingOK(t *testing.T) { } ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) + if r.URL.Path == "/" { + if r.Header["Authorization"][0] != "" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusBadRequest) + } + } else { + if r.Header["Api-Key"][0] != "" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusBadRequest) + } + } + data, _ := json.Marshal(res) w.Write(data) })) @@ -84,7 +97,15 @@ func TestEmbeddingOK(t *testing.T) { c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) - _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Data[0].Index, 0) + assert.Equal(t, ret.Data[1].Index, 1) + } + { + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err == nil) assert.Equal(t, ret.Data[0].Index, 0) @@ -148,6 +169,20 @@ func TestEmbeddingRetry(t *testing.T) { assert.Equal(t, ret.Data[2], res.Data[0]) assert.Equal(t, atomic.LoadInt32(&count), int32(2)) } + { + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Usage, res.Usage) + assert.Equal(t, ret.Object, res.Object) + assert.Equal(t, ret.Model, res.Model) + assert.Equal(t, ret.Data[0], res.Data[1]) + assert.Equal(t, ret.Data[1], res.Data[2]) + assert.Equal(t, ret.Data[2], res.Data[0]) + assert.Equal(t, atomic.LoadInt32(&count), int32(2)) + } } func TestEmbeddingFailed(t *testing.T) { @@ -161,6 +196,7 @@ func TestEmbeddingFailed(t *testing.T) { url := ts.URL { + atomic.StoreInt32(&count, 0) c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) @@ -168,6 +204,15 @@ func TestEmbeddingFailed(t *testing.T) { assert.True(t, err != nil) assert.Equal(t, atomic.LoadInt32(&count), int32(3)) } + { + atomic.StoreInt32(&count, 0) + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err != nil) + assert.Equal(t, atomic.LoadInt32(&count), int32(3)) + } } func TestTimeout(t *testing.T) { @@ -182,6 +227,7 @@ func TestTimeout(t *testing.T) { url := ts.URL { + atomic.StoreInt32(&st, 0) c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) @@ -191,4 +237,16 @@ func TestTimeout(t *testing.T) { time.Sleep(3 * time.Second) assert.Equal(t, atomic.LoadInt32(&st), int32(1)) } + + { + atomic.StoreInt32(&st, 0) + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1) + assert.True(t, err != nil) + assert.Equal(t, atomic.LoadInt32(&st), int32(0)) + time.Sleep(3 * time.Second) + assert.Equal(t, atomic.LoadInt32(&st), int32(1)) + } } diff --git a/internal/models/utils/embedding_util.go b/internal/models/utils/embedding_util.go index e67dcf0a4bd9c..1d6e7d916cab2 100644 --- a/internal/models/utils/embedding_util.go +++ b/internal/models/utils/embedding_util.go @@ -41,11 +41,13 @@ func send(req *http.Request) ([]byte, error) { } func RetrySend(req *http.Request, maxRetries int) ([]byte, error) { + var err error + var res []byte for i := 0; i < maxRetries; i++ { - res, err := send(req) + res, err = send(req) if err == nil { return res, nil } } - return nil, nil + return nil, err } diff --git a/internal/models/vertexai/vertexai_text_embedding.go b/internal/models/vertexai/vertexai_text_embedding.go new file mode 100644 index 0000000000000..3842824616214 --- /dev/null +++ b/internal/models/vertexai/vertexai_text_embedding.go @@ -0,0 +1,163 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vertexai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/milvus-io/milvus/internal/models/utils" + + "golang.org/x/oauth2/google" +) + +type Instance struct { + TaskType string `json:"task_type,omitempty"` + Content string `json:"content"` +} + +type Parameters struct { + OutputDimensionality int64 `json:"outputDimensionality,omitempty"` +} + +type EmbeddingRequest struct { + Instances []Instance `json:"instances"` + Parameters Parameters `json:"parameters,omitempty"` +} + +type Statistics struct { + Truncated bool `json:"truncated"` + TokenCount int `json:"token_count"` +} + +type Embeddings struct { + Statistics Statistics `json:"statistics"` + Values []float32 `json:"values"` +} + +type Prediction struct { + Embeddings Embeddings `json:"embeddings"` +} + +type Metadata struct { + BillableCharacterCount int `json:"billableCharacterCount"` +} + +type EmbeddingResponse struct { + Predictions []Prediction `json:"predictions"` + Metadata Metadata `json:"metadata"` +} + +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + RequestID string `json:"request_id"` +} + +type VertexAIEmbedding struct { + url string + jsonKey []byte + scopes string + token string +} + +func NewVertexAIEmbedding(url string, jsonKey []byte, scopes string, token string) *VertexAIEmbedding { + return &VertexAIEmbedding{ + url: url, + jsonKey: jsonKey, + scopes: scopes, + token: token, + } +} + +func (c *VertexAIEmbedding) Check() error { + if c.url == "" { + return fmt.Errorf("VertexAI embedding url is empty") + } + if len(c.jsonKey) == 0 { + return fmt.Errorf("jsonKey is empty") + } + if c.scopes == "" { + return fmt.Errorf("Scopes param is empty") + } + return nil +} + +func (c *VertexAIEmbedding) getAccessToken() (string, error) { + ctx := context.Background() + creds, err := google.CredentialsFromJSON(ctx, c.jsonKey, c.scopes) + if err != nil { + return "", fmt.Errorf("Failed to find credentials: %v", err) + } + token, err := creds.TokenSource.Token() + if err != nil { + return "", fmt.Errorf("Failed to get token: %v", err) + } + return token.AccessToken, nil +} + +func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int64, taskType string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + var r EmbeddingRequest + for _, text := range texts { + r.Instances = append(r.Instances, Instance{TaskType: taskType, Content: text}) + } + if dim != 0 { + r.Parameters.OutputDimensionality = dim + } + + data, err := json.Marshal(r) + if err != nil { + return nil, err + } + + if timeoutSec <= 0 { + timeoutSec = 30 + } + + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + var token string + if c.token != "" { + token = c.token + } else { + token, err = c.getAccessToken() + if err != nil { + return nil, err + } + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + body, err := utils.RetrySend(req, 3) + if err != nil { + return nil, err + } + var res EmbeddingResponse + err = json.Unmarshal(body, &res) + if err != nil { + return nil, err + } + return &res, err +} diff --git a/internal/models/vertexai/vertexai_text_embedding_test.go b/internal/models/vertexai/vertexai_text_embedding_test.go new file mode 100644 index 0000000000000..f138d659a3ea4 --- /dev/null +++ b/internal/models/vertexai/vertexai_text_embedding_test.go @@ -0,0 +1,90 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vertexai + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingClientCheck(t *testing.T) { + mockJsonKey := []byte{1, 2, 3} + { + c := NewVertexAIEmbedding("mock_url", []byte{}, "mock_scopes", "") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewVertexAIEmbedding("", mockJsonKey, "", "") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewVertexAIEmbedding("mock_url", mockJsonKey, "mock_scopes", "") + err := c.Check() + assert.True(t, err == nil) + } +} + +func TestEmbeddingOK(t *testing.T) { + var res EmbeddingResponse + repStr := `{"predictions": [{"embeddings": {"statistics": {"truncated": false, "token_count": 4}, "values": [-0.028420744463801384, 0.037183016538619995]}}, {"embeddings": {"statistics": {"truncated": false, "token_count": 8}, "values": [-0.04367655888199806, 0.03777721896767616, 0.0158217903226614]}}], "metadata": {"billableCharacterCount": 27}}` + err := json.Unmarshal([]byte(repStr), &res) + assert.NoError(t, err) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token") + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-005", []string{"sentence"}, 0, "query", 0) + assert.True(t, err == nil) + } +} + +func TestEmbeddingFailed(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token") + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", 0) + assert.True(t, err != nil) + } +} diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 8ccb0fb502301..e59f831186f3c 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -438,7 +438,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect() } - var err error t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams()) if err != nil { log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err)) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 019063bbbc9f1..7982f36731dbe 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "encoding/binary" + "fmt" "math/rand" "strconv" "testing" @@ -1022,6 +1023,49 @@ func TestCreateCollectionTask(t *testing.T) { err = task2.PreExecute(ctx) assert.Error(t, err) }) + + t.Run("collection with embedding function ", func(t *testing.T) { + fmt.Println(schema) + schema.Functions = []*schemapb.FunctionSchema{ + { + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldNames: []string{varCharField}, + OutputFieldNames: []string{floatVecField}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: "provider", Value: "openai"}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + }, + }, + } + + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + task2 := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + err = task2.OnEnqueue() + assert.NoError(t, err) + + err = task2.PreExecute(ctx) + assert.NoError(t, err) + }) } func TestHasCollectionTask(t *testing.T) { diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 63b2a9d8f74f2..2785079dbdfb8 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -39,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/internal/util/indexparamcheck" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" @@ -705,6 +706,10 @@ func validateFunction(coll *schemapb.CollectionSchema) error { return err } } + + if err := function.CheckFunctions(coll); err != nil { + return err + } return nil } diff --git a/internal/util/function/ali_embedding_provider.go b/internal/util/function/ali_embedding_provider.go index d106426f5c020..920041afadbc1 100644 --- a/internal/util/function/ali_embedding_provider.go +++ b/internal/util/function/ali_embedding_provider.go @@ -36,6 +36,7 @@ type AliEmbeddingProvider struct { client *ali.AliDashScopeEmbedding modelName string embedDimParam int64 + outputType string maxBatch int timeoutSec int @@ -43,19 +44,15 @@ type AliEmbeddingProvider struct { func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) { if apiKey == "" { - apiKey = os.Getenv("DASHSCOPE_API_KEY") + apiKey = os.Getenv(dashscopeApiKey) } if apiKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the DASHSCOPE_API_KEY environment variable in the Milvus service.") + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeApiKey) } if url == "" { url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" } - if url == "" { - return nil, fmt.Errorf("Must provide `url` arguments or configure the DASHSCOPE_ENDPOINT environment variable in the Milvus service") - } - c := ali.NewAliDashScopeEmbeddingClient(apiKey, url) return c, nil } @@ -93,6 +90,7 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", modelName, TextEmbeddingV1, TextEmbeddingV2, TextEmbeddingV3) } + c, err := createAliEmbeddingClient(apiKey, url) if err != nil { return nil, err @@ -102,8 +100,10 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio fieldDim: fieldDim, modelName: modelName, embedDimParam: dim, - maxBatch: 25, - timeoutSec: 30, + // TextEmbedding only supports dense embedding + outputType: "dense", + maxBatch: 25, + timeoutSec: 30, } return &provider, nil } @@ -116,19 +116,24 @@ func (provider *AliEmbeddingProvider) FieldDim() int64 { return provider.fieldDim } -func (provider *AliEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) { +func (provider *AliEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, mode string) ([][]float32, error) { numRows := len(texts) if batchLimit && numRows > provider.MaxBatch() { return nil, fmt.Errorf("Ali text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) } - + var textType string + if mode == SearchMode { + textType = "query" + } else { + textType = "document" + } data := make([][]float32, 0, numRows) for i := 0; i < numRows; i += provider.maxBatch { end := i + provider.maxBatch if end > numRows { end = numRows } - resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), "query", "dense", time.Duration(provider.timeoutSec)) + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, time.Duration(provider.timeoutSec)) if err != nil { return nil, err } diff --git a/internal/util/function/alitext_embedding_provider_test.go b/internal/util/function/alitext_embedding_provider_test.go index 100e42c31e918..f4a36bd2635a4 100644 --- a/internal/util/function/alitext_embedding_provider_test.go +++ b/internal/util/function/alitext_embedding_provider_test.go @@ -88,7 +88,7 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbedding() { s.NoError(err) { data := []string{"sentence"} - ret, err2 := provder.CallEmbedding(data, false) + ret, err2 := provder.CallEmbedding(data, false, InsertMode) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(4, len(ret[0])) @@ -96,7 +96,7 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbedding() { } { data := []string{"sentence 1", "sentence 2", "sentence 3"} - ret, _ := provder.CallEmbedding(data, false) + ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } @@ -130,7 +130,7 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { // embedding dim not match data := []string{"sentence", "sentence"} - _, err2 := provder.CallEmbedding(data, false) + _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) } @@ -159,7 +159,7 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { // embedding dim not match data := []string{"sentence", "sentence2"} - _, err2 := provder.CallEmbedding(data, false) + _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) } diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go index 5a8367fe23e49..a9a6d56e95be9 100644 --- a/internal/util/function/bedrock_embedding_provider.go +++ b/internal/util/function/bedrock_embedding_provider.go @@ -53,17 +53,17 @@ type BedrockEmbeddingProvider struct { func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) { if awsAccessKeyId == "" { - awsAccessKeyId = os.Getenv("BEDROCK_ACCESS_KEY_ID") + awsAccessKeyId = os.Getenv(bedrockAccessKeyId) } if awsAccessKeyId == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the BEDROCK_ACCESS_KEY_ID environment variable in the Milvus service.") + return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId) } if awsSecretAccessKey == "" { - awsSecretAccessKey = os.Getenv("BEDROCK_SECRET_ACCESS_KEY") + awsSecretAccessKey = os.Getenv(bedrockSecretAccessKey) } if awsSecretAccessKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the BEDROCK_SECRET_ACCESS_KEY environment variable in the Milvus service.") + return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSecretAccessKey) } if region == "" { return nil, fmt.Errorf("Missing region. Please pass `region` param.") @@ -154,7 +154,7 @@ func (provider *BedrockEmbeddingProvider) FieldDim() int64 { return 5 * provider.fieldDim } -func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) { +func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, _ string) ([][]float32, error) { numRows := len(texts) if batchLimit && numRows > provider.MaxBatch() { return nil, fmt.Errorf("Bedrock text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) diff --git a/internal/util/function/bedrock_text_embedding_provider_test.go b/internal/util/function/bedrock_text_embedding_provider_test.go index 8ba9b7763167f..9d74f7e2604cc 100644 --- a/internal/util/function/bedrock_text_embedding_provider_test.go +++ b/internal/util/function/bedrock_text_embedding_provider_test.go @@ -79,7 +79,7 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() { s.NoError(err) { data := []string{"sentence"} - ret, err2 := provder.CallEmbedding(data, false) + ret, err2 := provder.CallEmbedding(data, false, InsertMode) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(4, len(ret[0])) @@ -87,7 +87,7 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() { } { data := []string{"sentence 1", "sentence 2", "sentence 3"} - ret, _ := provder.CallEmbedding(data, false) + ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}}, ret) } @@ -101,7 +101,7 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { // embedding dim not match data := []string{"sentence", "sentence"} - _, err2 := provder.CallEmbedding(data, false) + _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) } diff --git a/internal/util/function/common.go b/internal/util/function/common.go index 5063d895283ff..56da30e5ed42f 100644 --- a/internal/util/function/common.go +++ b/internal/util/function/common.go @@ -18,6 +18,11 @@ package function +const ( + InsertMode string = "Insert" + SearchMode string = "Search" +) + // common params const ( modelNameParamKey string = "model_name" @@ -30,7 +35,9 @@ const ( const ( TextEmbeddingV1 string = "text-embedding-v1" TextEmbeddingV2 string = "text-embedding-v2" - TextEmbeddingV3 string = "text-embedding-v1" + TextEmbeddingV3 string = "text-embedding-v3" + + dashscopeApiKey string = "MILVUS_DASHSCOPE_API_KEY" ) // openai/azure text embedding @@ -39,9 +46,12 @@ const ( TextEmbeddingAda002 string = "text-embedding-ada-002" TextEmbedding3Small string = "text-embedding-3-small" TextEmbedding3Large string = "text-embedding-3-large" -) -const ( + openaiApiKey string = "MILVUSAI_OPENAI_API_KEY" + + azureOpenaiApiKey string = "MILVUSAI_AZURE_OPENAI_API_KEY" + azureOpenaiEndpoint string = "MILVUSAI_AZURE_OPENAI_ENDPOINT" + userParamKey string = "user" ) @@ -53,4 +63,20 @@ const ( awsSecretAccessKeyParamKey string = "aws_secret_access_key" regionParamKey string = "regin" normalizeParamKey string = "normalize" + + bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID" + bedrockSecretAccessKey string = "MILVUSAI_BEDROCK_SECRET_ACCESS_KEY" +) + +// vertexAI + +const ( + locationParamKey string = "location" + projectIDParamKey string = "projectid" + taskTypeParamKey string = "task" + + textEmbedding005 string = "text-embedding-005" + textMultilingualEmbedding002 string = "text-multilingual-embedding-002" + + vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS" ) diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index ac501a96dcc95..209b4788180dc 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -29,10 +29,10 @@ type FunctionBase struct { outputFields []*schemapb.FieldSchema } -func NewFunctionBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*FunctionBase, error) { +func NewFunctionBase(coll *schemapb.CollectionSchema, f_schema *schemapb.FunctionSchema) (*FunctionBase, error) { var base FunctionBase - base.schema = schema - for _, field_id := range schema.GetOutputFieldIds() { + base.schema = f_schema + for _, field_id := range f_schema.GetOutputFieldIds() { for _, field := range coll.GetFields() { if field.GetFieldID() == field_id { base.outputFields = append(base.outputFields, field) @@ -41,9 +41,9 @@ func NewFunctionBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionS } } - if len(base.outputFields) != len(schema.GetOutputFieldIds()) { + if len(base.outputFields) != len(f_schema.GetOutputFieldIds()) { return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema", - coll.Name, schema.Name) + coll.Name, f_schema.Name) } return &base, nil } diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index 221826e0b7378..6f2469cca9173 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -61,6 +61,15 @@ func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSc } } +func CheckFunctions(schema *schemapb.CollectionSchema) error { + for _, f_schema := range schema.Functions { + if _, err := createFunction(schema, f_schema); err != nil { + return err + } + } + return nil +} + func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { // If the function's outputs exists in outputIDs, then create the function // when outputIDs is empty, create all functions diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go index f48315c7eff41..4cb181a7a0c4f 100644 --- a/internal/util/function/mock_embedding_service.go +++ b/internal/util/function/mock_embedding_service.go @@ -28,6 +28,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/milvus-io/milvus/internal/models/ali" "github.com/milvus-io/milvus/internal/models/openai" + "github.com/milvus-io/milvus/internal/models/vertexai" ) func mockEmbedding(texts []string, dim int) [][]float32 { @@ -94,6 +95,40 @@ func CreateAliEmbeddingServer() *httptest.Server { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) + })) + return ts +} + +func CreateVertexAIEmbeddingServer() *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req vertexai.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + var texts []string + for _, item := range req.Instances { + texts = append(texts, item.Content) + } + embs := mockEmbedding(texts, int(req.Parameters.OutputDimensionality)) + var res vertexai.EmbeddingResponse + for i := 0; i < len(req.Instances); i++ { + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: embs[i], + }, + }) + } + + res.Metadata = vertexai.Metadata{ + BillableCharacterCount: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) })) return ts diff --git a/internal/util/function/openai_embedding_provider.go b/internal/util/function/openai_embedding_provider.go index f9fa78f64c8b6..32cfb945509f7 100644 --- a/internal/util/function/openai_embedding_provider.go +++ b/internal/util/function/openai_embedding_provider.go @@ -44,18 +44,15 @@ type OpenAIEmbeddingProvider struct { func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) { if apiKey == "" { - apiKey = os.Getenv("OPENAI_API_KEY") + apiKey = os.Getenv(openaiApiKey) } if apiKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the OPENAI_API_KEY environment variable in the Milvus service.") + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiApiKey) } if url == "" { url = "https://api.openai.com/v1/embeddings" } - if url == "" { - return nil, fmt.Errorf("Must provide `url` arguments or configure the OPENAI_ENDPOINT environment variable in the Milvus service") - } c := openai.NewOpenAIEmbeddingClient(apiKey, url) return c, nil @@ -63,17 +60,17 @@ func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbed func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureOpenAIEmbeddingClient, error) { if apiKey == "" { - apiKey = os.Getenv("AZURE_OPENAI_API_KEY") + apiKey = os.Getenv(azureOpenaiApiKey) } if apiKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the AZURE_OPENAI_API_KEY environment variable in the Milvus service") + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiApiKey) } if url == "" { - url = os.Getenv("AZURE_OPENAI_ENDPOINT") + url = os.Getenv(azureOpenaiEndpoint) } if url == "" { - return nil, fmt.Errorf("Must provide `url` arguments or configure the AZURE_OPENAI_ENDPOINT environment variable in the Milvus service") + return nil, fmt.Errorf("Must provide `url` arguments or configure the %s environment variable in the Milvus service", azureOpenaiEndpoint) } c := openai.NewAzureOpenAIEmbeddingClient(apiKey, url) return c, nil @@ -156,7 +153,7 @@ func (provider *OpenAIEmbeddingProvider) FieldDim() int64 { return provider.fieldDim } -func (provider *OpenAIEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) { +func (provider *OpenAIEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, _ string) ([][]float32, error) { numRows := len(texts) if batchLimit && numRows > provider.MaxBatch() { return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) diff --git a/internal/util/function/openai_text_embedding_provider_test.go b/internal/util/function/openai_text_embedding_provider_test.go index 7681161f9c3e5..7c3667822956f 100644 --- a/internal/util/function/openai_text_embedding_provider_test.go +++ b/internal/util/function/openai_text_embedding_provider_test.go @@ -90,7 +90,7 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() { s.NoError(err) { data := []string{"sentence"} - ret, err2 := provder.CallEmbedding(data, false) + ret, err2 := provder.CallEmbedding(data, false, InsertMode) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(4, len(ret[0])) @@ -98,7 +98,7 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() { } { data := []string{"sentence 1", "sentence 2", "sentence 3"} - ret, _ := provder.CallEmbedding(data, false) + ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } @@ -137,7 +137,7 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { // embedding dim not match data := []string{"sentence", "sentence"} - _, err2 := provder.CallEmbedding(data, false) + _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) } @@ -170,7 +170,7 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { // embedding dim not match data := []string{"sentence", "sentence2"} - _, err2 := provder.CallEmbedding(data, false) + _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) } diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index 274ddde0b8ba9..8ddf4893dfc44 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -38,11 +38,13 @@ const ( AzureOpenAIProvider string = "azure_openai" AliDashScopeProvider string = "dashscope" BedrockProvider string = "bedrock" + VertexAIProvider string = "vertexai" ) +// Text embedding for retrieval task type TextEmbeddingProvider interface { MaxBatch() int - CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) + CallEmbedding(texts []string, batchLimit bool, mode string) ([][]float32, error) FieldDim() int64 } @@ -120,6 +122,15 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s FunctionBase: *base, embProvider: embP, }, nil + case VertexAIProvider: + embP, err := NewVertextAIEmbeddingProvider(base.outputFields[0], functionSchema, nil) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil default: return nil, fmt.Errorf("Unsupported embedding service provider: [%s] , list of supported [%s, %s, %s, %s]", provider, OpenAIProvider, AzureOpenAIProvider, AliDashScopeProvider, BedrockProvider) } @@ -144,7 +155,7 @@ func (runner *TextEmebddingFunction) ProcessInsert(inputs []*schemapb.FieldData) return nil, fmt.Errorf("Input texts is empty") } - embds, err := runner.embProvider.CallEmbedding(texts, true) + embds, err := runner.embProvider.CallEmbedding(texts, true, InsertMode) if err != nil { return nil, err } @@ -173,7 +184,7 @@ func (runner *TextEmebddingFunction) ProcessInsert(inputs []*schemapb.FieldData) func (runner *TextEmebddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) { texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally - embds, err := runner.embProvider.CallEmbedding(texts, true) + embds, err := runner.embProvider.CallEmbedding(texts, true, SearchMode) if err != nil { return nil, err } @@ -194,7 +205,7 @@ func (runner *TextEmebddingFunction) ProcessBulkInsert(inputs []storage.FieldDat return nil, fmt.Errorf("Input texts is empty") } - embds, err := runner.embProvider.CallEmbedding(texts, false) + embds, err := runner.embProvider.CallEmbedding(texts, false, InsertMode) if err != nil { return nil, err } diff --git a/internal/util/function/vertexai_embedding_provider.go b/internal/util/function/vertexai_embedding_provider.go new file mode 100644 index 0000000000000..1d9c997571dcf --- /dev/null +++ b/internal/util/function/vertexai_embedding_provider.go @@ -0,0 +1,221 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/vertexai" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type vertexAIJsonKey struct { + jsonKey []byte + once sync.Once + initErr error +} + +var vtxKey vertexAIJsonKey + +func getVertexAIJsonKey() ([]byte, error) { + vtxKey.once.Do(func() { + jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv) + jsonKey, err := os.ReadFile(jsonKeyPath) + if err != nil { + vtxKey.initErr = fmt.Errorf("Read service account json file failed, %v", err) + return + } + vtxKey.jsonKey = jsonKey + }) + return vtxKey.jsonKey, vtxKey.initErr +} + +const ( + vertexAIDocRetrival string = "DOC_RETRIEVAL" + vertexAICodeRetrival string = "CODE_RETRIEVAL" + vertexAISTS string = "STS" +) + +func checkTask(modelName string, task string) error { + if task != vertexAIDocRetrival && task != vertexAICodeRetrival && task != vertexAISTS { + return fmt.Errorf("Unsupport task %s, the supported list: [%s, %s, %s]", task, vertexAIDocRetrival, vertexAICodeRetrival, vertexAISTS) + } + if modelName == textMultilingualEmbedding002 && task == vertexAICodeRetrival { + return fmt.Errorf("Model %s doesn't support %s task", textMultilingualEmbedding002, vertexAICodeRetrival) + } + return nil +} + +type VertextAIEmbeddingProvider struct { + fieldDim int64 + + client *vertexai.VertexAIEmbedding + modelName string + embedDimParam int64 + task string + + maxBatch int + timeoutSec int +} + +func createVertextAIEmbeddingClient(url string) (*vertexai.VertexAIEmbedding, error) { + jsonKey, err := getVertexAIJsonKey() + if err != nil { + return nil, err + } + c := vertexai.NewVertexAIEmbedding(url, jsonKey, "https://www.googleapis.com/auth/cloud-platform", "") + return c, nil +} + +func NewVertextAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding) (*VertextAIEmbeddingProvider, error) { + fieldDim, err := typeutil.GetDim(fieldSchema) + if err != nil { + return nil, err + } + var location, projectID, task, modelName string + var dim int64 + + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case modelNameParamKey: + modelName = param.Value + case dimParamKey: + dim, err = strconv.ParseInt(param.Value, 10, 64) + if err != nil { + return nil, fmt.Errorf("dim [%s] is not int", param.Value) + } + + if dim != 0 && dim != fieldDim { + return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + } + case locationParamKey: + location = param.Value + case projectIDParamKey: + projectID = param.Value + case taskTypeParamKey: + task = param.Value + default: + } + } + + if task == "" { + task = vertexAIDocRetrival + } + if err := checkTask(modelName, task); err != nil { + return nil, err + } + + if location == "" { + location = "us-central1" + } + + if modelName != textEmbedding005 && modelName != textMultilingualEmbedding002 { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s]", + modelName, textEmbedding005, textMultilingualEmbedding002) + } + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName) + var client *vertexai.VertexAIEmbedding + if c == nil { + client, err = createVertextAIEmbeddingClient(url) + if err != nil { + return nil, err + } + } else { + client = c + } + + provider := VertextAIEmbeddingProvider{ + fieldDim: fieldDim, + client: client, + modelName: modelName, + embedDimParam: dim, + task: task, + maxBatch: 128, + timeoutSec: 30, + } + return &provider, nil +} + +func (provider *VertextAIEmbeddingProvider) MaxBatch() int { + return 5 * provider.maxBatch +} + +func (provider *VertextAIEmbeddingProvider) FieldDim() int64 { + return provider.fieldDim +} + +func (provider *VertextAIEmbeddingProvider) getTaskType(mode string) string { + if mode == SearchMode { + switch provider.task { + case vertexAIDocRetrival: + return "RETRIEVAL_QUERY" + case vertexAICodeRetrival: + return "CODE_RETRIEVAL_QUERY" + case vertexAISTS: + return "SEMANTIC_SIMILARITY" + } + } else { + switch provider.task { + case vertexAIDocRetrival: + return "RETRIEVAL_DOCUMENT" + case vertexAICodeRetrival: + return "RETRIEVAL_DOCUMENT" + case vertexAISTS: + return "SEMANTIC_SIMILARITY" + } + } + return "" +} + +func (provider *VertextAIEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, mode string) ([][]float32, error) { + numRows := len(texts) + if batchLimit && numRows > provider.MaxBatch() { + return nil, fmt.Errorf("VertextAI text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) + } + + taskType := provider.getTaskType(mode) + data := make([][]float32, 0, numRows) + for i := 0; i < numRows; i += provider.maxBatch { + end := i + provider.maxBatch + if end > numRows { + end = numRows + } + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], provider.embedDimParam, taskType, time.Duration(provider.timeoutSec)) + if err != nil { + return nil, err + } + if end-i != len(resp.Predictions) { + return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Predictions)) + } + for _, item := range resp.Predictions { + if len(item.Embeddings.Values) != int(provider.fieldDim) { + return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]", + provider.fieldDim, len(item.Embeddings.Values)) + } + data = append(data, item.Embeddings.Values) + } + } + return data, nil +} diff --git a/internal/util/function/vertexai_embedding_provider_test.go b/internal/util/function/vertexai_embedding_provider_test.go new file mode 100644 index 0000000000000..321a531a1ac0e --- /dev/null +++ b/internal/util/function/vertexai_embedding_provider_test.go @@ -0,0 +1,170 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/vertexai" +) + +func TestVertextAITextEmbeddingProvider(t *testing.T) { + suite.Run(t, new(VertextAITextEmbeddingProviderSuite)) +} + +type VertextAITextEmbeddingProviderSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + providers []string +} + +func (s *VertextAITextEmbeddingProviderSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } +} + +func createVertextAIProvider(url string, schema *schemapb.FieldSchema) (TextEmbeddingProvider, error) { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: modelNameParamKey, Value: textEmbedding005}, + {Key: locationParamKey, Value: "mock_local"}, + {Key: projectIDParamKey, Value: "mock_id"}, + {Key: taskTypeParamKey, Value: vertexAICodeRetrival}, + {Key: embeddingUrlParamKey, Value: url}, + {Key: dimParamKey, Value: "4"}, + }, + } + mockClient := vertexai.NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock scope", "mock token") + return NewVertextAIEmbeddingProvider(schema, functionSchema, mockClient) +} + +func (s *VertextAITextEmbeddingProviderSuite) TestEmbedding() { + ts := CreateVertexAIEmbeddingServer() + + defer ts.Close() + provder, err := createVertextAIProvider(ts.URL, s.schema.Fields[2]) + s.NoError(err) + { + data := []string{"sentence"} + ret, err2 := provder.CallEmbedding(data, false, InsertMode) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(4, len(ret[0])) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0]) + } + { + data := []string{"sentence 1", "sentence 2", "sentence 3"} + ret, _ := provder.CallEmbedding(data, false, SearchMode) + s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) + } + +} + +func (s *VertextAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res vertexai.EmbeddingResponse + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: []float32{1.0, 1.0, 1.0, 1.0}, + }, + }) + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: []float32{1.0, 1.0}, + }, + }) + + res.Metadata = vertexai.Metadata{ + BillableCharacterCount: 100, + } + + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + provder, err := createVertextAIProvider(ts.URL, s.schema.Fields[2]) + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) +} + +func (s *VertextAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res vertexai.EmbeddingResponse + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: []float32{1.0, 1.0, 1.0, 1.0}, + }, + }) + res.Metadata = vertexai.Metadata{ + BillableCharacterCount: 100, + } + + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + provder, err := createVertextAIProvider(ts.URL, s.schema.Fields[2]) + + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence2"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) +}