diff --git a/go.mod b/go.mod index 5c8ff76d9e3cc..fdfc3e2bd43e9 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,10 @@ require ( require ( cloud.google.com/go/storage v1.43.0 github.com/antlr4-go/antlr/v4 v4.13.1 + github.com/aws/aws-sdk-go-v2 v1.32.6 + github.com/aws/aws-sdk-go-v2/config v1.28.6 + github.com/aws/aws-sdk-go-v2/credentials v1.17.47 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0 github.com/bits-and-blooms/bitset v1.10.0 github.com/bytedance/sonic v1.12.2 github.com/cenkalti/backoff/v4 v4.2.1 @@ -101,6 +105,17 @@ require ( github.com/apache/pulsar-client-go v0.6.1-0.20210728062540-29414db801a7 // indirect github.com/apache/thrift v0.18.1 // indirect github.com/ardielle/ardielle-go v1.5.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect + github.com/aws/smithy-go v1.22.1 // indirect github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bytedance/sonic/loader v0.2.0 // indirect diff --git a/go.sum b/go.sum index 228644055077f..52dd7ac9b51bc 100644 --- a/go.sum +++ b/go.sum @@ -120,6 +120,36 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5 github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aws/aws-sdk-go v1.32.6/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= +github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= +github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo= +github.com/aws/aws-sdk-go-v2/config v1.28.6/go.mod h1:GDzxJ5wyyFSCoLkS+UhGB0dArhb9mI+Co4dHtoTxbko= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47 h1:48bA+3/fCdi2yAwVt+3COvmatZ6jUDNkDTIsqDiMUdw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47/go.mod h1:+KdckOejLW3Ks3b0E3b5rHsr2f9yuORBum0WPnE5o5w= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 h1:s/fF4+yDQDoElYhfIVvSNyeCydfbuTKzhxSXDXCPasU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25/go.mod h1:IgPfDv5jqFIzQSNbUEMoitNooSMXjRSDkhXv8jiROvU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 h1:ZntTCl5EsYnhN/IygQEUugpdwbhdkom9uHcbCftiGgA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25/go.mod h1:DBdPrgeocww+CSl1C8cEV8PN1mHMBhuCDLpXezyvWkE= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0 h1:mfV5tcLXeRLbiyI4EHoHWH1sIU7JvbfXVvymUCIgZEo= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0/go.mod h1:YSSgYnasDKm5OjU3bOPkaz+2PFO6WjEQGIA6KQNsR3Q= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 h1:50+XsN70RS7dwJ2CkVNXzj7U2L1HKP8nqTd3XWEXBN4= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6/go.mod h1:WqgLmwY7so32kG01zD8CPTJWVWM+TzJoOVHwTg4aPug= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 h1:rLnYAfXQ3YAccocshIH5mzNNwZBkBo+bP6EhIxak6Hw= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7/go.mod h1:ZHtuQJ6t9A/+YDuxOLnbryAmITtr8UysSny3qcyvJTc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 h1:JnhTZR3PiYDNKlXy50/pNeix9aGMo6lLpXwJ1mw8MD4= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6/go.mod h1:URronUEGfXZN1VpdktPSD1EkAL9mfrV+2F4sjH38qOY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 h1:s4074ZO1Hk8qv65GqNXqDjmkf4HSQqJukaLuuW0TpDA= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2/go.mod h1:mVggCnIWoM09jP71Wh+ea7+5gAp53q+49wDFs1SW5z8= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b h1:5JgaFtHFRnOPReItxvhMDXbvuBkjSWE+9glJyF466yw= diff --git a/internal/datanode/importv2/scheduler_test.go b/internal/datanode/importv2/scheduler_test.go index 7752c382187d1..a7c56a4b2561b 100644 --- a/internal/datanode/importv2/scheduler_test.go +++ b/internal/datanode/importv2/scheduler_test.go @@ -37,6 +37,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/importutilv2" "github.com/milvus-io/milvus/internal/util/testutil" "github.com/milvus-io/milvus/pkg/common" @@ -435,6 +436,107 @@ func (s *SchedulerSuite) TestScheduler_ImportFile() { s.NoError(err) } +func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() { + s.syncMgr.EXPECT().SyncData(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, callbacks ...func(error) error) *conc.Future[struct{}] { + future := conc.Go(func() (struct{}, error) { + return struct{}{}, nil + }) + return future + }) + ts := function.CreateOpenAIEmbeddingServer() + defer ts.Close() + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.MaxLengthKey, Value: "128"}, + }, + }, + { + FieldID: 101, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "4", + }, + }, + }, + { + FieldID: 102, + Name: "int64", + DataType: schemapb.DataType_Int64, + }, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{100}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{101}, + OutputFieldNames: []string{"vec"}, + Params: []*commonpb.KeyValuePair{ + {Key: function.Provider, Value: function.OpenAIProvider}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "4"}, + }, + }, + }, + } + + var once sync.Once + data, err := testutil.CreateInsertData(schema, s.numRows) + s.NoError(err) + s.reader = importutilv2.NewMockReader(s.T()) + s.reader.EXPECT().Read().RunAndReturn(func() (*storage.InsertData, error) { + var res *storage.InsertData + once.Do(func() { + res = data + }) + if res != nil { + return res, nil + } + return nil, io.EOF + }) + importReq := &datapb.ImportRequest{ + JobID: 10, + TaskID: 11, + CollectionID: 12, + PartitionIDs: []int64{13}, + Vchannels: []string{"v0"}, + Schema: schema, + Files: []*internalpb.ImportFile{ + { + Paths: []string{"dummy.json"}, + }, + }, + Ts: 1000, + IDRange: &datapb.IDRange{ + Begin: 0, + End: int64(s.numRows), + }, + RequestSegments: []*datapb.ImportRequestSegment{ + { + SegmentID: 14, + PartitionID: 13, + Vchannel: "v0", + }, + }, + } + importTask := NewImportTask(importReq, s.manager, s.syncMgr, s.cm) + s.manager.Add(importTask) + err = importTask.(*ImportTask).importFile(s.reader) + s.NoError(err) +} + func TestScheduler(t *testing.T) { suite.Run(t, new(SchedulerSuite)) } diff --git a/internal/datanode/importv2/util.go b/internal/datanode/importv2/util.go index eb6c592f85b12..9e683a6f65549 100644 --- a/internal/datanode/importv2/util.go +++ b/internal/datanode/importv2/util.go @@ -208,12 +208,39 @@ func AppendSystemFieldsData(task *ImportTask, data *storage.InsertData) error { } func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error { + if err := RunBm25Function(task, data); err != nil { + return err + } + if err := RunDenseEmbedding(task, data); err != nil { + return err + } + return nil +} + +func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error { + schema := task.GetSchema() + if function.HasFunctions(schema.Functions, []int64{}) { + exec, err := function.NewFunctionExecutor(schema) + if err != nil { + return err + } + if err := exec.ProcessBulkInsert(data); err != nil { + return err + } + } + return nil +} + +func RunBm25Function(task *ImportTask, data *storage.InsertData) error { fns := task.GetSchema().GetFunctions() for _, fn := range fns { runner, err := function.NewFunctionRunner(task.GetSchema(), fn) if err != nil { return err } + if runner == nil { + continue + } inputDatas := make([]any, 0, len(fn.InputFieldIds)) for _, inputFieldID := range fn.InputFieldIds { inputDatas = append(inputDatas, data.Data[inputFieldID].GetDataRows()) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index dbf76a8bfe17c..8525effd00c02 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -947,8 +947,8 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche if vectorField.GetIsFunctionOutput() { for _, function := range collSchema.Functions { - if function.Type == schemapb.FunctionType_BM25 { - // TODO: currently only BM25 function is supported, thus guarantees one input field to one output field + if function.Type == schemapb.FunctionType_BM25 || function.Type == schemapb.FunctionType_TextEmbedding { + // TODO: currently only BM25 & text embedding function is supported, thus guarantees one input field to one output field if function.OutputFieldNames[0] == vectorField.Name { dataType = schemapb.DataType_VarChar } diff --git a/internal/flushcommon/pipeline/flow_graph_embedding_node.go b/internal/flushcommon/pipeline/flow_graph_embedding_node.go index bf809c49aeb7a..f264b8b39aa8e 100644 --- a/internal/flushcommon/pipeline/flow_graph_embedding_node.go +++ b/internal/flushcommon/pipeline/flow_graph_embedding_node.go @@ -67,6 +67,9 @@ func newEmbeddingNode(channelName string, schema *schemapb.CollectionSchema) (*e if err != nil { return nil, err } + if functionRunner == nil { + continue + } node.functionRunners[tf.GetId()] = functionRunner } return node, nil diff --git a/internal/models/ali/ali_dashscope_text_embedding.go b/internal/models/ali/ali_dashscope_text_embedding.go new file mode 100644 index 0000000000000..ee412c6e992f6 --- /dev/null +++ b/internal/models/ali/ali_dashscope_text_embedding.go @@ -0,0 +1,156 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ali + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "sort" + "time" + + "github.com/milvus-io/milvus/internal/models/utils" +) + +type Input struct { + Texts []string `json:"texts"` +} + +type Parameters struct { + TextType string `json:"text_type,omitempty"` + Dimension int `json:"dimension,omitempty"` + OutputType string `json:"output_type,omitempty"` +} + +type EmbeddingRequest struct { + // ID of the model to use. + Model string `json:"model"` + + // Input text to embed, encoded as a string. + Input Input `json:"input"` + + Parameters Parameters `json:"parameters,omitempty"` +} + +type Usage struct { + // The total number of tokens used by the request. + TotalTokens int `json:"total_tokens"` +} + +type SparseEmbedding struct { + Index int `json:"index"` + Value float32 `json:"value"` + Token string `json:"token"` +} + +type Embeddings struct { + TextIndex int `json:"text_index"` + Embedding []float32 `json:"embedding,omitempty"` + SparseEmbedding []SparseEmbedding `json:"sparse_embedding,omitempty"` +} + +type Output struct { + Embeddings []Embeddings `json:"embeddings"` +} + +type EmbeddingResponse struct { + Output Output `json:"output"` + Usage Usage `json:"usage"` + RequestID string `json:"request_id"` +} + +type ByIndex struct { + resp *EmbeddingResponse +} + +func (eb *ByIndex) Len() int { return len(eb.resp.Output.Embeddings) } +func (eb *ByIndex) Swap(i, j int) { + eb.resp.Output.Embeddings[i], eb.resp.Output.Embeddings[j] = eb.resp.Output.Embeddings[j], eb.resp.Output.Embeddings[i] +} + +func (eb *ByIndex) Less(i, j int) bool { + return eb.resp.Output.Embeddings[i].TextIndex < eb.resp.Output.Embeddings[j].TextIndex +} + +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + RequestID string `json:"request_id"` +} + +type AliDashScopeEmbedding struct { + apiKey string + url string +} + +func NewAliDashScopeEmbeddingClient(apiKey string, url string) *AliDashScopeEmbedding { + return &AliDashScopeEmbedding{ + apiKey: apiKey, + url: url, + } +} + +func (c *AliDashScopeEmbedding) Check() error { + if c.apiKey == "" { + return fmt.Errorf("api key is empty") + } + + if c.url == "" { + return fmt.Errorf("url is empty") + } + return nil +} + +func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + var r EmbeddingRequest + r.Model = modelName + r.Input = Input{texts} + r.Parameters.Dimension = dim + r.Parameters.TextType = textType + r.Parameters.OutputType = outputType + data, err := json.Marshal(r) + if err != nil { + return nil, err + } + + if timeoutSec <= 0 { + timeoutSec = utils.DefaultTimeout + } + + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + body, err := utils.RetrySend(req, 3) + if err != nil { + return nil, err + } + var res EmbeddingResponse + err = json.Unmarshal(body, &res) + if err != nil { + return nil, err + } + sort.Sort(&ByIndex{&res}) + return &res, err +} diff --git a/internal/models/ali/ali_dashscope_text_embedding_test.go b/internal/models/ali/ali_dashscope_text_embedding_test.go new file mode 100644 index 0000000000000..6fb7cd04e5ac4 --- /dev/null +++ b/internal/models/ali/ali_dashscope_text_embedding_test.go @@ -0,0 +1,116 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ali + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingClientCheck(t *testing.T) { + { + c := NewAliDashScopeEmbeddingClient("", "mock_uri") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewAliDashScopeEmbeddingClient("mock_key", "") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewAliDashScopeEmbeddingClient("mock_key", "mock_uri") + err := c.Check() + assert.True(t, err == nil) + } +} + +func TestEmbeddingOK(t *testing.T) { + var res EmbeddingResponse + repStr := `{ +"output": { +"embeddings": [ +{ +"text_index": 1, +"embedding": [0.1] +}, +{ +"text_index": 0, +"embedding": [0.0] +}, +{ +"text_index": 2, +"embedding": [0.2] +} +] +}, +"usage": { +"total_tokens": 100 +}, +"request_id": "0000000000000" +}` + err := json.Unmarshal([]byte(repStr), &res) + assert.NoError(t, err) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewAliDashScopeEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + ret, err := c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", "dense", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Output.Embeddings[0].TextIndex, 0) + assert.Equal(t, ret.Output.Embeddings[1].TextIndex, 1) + assert.Equal(t, ret.Output.Embeddings[2].TextIndex, 2) + assert.Equal(t, ret.Output.Embeddings[0].Embedding, []float32{0.0}) + assert.Equal(t, ret.Output.Embeddings[1].Embedding, []float32{0.1}) + assert.Equal(t, ret.Output.Embeddings[2].Embedding, []float32{0.2}) + } +} + +func TestEmbeddingFailed(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewAliDashScopeEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", "dense", 0) + assert.True(t, err != nil) + } +} diff --git a/internal/models/openai/openai_embedding.go b/internal/models/openai/openai_embedding.go new file mode 100644 index 0000000000000..433a95a2e32a7 --- /dev/null +++ b/internal/models/openai/openai_embedding.go @@ -0,0 +1,225 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package openai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "sort" + "time" + + "github.com/milvus-io/milvus/internal/models/utils" +) + +type EmbeddingRequest struct { + // ID of the model to use. + Model string `json:"model"` + + // Input text to embed, encoded as a string. + Input []string `json:"input"` + + // A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + User string `json:"user,omitempty"` + + // The format to return the embeddings in. Can be either float or base64. + EncodingFormat string `json:"encoding_format,omitempty"` + + // The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` +} + +type Usage struct { + // The number of tokens used by the prompt. + PromptTokens int `json:"prompt_tokens"` + + // The total number of tokens used by the request. + TotalTokens int `json:"total_tokens"` +} + +type EmbeddingData struct { + // The object type, which is always "embedding". + Object string `json:"object"` + + // The embedding vector, which is a list of floats. + Embedding []float32 `json:"embedding"` + + // The index of the embedding in the list of embeddings. + Index int `json:"index"` +} + +type EmbeddingResponse struct { + // The object type, which is always "list". + Object string `json:"object"` + + // The list of embeddings generated by the model. + Data []EmbeddingData `json:"data"` + + // The name of the model used to generate the embedding. + Model string `json:"model"` + + // The usage information for the request. + Usage Usage `json:"usage"` +} + +type ByIndex struct { + resp *EmbeddingResponse +} + +func (eb *ByIndex) Len() int { return len(eb.resp.Data) } +func (eb *ByIndex) Swap(i, j int) { + eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i] +} +func (eb *ByIndex) Less(i, j int) bool { return eb.resp.Data[i].Index < eb.resp.Data[j].Index } + +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + Param string `json:"param,omitempty"` + Type string `json:"type"` +} + +type EmbedddingError struct { + Error ErrorInfo `json:"error"` +} + +type OpenAIEmbeddingInterface interface { + Check() error + Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) +} + +type openAIBase struct { + apiKey string + url string +} + +func (c *openAIBase) Check() error { + if c.apiKey == "" { + return fmt.Errorf("api key is empty") + } + + if c.url == "" { + return fmt.Errorf("url is empty") + } + return nil +} + +func (c *openAIBase) genReq(modelName string, texts []string, dim int, user string) *EmbeddingRequest { + var r EmbeddingRequest + r.Model = modelName + r.Input = texts + r.EncodingFormat = "float" + if user != "" { + r.User = user + } + if dim != 0 { + r.Dimensions = dim + } + return &r +} + +func (c *openAIBase) embedding(url string, headers map[string]string, modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + r := c.genReq(modelName, texts, dim, user) + data, err := json.Marshal(r) + if err != nil { + return nil, err + } + + if timeoutSec <= 0 { + timeoutSec = utils.DefaultTimeout + } + + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + for key, value := range headers { + req.Header.Set(key, value) + } + body, err := utils.RetrySend(req, 3) + if err != nil { + return nil, err + } + + var res EmbeddingResponse + err = json.Unmarshal(body, &res) + if err != nil { + return nil, err + } + sort.Sort(&ByIndex{&res}) + return &res, err +} + +type OpenAIEmbeddingClient struct { + openAIBase +} + +func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient { + return &OpenAIEmbeddingClient{ + openAIBase{ + apiKey: apiKey, + url: url, + }, + } +} + +func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), + } + return c.embedding(c.url, headers, modelName, texts, dim, user, timeoutSec) +} + +type AzureOpenAIEmbeddingClient struct { + openAIBase + apiVersion string +} + +func NewAzureOpenAIEmbeddingClient(apiKey string, url string) *AzureOpenAIEmbeddingClient { + return &AzureOpenAIEmbeddingClient{ + openAIBase: openAIBase{ + apiKey: apiKey, + url: url, + }, + apiVersion: "2024-06-01", + } +} + +func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + base, err := url.Parse(c.url) + if err != nil { + return nil, err + } + path := fmt.Sprintf("/openai/deployments/%s/embeddings", modelName) + base.Path = path + params := url.Values{} + params.Add("api-version", c.apiVersion) + base.RawQuery = params.Encode() + url := base.String() + + headers := map[string]string{ + "Content-Type": "application/json", + "api-key": c.apiKey, + } + return c.embedding(url, headers, modelName, texts, dim, user, timeoutSec) +} diff --git a/internal/models/openai/openai_embedding_test.go b/internal/models/openai/openai_embedding_test.go new file mode 100644 index 0000000000000..87f44b4ea6308 --- /dev/null +++ b/internal/models/openai/openai_embedding_test.go @@ -0,0 +1,252 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package openai + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingClientCheck(t *testing.T) { + { + c := NewOpenAIEmbeddingClient("", "mock_uri") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewOpenAIEmbeddingClient("mock_key", "") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewOpenAIEmbeddingClient("mock_key", "mock_uri") + err := c.Check() + assert.True(t, err == nil) + } +} + +func TestEmbeddingOK(t *testing.T) { + var res EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + res.Data = []EmbeddingData{ + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.3, 4.4}, + Index: 1, + }, + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.3, 4.4}, + Index: 0, + }, + } + res.Usage = Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + if r.Header["Authorization"][0] != "" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusBadRequest) + } + } else { + if r.Header["Api-Key"][0] != "" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusBadRequest) + } + } + + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Data[0].Index, 0) + assert.Equal(t, ret.Data[1].Index, 1) + } + { + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Data[0].Index, 0) + assert.Equal(t, ret.Data[1].Index, 1) + } +} + +func TestEmbeddingRetry(t *testing.T) { + var res EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + res.Data = []EmbeddingData{ + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.2, 4.5}, + Index: 2, + }, + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.3, 4.4}, + Index: 0, + }, + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.2, 4.3}, + Index: 1, + }, + } + res.Usage = Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + + var count int32 = 0 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if atomic.LoadInt32(&count) < 2 { + atomic.AddInt32(&count, 1) + w.WriteHeader(http.StatusUnauthorized) + } else { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + } + })) + + defer ts.Close() + url := ts.URL + + { + c := NewOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Usage, res.Usage) + assert.Equal(t, ret.Object, res.Object) + assert.Equal(t, ret.Model, res.Model) + assert.Equal(t, ret.Data[0], res.Data[1]) + assert.Equal(t, ret.Data[1], res.Data[2]) + assert.Equal(t, ret.Data[2], res.Data[0]) + assert.Equal(t, atomic.LoadInt32(&count), int32(2)) + } + { + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Usage, res.Usage) + assert.Equal(t, ret.Object, res.Object) + assert.Equal(t, ret.Model, res.Model) + assert.Equal(t, ret.Data[0], res.Data[1]) + assert.Equal(t, ret.Data[1], res.Data[2]) + assert.Equal(t, ret.Data[2], res.Data[0]) + assert.Equal(t, atomic.LoadInt32(&count), int32(2)) + } +} + +func TestEmbeddingFailed(t *testing.T) { + var count int32 = 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&count, 1) + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + atomic.StoreInt32(&count, 0) + c := NewOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err != nil) + assert.Equal(t, atomic.LoadInt32(&count), int32(3)) + } + { + atomic.StoreInt32(&count, 0) + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err != nil) + assert.Equal(t, atomic.LoadInt32(&count), int32(3)) + } +} + +func TestTimeout(t *testing.T) { + var st int32 = 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + atomic.AddInt32(&st, 1) + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + atomic.StoreInt32(&st, 0) + c := NewOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1) + assert.True(t, err != nil) + assert.Equal(t, atomic.LoadInt32(&st), int32(0)) + time.Sleep(3 * time.Second) + assert.Equal(t, atomic.LoadInt32(&st), int32(1)) + } + + { + atomic.StoreInt32(&st, 0) + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1) + assert.True(t, err != nil) + assert.Equal(t, atomic.LoadInt32(&st), int32(0)) + time.Sleep(3 * time.Second) + assert.Equal(t, atomic.LoadInt32(&st), int32(1)) + } +} diff --git a/internal/models/utils/embedding_util.go b/internal/models/utils/embedding_util.go new file mode 100644 index 0000000000000..1383d5740e814 --- /dev/null +++ b/internal/models/utils/embedding_util.go @@ -0,0 +1,56 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + "io" + "net/http" + "time" +) + +const DefaultTimeout time.Duration = 30 + +func send(req *http.Request) ([]byte, error) { + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf(string(body)) + } + return body, nil +} + +func RetrySend(req *http.Request, maxRetries int) ([]byte, error) { + var err error + var res []byte + for i := 0; i < maxRetries; i++ { + res, err = send(req) + if err == nil { + return res, nil + } + } + return nil, err +} diff --git a/internal/models/vertexai/vertexai_text_embedding.go b/internal/models/vertexai/vertexai_text_embedding.go new file mode 100644 index 0000000000000..1a63c59961f81 --- /dev/null +++ b/internal/models/vertexai/vertexai_text_embedding.go @@ -0,0 +1,163 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vertexai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "golang.org/x/oauth2/google" + + "github.com/milvus-io/milvus/internal/models/utils" +) + +type Instance struct { + TaskType string `json:"task_type,omitempty"` + Content string `json:"content"` +} + +type Parameters struct { + OutputDimensionality int64 `json:"outputDimensionality,omitempty"` +} + +type EmbeddingRequest struct { + Instances []Instance `json:"instances"` + Parameters Parameters `json:"parameters,omitempty"` +} + +type Statistics struct { + Truncated bool `json:"truncated"` + TokenCount int `json:"token_count"` +} + +type Embeddings struct { + Statistics Statistics `json:"statistics"` + Values []float32 `json:"values"` +} + +type Prediction struct { + Embeddings Embeddings `json:"embeddings"` +} + +type Metadata struct { + BillableCharacterCount int `json:"billableCharacterCount"` +} + +type EmbeddingResponse struct { + Predictions []Prediction `json:"predictions"` + Metadata Metadata `json:"metadata"` +} + +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + RequestID string `json:"request_id"` +} + +type VertexAIEmbedding struct { + url string + jsonKey []byte + scopes string + token string +} + +func NewVertexAIEmbedding(url string, jsonKey []byte, scopes string, token string) *VertexAIEmbedding { + return &VertexAIEmbedding{ + url: url, + jsonKey: jsonKey, + scopes: scopes, + token: token, + } +} + +func (c *VertexAIEmbedding) Check() error { + if c.url == "" { + return fmt.Errorf("VertexAI embedding url is empty") + } + if len(c.jsonKey) == 0 { + return fmt.Errorf("jsonKey is empty") + } + if c.scopes == "" { + return fmt.Errorf("Scopes param is empty") + } + return nil +} + +func (c *VertexAIEmbedding) getAccessToken() (string, error) { + ctx := context.Background() + creds, err := google.CredentialsFromJSON(ctx, c.jsonKey, c.scopes) + if err != nil { + return "", fmt.Errorf("Failed to find credentials: %v", err) + } + token, err := creds.TokenSource.Token() + if err != nil { + return "", fmt.Errorf("Failed to get token: %v", err) + } + return token.AccessToken, nil +} + +func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int64, taskType string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + var r EmbeddingRequest + for _, text := range texts { + r.Instances = append(r.Instances, Instance{TaskType: taskType, Content: text}) + } + if dim != 0 { + r.Parameters.OutputDimensionality = dim + } + + data, err := json.Marshal(r) + if err != nil { + return nil, err + } + + if timeoutSec <= 0 { + timeoutSec = utils.DefaultTimeout + } + + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + var token string + if c.token != "" { + token = c.token + } else { + token, err = c.getAccessToken() + if err != nil { + return nil, err + } + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + body, err := utils.RetrySend(req, 3) + if err != nil { + return nil, err + } + var res EmbeddingResponse + err = json.Unmarshal(body, &res) + if err != nil { + return nil, err + } + return &res, err +} diff --git a/internal/models/vertexai/vertexai_text_embedding_test.go b/internal/models/vertexai/vertexai_text_embedding_test.go new file mode 100644 index 0000000000000..83b26ac4d4634 --- /dev/null +++ b/internal/models/vertexai/vertexai_text_embedding_test.go @@ -0,0 +1,90 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vertexai + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingClientCheck(t *testing.T) { + mockJSONKey := []byte{1, 2, 3} + { + c := NewVertexAIEmbedding("mock_url", []byte{}, "mock_scopes", "") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewVertexAIEmbedding("", mockJSONKey, "", "") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewVertexAIEmbedding("mock_url", mockJSONKey, "mock_scopes", "") + err := c.Check() + assert.True(t, err == nil) + } +} + +func TestEmbeddingOK(t *testing.T) { + var res EmbeddingResponse + repStr := `{"predictions": [{"embeddings": {"statistics": {"truncated": false, "token_count": 4}, "values": [-0.028420744463801384, 0.037183016538619995]}}, {"embeddings": {"statistics": {"truncated": false, "token_count": 8}, "values": [-0.04367655888199806, 0.03777721896767616, 0.0158217903226614]}}], "metadata": {"billableCharacterCount": 27}}` + err := json.Unmarshal([]byte(repStr), &res) + assert.NoError(t, err) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token") + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-005", []string{"sentence"}, 0, "query", 0) + assert.True(t, err == nil) + } +} + +func TestEmbeddingFailed(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token") + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", 0) + assert.True(t, err != nil) + } +} diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 9de31cd53d600..37f0a58b84aa4 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -12,6 +12,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -141,6 +142,16 @@ func (it *insertTask) PreExecute(ctx context.Context) error { } it.schema = schema.CollectionSchema + // Calculate embedding fields + if function.HasFunctions(schema.CollectionSchema.Functions, []int64{}) { + exec, err := function.NewFunctionExecutor(schema.CollectionSchema) + if err != nil { + return err + } + if err := exec.ProcessInsert(it.insertMsg); err != nil { + return err + } + } rowNums := uint32(it.insertMsg.NRows()) // set insertTask.rowIDs var rowIDBegin UniqueID diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index cdf90290567ce..006d383be3146 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -10,6 +10,10 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -308,3 +312,128 @@ func TestMaxInsertSize(t *testing.T) { assert.ErrorIs(t, err, merr.ErrParameterTooLarge) }) } + +func TestInsertTask_Function(t *testing.T) { + ts := function.CreateOpenAIEmbeddingServer() + defer ts.Close() + data := []*schemapb.FieldData{} + f := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + FieldName: "text", + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"sentence", "sentence"}, + }, + }, + }, + }, + } + data = append(data, &f) + collectionName := "TestInsertTask_function" + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "TestInsertTask_function", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length", Value: "200"}, + }, + }, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + IsFunctionOutput: true, + }, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "test_function", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, + OutputFieldNames: []string{"vector"}, + Params: []*commonpb.KeyValuePair{ + {Key: function.Provider, Value: function.OpenAIProvider}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "4"}, + }, + }, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + rc := mocks.NewMockRootCoordClient(t) + rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{ + Status: merr.Status(nil), + ID: 11198, + Count: 10, + }, nil) + idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0) + idAllocator.Start() + defer idAllocator.Close() + assert.NoError(t, err) + task := insertTask{ + ctx: context.Background(), + insertMsg: &BaseInsertTask{ + InsertRequest: &msgpb.InsertRequest{ + CollectionName: collectionName, + DbName: "hooooooo", + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + Version: msgpb.InsertDataVersion_ColumnBased, + FieldsData: data, + NumRows: 2, + }, + }, + schema: schema, + idAllocator: idAllocator, + } + + info := newSchemaInfo(schema) + cache := NewMockCache(t) + cache.On("GetCollectionSchema", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(info, nil) + + cache.On("GetPartitionInfo", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(&partitionInfo{ + name: "p1", + partitionID: 10, + createdTimestamp: 10001, + createdUtcTimestamp: 10002, + }, nil) + cache.On("GetCollectionInfo", + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(&collectionInfo{schema: info}, nil) + cache.On("GetDatabaseInfo", + mock.Anything, + mock.Anything, + ).Return(&databaseInfo{properties: []*commonpb.KeyValuePair{}}, nil) + + globalMetaCache = cache + err = task.PreExecute(ctx) + assert.NoError(t, err) +} diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index ffa7c9b23b8ee..0b5988961c84f 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -22,6 +22,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/exprutil" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -419,6 +420,17 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { zap.Stringer("plan", plan)) // may be very large if large term passed. } + var err error + if function.HasFunctions(t.schema.CollectionSchema.Functions, []int64{}) { + exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema) + if err != nil { + return err + } + if err := exec.ProcessSearch(t.SearchRequest); err != nil { + return err + } + } + t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId() t.SearchRequest.GroupSize = t.rankParams.GetGroupSize() @@ -426,7 +438,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { if t.partitionKeyMode { t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect() } - var err error + t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams()) if err != nil { log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err)) @@ -497,6 +509,16 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { t.SearchRequest.DslType = commonpb.DslType_BoolExprV1 t.SearchRequest.GroupByFieldId = queryInfo.GroupByFieldId t.SearchRequest.GroupSize = queryInfo.GroupSize + + if function.HasFunctions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) { + exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema) + if err != nil { + return err + } + if err := exec.ProcessSearch(t.SearchRequest); err != nil { + return err + } + } log.Debug("proxy init search request", zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.Stringer("plan", plan)) // may be very large if large term passed. diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 1edf764c8b418..b255027cf7316 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -26,6 +26,7 @@ import ( "github.com/cockroachdb/errors" "github.com/google/uuid" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -42,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -475,6 +477,232 @@ func TestSearchTask_PreExecute(t *testing.T) { }) } +func TestSearchTask_WithFunctions(t *testing.T) { + ts := function.CreateOpenAIEmbeddingServer() + defer ts.Close() + collectionName := "TestSearchTask_function" + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "TestSearchTask_function", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length", Value: "200"}, + }, + }, + { + FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "func1", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, + OutputFieldNames: []string{"vector1"}, + Params: []*commonpb.KeyValuePair{ + {Key: function.Provider, Value: function.OpenAIProvider}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "4"}, + }, + }, + { + Name: "func2", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{103}, + OutputFieldNames: []string{"vector2"}, + Params: []*commonpb.KeyValuePair{ + {Key: function.Provider, Value: function.OpenAIProvider}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "4"}, + }, + }, + }, + } + + var err error + var ( + rc = NewRootCoordMock() + qc = mocks.NewMockQueryCoordClient(t) + ctx = context.TODO() + ) + + defer rc.Close() + require.NoError(t, err) + mgr := newShardClientMgr() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() + err = InitMetaCache(ctx, rc, qc, mgr) + require.NoError(t, err) + + getSearchTask := func(t *testing.T, collName string, data []string) *searchTask { + placeholderValue := &commonpb.PlaceholderValue{ + Tag: "$0", + Type: commonpb.PlaceholderType_VarChar, + Values: lo.Map(data, func(str string, _ int) []byte { return []byte(str) }), + } + holder := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{placeholderValue}, + } + holderByte, _ := proto.Marshal(holder) + task := &searchTask{ + ctx: ctx, + collectionName: collectionName, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + Timestamp: uint64(time.Now().UnixNano()), + }, + }, + request: &milvuspb.SearchRequest{ + CollectionName: collectionName, + Nq: int64(len(data)), + SearchParams: []*commonpb.KeyValuePair{ + {Key: AnnsFieldKey, Value: "vector1"}, + {Key: TopKKey, Value: "10"}, + }, + PlaceholderGroup: holderByte, + }, + qc: qc, + tr: timerecord.NewTimeRecorder("test-search"), + } + require.NoError(t, task.OnEnqueue()) + return task + } + + collectionID := UniqueID(1000) + cache := NewMockCache(t) + info := newSchemaInfo(schema) + cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe() + cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(info, nil).Maybe() + cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe() + cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe() + cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe() + cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe() + globalMetaCache = cache + + { + task := getSearchTask(t, collectionName, []string{"sentence"}) + err = task.PreExecute(ctx) + assert.NoError(t, err) + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(task.SearchRequest.PlaceholderGroup, pb) + assert.Equal(t, len(pb.Placeholders), 1) + assert.Equal(t, len(pb.Placeholders[0].Values), 1) + assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector) + } + + { + task := getSearchTask(t, collectionName, []string{"sentence 1", "sentence 2"}) + err = task.PreExecute(ctx) + assert.NoError(t, err) + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(task.SearchRequest.PlaceholderGroup, pb) + assert.Equal(t, len(pb.Placeholders), 1) + assert.Equal(t, len(pb.Placeholders[0].Values), 2) + assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector) + } + + getHybridSearchTask := func(t *testing.T, collName string, data [][]string) *searchTask { + subReqs := []*milvuspb.SubSearchRequest{} + for _, item := range data { + placeholderValue := &commonpb.PlaceholderValue{ + Tag: "$0", + Type: commonpb.PlaceholderType_VarChar, + Values: lo.Map(item, func(str string, _ int) []byte { return []byte(str) }), + } + holder := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{placeholderValue}, + } + holderByte, _ := proto.Marshal(holder) + subReq := &milvuspb.SubSearchRequest{ + PlaceholderGroup: holderByte, + SearchParams: []*commonpb.KeyValuePair{ + {Key: AnnsFieldKey, Value: "vector1"}, + {Key: TopKKey, Value: "10"}, + }, + Nq: int64(len(item)), + } + subReqs = append(subReqs, subReq) + } + task := &searchTask{ + ctx: ctx, + collectionName: collectionName, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + Timestamp: uint64(time.Now().UnixNano()), + }, + }, + request: &milvuspb.SearchRequest{ + CollectionName: collectionName, + SubReqs: subReqs, + SearchParams: []*commonpb.KeyValuePair{ + {Key: LimitKey, Value: "10"}, + }, + }, + qc: qc, + tr: timerecord.NewTimeRecorder("test-search"), + } + require.NoError(t, task.OnEnqueue()) + return task + } + + { + task := getHybridSearchTask(t, collectionName, [][]string{ + {"sentence1"}, + {"sentence2"}, + }) + err = task.PreExecute(ctx) + assert.NoError(t, err) + assert.Equal(t, len(task.SearchRequest.SubReqs), 2) + for _, sub := range task.SearchRequest.SubReqs { + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(sub.PlaceholderGroup, pb) + assert.Equal(t, len(pb.Placeholders), 1) + assert.Equal(t, len(pb.Placeholders[0].Values), 1) + assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector) + } + } + + { + task := getHybridSearchTask(t, collectionName, [][]string{ + {"sentence1", "sentence1"}, + {"sentence2", "sentence2"}, + {"sentence3", "sentence3"}, + }) + err = task.PreExecute(ctx) + assert.NoError(t, err) + assert.Equal(t, len(task.SearchRequest.SubReqs), 3) + for _, sub := range task.SearchRequest.SubReqs { + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(sub.PlaceholderGroup, pb) + assert.Equal(t, len(pb.Placeholders), 1) + assert.Equal(t, len(pb.Placeholders[0].Values), 2) + assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector) + } + } +} + func getQueryCoord() *mocks.MockQueryCoord { qc := &mocks.MockQueryCoord{} qc.EXPECT().Start().Return(nil) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 019063bbbc9f1..d1fd055556dea 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "encoding/binary" + "fmt" "math/rand" "strconv" "testing" @@ -1022,6 +1023,47 @@ func TestCreateCollectionTask(t *testing.T) { err = task2.PreExecute(ctx) assert.Error(t, err) }) + + t.Run("collection with embedding function ", func(t *testing.T) { + fmt.Println(schema) + schema.Functions = []*schemapb.FunctionSchema{ + { + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldNames: []string{varCharField}, + OutputFieldNames: []string{floatVecField}, + Params: []*commonpb.KeyValuePair{ + {Key: "provider", Value: "openai"}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + }, + }, + } + + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + task2 := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + err = task2.OnEnqueue() + assert.NoError(t, err) + + err = task2.PreExecute(ctx) + assert.NoError(t, err) + }) } func TestHasCollectionTask(t *testing.T) { diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 1de223fa4124d..32f3e29f0e0a5 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -152,6 +153,16 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { return err } + // Calculate embedding fields + if function.HasFunctions(it.schema.CollectionSchema.Functions, []int64{}) { + exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema) + if err != nil { + return err + } + if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil { + return err + } + } rowNums := uint32(it.upsertMsg.InsertMsg.NRows()) // set upsertTask.insertRequest.rowIDs tr := timerecord.NewTimeRecorder("applyPK") diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index 75fd39964b00e..da0b3595cc45e 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -27,8 +27,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/testutils" ) @@ -360,3 +365,124 @@ func TestUpsertTaskForReplicate(t *testing.T) { assert.Error(t, err) }) } + +func TestUpsertTask_Function(t *testing.T) { + ts := function.CreateOpenAIEmbeddingServer() + defer ts.Close() + data := []*schemapb.FieldData{} + f1 := schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldId: 100, + FieldName: "id", + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{0, 1}, + }, + }, + }, + }, + } + data = append(data, &f1) + f2 := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + FieldName: "text", + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"sentence", "sentence"}, + }, + }, + }, + }, + } + data = append(data, &f2) + collectionName := "TestUpsertTask_function" + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "TestUpsertTask_function", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length", Value: "200"}, + }, + }, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + IsFunctionOutput: true, + }, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "test_function", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, + OutputFieldNames: []string{"vector"}, + Params: []*commonpb.KeyValuePair{ + {Key: function.Provider, Value: function.OpenAIProvider}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "4"}, + }, + }, + }, + } + + info := newSchemaInfo(schema) + collectionID := UniqueID(0) + cache := NewMockCache(t) + globalMetaCache = cache + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + rc := mocks.NewMockRootCoordClient(t) + rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{ + Status: merr.Status(nil), + ID: collectionID, + Count: 10, + }, nil) + idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0) + idAllocator.Start() + defer idAllocator.Close() + assert.NoError(t, err) + task := upsertTask{ + ctx: context.Background(), + req: &milvuspb.UpsertRequest{ + CollectionName: collectionName, + }, + upsertMsg: &msgstream.UpsertMsg{ + InsertMsg: &msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Insert), + ), + CollectionName: collectionName, + DbName: "hooooooo", + Version: msgpb.InsertDataVersion_ColumnBased, + FieldsData: data, + NumRows: 2, + PartitionName: Params.CommonCfg.DefaultPartitionName.GetValue(), + }, + }, + }, + idAllocator: idAllocator, + schema: info, + result: &milvuspb.MutationResult{}, + } + err = task.insertPreExecute(ctx) + assert.NoError(t, err) +} diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 76b477400fea4..2efb9e2ebd9bd 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -39,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/internal/util/indexparamcheck" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" @@ -705,6 +706,10 @@ func validateFunction(coll *schemapb.CollectionSchema) error { return err } } + + if err := function.CheckFunctions(coll); err != nil { + return err + } return nil } @@ -718,6 +723,11 @@ func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schem if !typeutil.IsSparseFloatVectorType(fields[0].GetDataType()) { return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String()) } + case schemapb.FunctionType_TextEmbedding: + if len(fields) != 1 || fields[0].DataType != schemapb.DataType_FloatVector { + return fmt.Errorf("TextEmbedding function output field must be a FloatVector field, got %d field with type %s", + len(fields), fields[0].DataType.String()) + } default: return fmt.Errorf("check output field for unknown function type") } @@ -744,7 +754,11 @@ func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schema if !h.EnableAnalyzer() { return fmt.Errorf("BM25 function input field must set enable_analyzer to true") } - + case schemapb.FunctionType_TextEmbedding: + if len(fields) != 1 || fields[0].DataType != schemapb.DataType_VarChar { + return fmt.Errorf("TextEmbedding function input field must be a VARCHAR field, got %d field with type %s", + len(fields), fields[0].DataType.String()) + } default: return fmt.Errorf("check input field with unknown function type") } @@ -786,6 +800,10 @@ func checkFunctionBasicParams(function *schemapb.FunctionSchema) error { if len(function.GetParams()) != 0 { return fmt.Errorf("BM25 function accepts no params") } + case schemapb.FunctionType_TextEmbedding: + if len(function.GetParams()) == 0 { + return fmt.Errorf("TextEmbedding function need provider and model_name params") + } default: return fmt.Errorf("check function params with unknown function type") } @@ -942,7 +960,7 @@ func fillFieldPropertiesBySchema(columns []*schemapb.FieldData, schema *schemapb expectColumnNum := 0 for _, field := range schema.GetFields() { fieldName2Schema[field.Name] = field - if !field.GetIsFunctionOutput() { + if !IsBM25FunctionOutputField(field) { expectColumnNum++ } } @@ -1519,12 +1537,12 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst if fieldSchema.GetDefaultValue() != nil && fieldSchema.IsPrimaryKey { return merr.WrapErrParameterInvalidMsg("primary key can't be with default value") } - if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() { + if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || IsBM25FunctionOutputField(fieldSchema) { // when inInsert, no need to pass when pk is autoid and SkipAutoIDCheck is false autoGenFieldNum++ } if _, ok := dataNameSet[fieldSchema.GetName()]; !ok { - if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() { + if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || IsBM25FunctionOutputField(fieldSchema) { // autoGenField continue } @@ -1548,7 +1566,6 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst zap.Int64("primaryKeyNum", int64(primaryKeyNum))) return merr.WrapErrParameterInvalidMsg("more than 1 primary keys not supported, got %d", primaryKeyNum) } - expectedNum := len(schema.Fields) actualNum := len(insertMsg.FieldsData) + autoGenFieldNum @@ -2231,3 +2248,7 @@ func GetReplicateID(ctx context.Context, database, collectionName string) (strin replicateID, _ := common.GetReplicateID(dbInfo.properties) return replicateID, nil } + +func IsBM25FunctionOutputField(field *schemapb.FieldSchema) bool { + return field.GetIsFunctionOutput() && field.GetDataType() == schemapb.DataType_SparseFloatVector +} diff --git a/internal/querynodev2/pipeline/embedding_node.go b/internal/querynodev2/pipeline/embedding_node.go index da75099f0b8af..76ef66e9c6ef1 100644 --- a/internal/querynodev2/pipeline/embedding_node.go +++ b/internal/querynodev2/pipeline/embedding_node.go @@ -70,6 +70,9 @@ func newEmbeddingNode(collectionID int64, channelName string, manager *DataManag if err != nil { return nil, err } + if functionRunner == nil { + continue + } node.functionRunners = append(node.functionRunners, functionRunner) } return node, nil diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 5e444027dcc48..614c29057c896 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -371,7 +371,7 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap } for _, field := range collSchema.Fields { - if skipFunction && field.GetIsFunctionOutput() { + if skipFunction && IsBM25FunctionOutputField(field) { continue } @@ -503,7 +503,7 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche } length := 0 for _, field := range collSchema.Fields { - if field.GetIsFunctionOutput() { + if IsBM25FunctionOutputField(field) { continue } @@ -1340,3 +1340,8 @@ func (ni NullableInt) GetValue() int { func (ni NullableInt) IsNull() bool { return ni.Value == nil } + +// TODO: unify the function implementation, storage/utils.go & proxy/util.go +func IsBM25FunctionOutputField(field *schemapb.FieldSchema) bool { + return field.GetIsFunctionOutput() && field.GetDataType() == schemapb.DataType_SparseFloatVector +} diff --git a/internal/util/function/ali_embedding_provider.go b/internal/util/function/ali_embedding_provider.go new file mode 100644 index 0000000000000..966c530522e16 --- /dev/null +++ b/internal/util/function/ali_embedding_provider.go @@ -0,0 +1,153 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/ali" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type AliEmbeddingProvider struct { + fieldDim int64 + + client *ali.AliDashScopeEmbedding + modelName string + embedDimParam int64 + outputType string + + maxBatch int + timeoutSec int +} + +func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) { + if apiKey == "" { + apiKey = os.Getenv(dashscopeApiKey) + } + if apiKey == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeApiKey) + } + + if url == "" { + url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + } + c := ali.NewAliDashScopeEmbeddingClient(apiKey, url) + return c, nil +} + +func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*AliEmbeddingProvider, error) { + fieldDim, err := typeutil.GetDim(fieldSchema) + if err != nil { + return nil, err + } + var apiKey, url, modelName string + var dim int64 + + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case modelNameParamKey: + modelName = param.Value + case dimParamKey: + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) + if err != nil { + return nil, err + } + case apiKeyParamKey: + apiKey = param.Value + case embeddingURLParamKey: + url = param.Value + default: + } + } + + if modelName != TextEmbeddingV1 && modelName != TextEmbeddingV2 && modelName != TextEmbeddingV3 { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", + modelName, TextEmbeddingV1, TextEmbeddingV2, TextEmbeddingV3) + } + + c, err := createAliEmbeddingClient(apiKey, url) + if err != nil { + return nil, err + } + + maxBatch := 25 + if modelName == TextEmbeddingV3 { + maxBatch = 6 + } + + provider := AliEmbeddingProvider{ + client: c, + fieldDim: fieldDim, + modelName: modelName, + embedDimParam: dim, + // TextEmbedding only supports dense embedding + outputType: "dense", + maxBatch: maxBatch, + timeoutSec: 30, + } + return &provider, nil +} + +func (provider *AliEmbeddingProvider) MaxBatch() int { + return 5 * provider.maxBatch +} + +func (provider *AliEmbeddingProvider) FieldDim() int64 { + return provider.fieldDim +} + +func (provider *AliEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, mode string) ([][]float32, error) { + numRows := len(texts) + if batchLimit && numRows > provider.MaxBatch() { + return nil, fmt.Errorf("Ali text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) + } + var textType string + if mode == SearchMode { + textType = "query" + } else { + textType = "document" + } + data := make([][]float32, 0, numRows) + for i := 0; i < numRows; i += provider.maxBatch { + end := i + provider.maxBatch + if end > numRows { + end = numRows + } + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, time.Duration(provider.timeoutSec)) + if err != nil { + return nil, err + } + if end-i != len(resp.Output.Embeddings) { + return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Output.Embeddings)) + } + for _, item := range resp.Output.Embeddings { + if len(item.Embedding) != int(provider.fieldDim) { + return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]", + provider.fieldDim, len(item.Embedding)) + } + data = append(data, item.Embedding) + } + } + return data, nil +} diff --git a/internal/util/function/alitext_embedding_provider_test.go b/internal/util/function/alitext_embedding_provider_test.go new file mode 100644 index 0000000000000..a852b1b74e6ab --- /dev/null +++ b/internal/util/function/alitext_embedding_provider_test.go @@ -0,0 +1,165 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/ali" + + "github.com/stretchr/testify/suite" +) + +func TestAliTextEmbeddingProvider(t *testing.T) { + suite.Run(t, new(AliTextEmbeddingProviderSuite)) +} + +type AliTextEmbeddingProviderSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + providers []string +} + +func (s *AliTextEmbeddingProviderSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + s.providers = []string{AliDashScopeProvider} +} + +func createAliProvider(url string, schema *schemapb.FieldSchema, providerName string) (TextEmbeddingProvider, error) { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: modelNameParamKey, Value: TextEmbeddingV3}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: url}, + {Key: dimParamKey, Value: "4"}, + }, + } + switch providerName { + case AliDashScopeProvider: + return NewAliDashScopeEmbeddingProvider(schema, functionSchema) + default: + return nil, fmt.Errorf("Unknow provider") + } +} + +func (s *AliTextEmbeddingProviderSuite) TestEmbedding() { + ts := CreateAliEmbeddingServer() + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName) + s.NoError(err) + { + data := []string{"sentence"} + ret, err2 := provder.CallEmbedding(data, false, InsertMode) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(4, len(ret[0])) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0]) + } + { + data := []string{"sentence 1", "sentence 2", "sentence 3"} + ret, _ := provder.CallEmbedding(data, false, SearchMode) + s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) + } + } +} + +func (s *AliTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res ali.EmbeddingResponse + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + TextIndex: 0, + }) + + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: []float32{1.0, 1.0}, + TextIndex: 1, + }) + res.Usage = ali.Usage{ + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName) + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) + } +} + +func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res ali.EmbeddingResponse + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + TextIndex: 0, + }) + res.Usage = ali.Usage{ + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName) + + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence2"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) + } +} diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go new file mode 100644 index 0000000000000..eb54712ce5499 --- /dev/null +++ b/internal/util/function/bedrock_embedding_provider.go @@ -0,0 +1,203 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type BedrockClient interface { + InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) +} + +type BedrockEmbeddingProvider struct { + fieldDim int64 + + client BedrockClient + modelName string + embedDimParam int64 + normalize bool + + maxBatch int + timeoutSec int +} + +func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) { + if awsAccessKeyId == "" { + awsAccessKeyId = os.Getenv(bedrockAccessKeyId) + } + if awsAccessKeyId == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId) + } + + if awsSecretAccessKey == "" { + awsSecretAccessKey = os.Getenv(bedrockSecretAccessKey) + } + if awsSecretAccessKey == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSecretAccessKey) + } + if region == "" { + return nil, fmt.Errorf("Missing region. Please pass `region` param.") + } + + cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( + awsAccessKeyId, awsSecretAccessKey, "")), + ) + if err != nil { + return nil, err + } + + return bedrockruntime.NewFromConfig(cfg), nil +} + +func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient) (*BedrockEmbeddingProvider, error) { + fieldDim, err := typeutil.GetDim(fieldSchema) + if err != nil { + return nil, err + } + var awsAccessKeyId, awsSecretAccessKey, region, modelName string + var dim int64 + normalize := true + + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case modelNameParamKey: + modelName = param.Value + case dimParamKey: + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) + if err != nil { + return nil, err + } + case awsAccessKeyIdParamKey: + awsAccessKeyId = param.Value + case awsSecretAccessKeyParamKey: + awsSecretAccessKey = param.Value + case regionParamKey: + region = param.Value + case normalizeParamKey: + switch strings.ToLower(param.Value) { + case "false": + normalize = false + default: + return nil, fmt.Errorf("Illegal [%s:%s] param, ", normalizeParamKey, param.Value) + } + default: + } + } + + if modelName != BedRockTitanTextEmbeddingsV2 { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s]", + modelName, BedRockTitanTextEmbeddingsV2) + } + var client BedrockClient + if c == nil { + client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region) + if err != nil { + return nil, err + } + } else { + client = c + } + + return &BedrockEmbeddingProvider{ + client: client, + fieldDim: fieldDim, + modelName: modelName, + embedDimParam: dim, + normalize: normalize, + maxBatch: 1, + timeoutSec: 30, + }, nil +} + +func (provider *BedrockEmbeddingProvider) MaxBatch() int { + return 12 * provider.maxBatch +} + +func (provider *BedrockEmbeddingProvider) FieldDim() int64 { + return provider.fieldDim +} + +func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, _ string) ([][]float32, error) { + numRows := len(texts) + if batchLimit && numRows > provider.MaxBatch() { + return nil, fmt.Errorf("Bedrock text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) + } + + data := make([][]float32, 0, numRows) + for i := 0; i < numRows; i += 1 { + payload := BedRockRequest{ + InputText: texts[i], + Normalize: provider.normalize, + } + if provider.embedDimParam != 0 { + payload.Dimensions = provider.embedDimParam + } + + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + output, err := provider.client.InvokeModel(context.Background(), &bedrockruntime.InvokeModelInput{ + Body: payloadBytes, + ModelId: aws.String(provider.modelName), + ContentType: aws.String("application/json"), + }) + if err != nil { + return nil, err + } + + var resp BedRockResponse + err = json.Unmarshal(output.Body, &resp) + if err != nil { + return nil, err + } + if len(resp.Embedding) != int(provider.fieldDim) { + return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]", + provider.fieldDim, len(resp.Embedding)) + } + data = append(data, resp.Embedding) + } + return data, nil +} + +type BedRockRequest struct { + InputText string `json:"inputText"` + Dimensions int64 `json:"dimensions,omitempty"` + Normalize bool `json:"normalize,omitempty"` +} + +type BedRockResponse struct { + Embedding []float32 `json:"embedding"` + InputTextTokenCount int `json:"inputTextTokenCount"` +} diff --git a/internal/util/function/bedrock_text_embedding_provider_test.go b/internal/util/function/bedrock_text_embedding_provider_test.go new file mode 100644 index 0000000000000..e8f08df77e8d1 --- /dev/null +++ b/internal/util/function/bedrock_text_embedding_provider_test.go @@ -0,0 +1,109 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func TestBedrockTextEmbeddingProvider(t *testing.T) { + suite.Run(t, new(BedrockTextEmbeddingProviderSuite)) +} + +type BedrockTextEmbeddingProviderSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + providers []string +} + +func (s *BedrockTextEmbeddingProviderSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + s.providers = []string{BedrockProvider} +} + +func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, dim int) (TextEmbeddingProvider, error) { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: dimParamKey, Value: "4"}, + }, + } + switch providerName { + case BedrockProvider: + return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim}) + default: + return nil, fmt.Errorf("Unknow provider") + } +} + +func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() { + for _, provderName := range s.providers { + provder, err := createBedrockProvider(s.schema.Fields[2], provderName, 4) + s.NoError(err) + { + data := []string{"sentence"} + ret, err2 := provder.CallEmbedding(data, false, InsertMode) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(4, len(ret[0])) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0]) + } + { + data := []string{"sentence 1", "sentence 2", "sentence 3"} + ret, _ := provder.CallEmbedding(data, false, SearchMode) + s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}}, ret) + } + } +} + +func (s *BedrockTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { + for _, provderName := range s.providers { + provder, err := createBedrockProvider(s.schema.Fields[2], provderName, 2) + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) + } +} diff --git a/internal/util/function/common.go b/internal/util/function/common.go new file mode 100644 index 0000000000000..a6c4fe4840b0e --- /dev/null +++ b/internal/util/function/common.go @@ -0,0 +1,99 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "strconv" +) + +const ( + InsertMode string = "Insert" + SearchMode string = "Search" +) + +// common params +const ( + modelNameParamKey string = "model_name" + dimParamKey string = "dim" + embeddingURLParamKey string = "url" + apiKeyParamKey string = "api_key" +) + +// ali text embedding +const ( + TextEmbeddingV1 string = "text-embedding-v1" + TextEmbeddingV2 string = "text-embedding-v2" + TextEmbeddingV3 string = "text-embedding-v3" + + dashscopeApiKey string = "MILVUS_DASHSCOPE_API_KEY" +) + +// openai/azure text embedding + +const ( + TextEmbeddingAda002 string = "text-embedding-ada-002" + TextEmbedding3Small string = "text-embedding-3-small" + TextEmbedding3Large string = "text-embedding-3-large" + + openaiApiKey string = "MILVUSAI_OPENAI_API_KEY" + + azureOpenaiApiKey string = "MILVUSAI_AZURE_OPENAI_API_KEY" + azureOpenaiEndpoint string = "MILVUSAI_AZURE_OPENAI_ENDPOINT" + + userParamKey string = "user" +) + +// bedrock emebdding + +const ( + BedRockTitanTextEmbeddingsV2 string = "amazon.titan-embed-text-v2:0" + awsAccessKeyIdParamKey string = "aws_access_key_id" + awsSecretAccessKeyParamKey string = "aws_secret_access_key" + regionParamKey string = "regin" + normalizeParamKey string = "normalize" + + bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID" + bedrockSecretAccessKey string = "MILVUSAI_BEDROCK_SECRET_ACCESS_KEY" +) + +// vertexAI + +const ( + locationParamKey string = "location" + projectIDParamKey string = "projectid" + taskTypeParamKey string = "task" + + textEmbedding005 string = "text-embedding-005" + textMultilingualEmbedding002 string = "text-multilingual-embedding-002" + + vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS" +) + +func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) { + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("dim [%s] is not int", dimStr) + } + + if dim != 0 && dim != fieldDim { + return 0, fmt.Errorf("Field %s's dim is [%d], but embedding's dim is [%d]", fieldName, fieldDim, dim) + } + return dim, nil +} diff --git a/internal/util/function/function.go b/internal/util/function/function.go index a9056af41298d..7c3bae8ca4833 100644 --- a/internal/util/function/function.go +++ b/internal/util/function/function.go @@ -35,6 +35,8 @@ func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.Functio switch schema.GetType() { case schemapb.FunctionType_BM25: return NewBM25FunctionRunner(coll, schema) + case schemapb.FunctionType_TextEmbedding: + return nil, nil default: return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String()) } diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go new file mode 100644 index 0000000000000..aabcfdf5c0ea2 --- /dev/null +++ b/internal/util/function/function_base.go @@ -0,0 +1,57 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +type FunctionBase struct { + schema *schemapb.FunctionSchema + outputFields []*schemapb.FieldSchema +} + +func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.FunctionSchema) (*FunctionBase, error) { + var base FunctionBase + base.schema = fSchema + for _, fieldName := range fSchema.GetOutputFieldNames() { + for _, field := range coll.GetFields() { + if field.GetName() == fieldName { + base.outputFields = append(base.outputFields, field) + break + } + } + } + + if len(base.outputFields) != len(fSchema.GetOutputFieldNames()) { + return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema", + coll.Name, fSchema.Name) + } + return &base, nil +} + +func (base *FunctionBase) GetSchema() *schemapb.FunctionSchema { + return base.schema +} + +func (base *FunctionBase) GetOutputFields() []*schemapb.FieldSchema { + return base.outputFields +} diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go new file mode 100644 index 0000000000000..011380ae5a226 --- /dev/null +++ b/internal/util/function/function_executor.go @@ -0,0 +1,257 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "sync" + + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type Runner interface { + GetSchema() *schemapb.FunctionSchema + GetOutputFields() []*schemapb.FieldSchema + + MaxBatch() int + ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) + ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) + ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) +} + +type FunctionExecutor struct { + runners map[int64]Runner +} + +func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (Runner, error) { + switch schema.GetType() { + case schemapb.FunctionType_BM25: // ignore bm25 function + return nil, nil + case schemapb.FunctionType_TextEmbedding: + f, err := NewTextEmbeddingFunction(coll, schema) + if err != nil { + return nil, err + } + return f, nil + default: + return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String()) + } +} + +func CheckFunctions(schema *schemapb.CollectionSchema) error { + for _, fSchema := range schema.Functions { + if _, err := createFunction(schema, fSchema); err != nil { + return err + } + } + return nil +} + +func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { + // If the function's outputs exists in outputIDs, then create the function + // when outputIDs is empty, create all functions + // Please note that currently we only support the function with one input and one output. + executor := &FunctionExecutor{ + runners: make(map[int64]Runner), + } + for _, fSchema := range schema.Functions { + if runner, err := createFunction(schema, fSchema); err != nil { + return nil, err + } else { + if runner != nil { + executor.runners[fSchema.GetOutputFieldIds()[0]] = runner + } + } + } + return executor, nil +} + +func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { + inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().GetInputFieldNames())) + for _, name := range runner.GetSchema().GetInputFieldNames() { + for _, field := range msg.FieldsData { + if field.GetFieldName() == name { + inputs = append(inputs, field) + } + } + } + if len(inputs) != len(runner.GetSchema().InputFieldIds) { + return nil, fmt.Errorf("Input field not found") + } + + outputs, err := runner.ProcessInsert(inputs) + if err != nil { + return nil, err + } + return outputs, nil +} + +func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error { + numRows := msg.NumRows + for _, runner := range executor.runners { + if numRows > uint64(runner.MaxBatch()) { + return fmt.Errorf("numRows [%d] > function [%s]'s max batch [%d]", numRows, runner.GetSchema().Name, runner.MaxBatch()) + } + } + + outputs := make(chan []*schemapb.FieldData, len(executor.runners)) + errChan := make(chan error, len(executor.runners)) + var wg sync.WaitGroup + for _, runner := range executor.runners { + wg.Add(1) + go func(runner Runner) { + defer wg.Done() + data, err := executor.processSingleFunction(runner, msg) + if err != nil { + errChan <- err + return + } + outputs <- data + }(runner) + } + wg.Wait() + close(errChan) + close(outputs) + + // Collect all errors + var errs []error + for err := range errChan { + errs = append(errs, err) + } + if len(errs) > 0 { + return fmt.Errorf("multiple errors occurred: %v", errs) + } + + for output := range outputs { + msg.FieldsData = append(msg.FieldsData, output...) + } + return nil +} + +func (executor *FunctionExecutor) processSingleSearch(runner Runner, placeholderGroup []byte) ([]byte, error) { + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(placeholderGroup, pb) + if len(pb.Placeholders) != 1 { + return nil, merr.WrapErrParameterInvalidMsg("No placeholders founded") + } + if pb.Placeholders[0].Type != commonpb.PlaceholderType_VarChar { + return placeholderGroup, nil + } + res, err := runner.ProcessSearch(pb) + if err != nil { + return nil, err + } + return proto.Marshal(res) +} + +func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) error { + runner, exist := executor.runners[req.FieldId] + if !exist { + return nil + } + if req.Nq > int64(runner.MaxBatch()) { + return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", req.Nq, runner.GetSchema().Name, runner.MaxBatch()) + } + if newHolder, err := executor.processSingleSearch(runner, req.GetPlaceholderGroup()); err != nil { + return err + } else { + req.PlaceholderGroup = newHolder + } + return nil +} + +func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequest) error { + outputs := make(chan map[int64][]byte, len(req.GetSubReqs())) + errChan := make(chan error, len(req.GetSubReqs())) + var wg sync.WaitGroup + for idx, sub := range req.GetSubReqs() { + if runner, exist := executor.runners[sub.FieldId]; exist { + if sub.Nq > int64(runner.MaxBatch()) { + return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", sub.Nq, runner.GetSchema().Name, runner.MaxBatch()) + } + wg.Add(1) + go func(runner Runner, idx int64) { + defer wg.Done() + if newHolder, err := executor.processSingleSearch(runner, sub.GetPlaceholderGroup()); err != nil { + errChan <- err + } else { + outputs <- map[int64][]byte{idx: newHolder} + } + }(runner, int64(idx)) + } + } + wg.Wait() + close(errChan) + close(outputs) + for err := range errChan { + return err + } + + for output := range outputs { + for idx, holder := range output { + req.SubReqs[idx].PlaceholderGroup = holder + } + } + return nil +} + +func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) error { + if !req.IsAdvanced { + return executor.prcessSearch(req) + } + return executor.prcessAdvanceSearch(req) +} + +func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) { + inputs := make([]storage.FieldData, 0, len(runner.GetSchema().InputFieldIds)) + for idx, id := range runner.GetSchema().InputFieldIds { + field, exist := data.Data[id] + if !exist { + return nil, fmt.Errorf("Can not find input field: [%s]", runner.GetSchema().GetInputFieldNames()[idx]) + } + inputs = append(inputs, field) + } + + outputs, err := runner.ProcessBulkInsert(inputs) + if err != nil { + return nil, err + } + return outputs, nil +} + +func (executor *FunctionExecutor) ProcessBulkInsert(data *storage.InsertData) error { + // Since concurrency has already been used in the outer layer, only a serial logic access model is used here. + for _, runner := range executor.runners { + output, err := executor.processSingleBulkInsert(runner, data) + if err != nil { + return nil + } + for k, v := range output { + data.Data[k] = v + } + } + return nil +} diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go new file mode 100644 index 0000000000000..a38360d343660 --- /dev/null +++ b/internal/util/function/function_executor_test.go @@ -0,0 +1,194 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/openai" + "github.com/milvus-io/milvus/pkg/mq/msgstream" +) + +func TestFunctionExecutor(t *testing.T) { + suite.Run(t, new(FunctionExecutorSuite)) +} + +type FunctionExecutorSuite struct { + suite.Suite +} + +func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema { + return &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + IsFunctionOutput: true, + }, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "8"}, + }, + IsFunctionOutput: true, + }, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, + OutputFieldNames: []string{"vector"}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: url}, + {Key: dimParamKey, Value: "4"}, + }, + }, + { + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{103}, + OutputFieldNames: []string{"vector2"}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: url}, + {Key: dimParamKey, Value: "8"}, + }, + }, + }, + } +} + +func (s *FunctionExecutorSuite) createMsg(texts []string) *msgstream.InsertMsg { + data := []*schemapb.FieldData{} + f := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + FieldName: "text", + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: texts, + }, + }, + }, + }, + } + data = append(data, &f) + + msg := msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + FieldsData: data, + }, + } + return &msg +} + +func (s *FunctionExecutorSuite) createEmbedding(texts []string, dim int) [][]float32 { + embeddings := make([][]float32, 0) + for i := 0; i < len(texts); i++ { + f := float32(i) + emb := make([]float32, 0) + for j := 0; j < dim; j++ { + emb = append(emb, f+float32(j)*0.1) + } + embeddings = append(embeddings, emb) + } + return embeddings +} + +func (s *FunctionExecutorSuite) TestExecutor() { + ts := CreateOpenAIEmbeddingServer() + defer ts.Close() + schema := s.creataSchema(ts.URL) + exec, err := NewFunctionExecutor(schema) + s.NoError(err) + msg := s.createMsg([]string{"sentence", "sentence"}) + exec.ProcessInsert(msg) + s.Equal(len(msg.FieldsData), 3) +} + +func (s *FunctionExecutorSuite) TestErrorEmbedding() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req openai.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + + var res openai.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, openai.EmbeddingData{ + Object: "embedding", + Embedding: []float32{}, + Index: i, + }) + } + + res.Usage = openai.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + defer ts.Close() + schema := s.creataSchema(ts.URL) + exec, err := NewFunctionExecutor(schema) + fmt.Println(err) + s.NoError(err) + msg := s.createMsg([]string{"sentence", "sentence"}) + err = exec.ProcessInsert(msg) + s.Error(err) +} + +func (s *FunctionExecutorSuite) TestErrorSchema() { + schema := s.creataSchema("http://localhost") + schema.Functions[0].Type = schemapb.FunctionType_Unknown + _, err := NewFunctionExecutor(schema) + s.Error(err) +} diff --git a/internal/util/function/function_util.go b/internal/util/function/function_util.go new file mode 100644 index 0000000000000..240e13615b4f7 --- /dev/null +++ b/internal/util/function/function_util.go @@ -0,0 +1,61 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool { + // Determine whether the column corresponding to outputIDs contains functions, except bm25 function, + // if outputIDs is empty, check all cols + for _, fSchema := range functions { + switch fSchema.GetType() { + case schemapb.FunctionType_BM25: + case schemapb.FunctionType_Unknown: + default: + if len(outputIDs) == 0 { + return true + } else { + for _, id := range outputIDs { + if fSchema.GetOutputFieldIds()[0] == id { + return true + } + } + } + } + } + return false +} + +func GetOutputIDFunctionsMap(functions []*schemapb.FunctionSchema) (map[int64]*schemapb.FunctionSchema, error) { + outputIdMap := map[int64]*schemapb.FunctionSchema{} + for _, fSchema := range functions { + switch fSchema.GetType() { + case schemapb.FunctionType_BM25: + default: + if len(fSchema.OutputFieldIds) != 1 { + return nil, merr.WrapErrParameterInvalidMsg("Function [%s]'s outputs err, only supports one outputs", fSchema.Name) + } + outputIdMap[fSchema.OutputFieldIds[0]] = fSchema + } + } + return outputIdMap, nil +} diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go new file mode 100644 index 0000000000000..c071a2056df72 --- /dev/null +++ b/internal/util/function/mock_embedding_service.go @@ -0,0 +1,150 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + + "github.com/milvus-io/milvus/internal/models/ali" + "github.com/milvus-io/milvus/internal/models/openai" + "github.com/milvus-io/milvus/internal/models/vertexai" + + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" +) + +func mockEmbedding(texts []string, dim int) [][]float32 { + embeddings := make([][]float32, 0) + for i := 0; i < len(texts); i++ { + f := float32(i) + emb := make([]float32, 0) + for j := 0; j < dim; j++ { + emb = append(emb, f+float32(j)*0.1) + } + embeddings = append(embeddings, emb) + } + return embeddings +} + +func CreateOpenAIEmbeddingServer() *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req openai.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + embs := mockEmbedding(req.Input, req.Dimensions) + var res openai.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, openai.EmbeddingData{ + Object: "embedding", + Embedding: embs[i], + Index: i, + }) + } + + res.Usage = openai.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + return ts +} + +func CreateAliEmbeddingServer() *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req ali.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + embs := mockEmbedding(req.Input.Texts, req.Parameters.Dimension) + var res ali.EmbeddingResponse + for i := 0; i < len(req.Input.Texts); i++ { + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: embs[i], + TextIndex: i, + }) + } + + res.Usage = ali.Usage{ + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + return ts +} + +func CreateVertexAIEmbeddingServer() *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req vertexai.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + var texts []string + for _, item := range req.Instances { + texts = append(texts, item.Content) + } + embs := mockEmbedding(texts, int(req.Parameters.OutputDimensionality)) + var res vertexai.EmbeddingResponse + for i := 0; i < len(req.Instances); i++ { + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: embs[i], + }, + }) + } + + res.Metadata = vertexai.Metadata{ + BillableCharacterCount: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + return ts +} + +type MockBedrockClient struct { + dim int +} + +func (c *MockBedrockClient) InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) { + var req BedRockRequest + json.Unmarshal(params.Body, &req) + embs := mockEmbedding([]string{req.InputText}, c.dim) + + var resp BedRockResponse + resp.Embedding = embs[0] + resp.InputTextTokenCount = 2 + body, _ := json.Marshal(resp) + return &bedrockruntime.InvokeModelOutput{Body: body}, nil +} diff --git a/internal/util/function/openai_embedding_provider.go b/internal/util/function/openai_embedding_provider.go new file mode 100644 index 0000000000000..8b5f53bc7fd2c --- /dev/null +++ b/internal/util/function/openai_embedding_provider.go @@ -0,0 +1,179 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/openai" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type OpenAIEmbeddingProvider struct { + fieldDim int64 + + client openai.OpenAIEmbeddingInterface + modelName string + embedDimParam int64 + user string + + maxBatch int + timeoutSec int +} + +func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) { + if apiKey == "" { + apiKey = os.Getenv(openaiApiKey) + } + if apiKey == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiApiKey) + } + + if url == "" { + url = "https://api.openai.com/v1/embeddings" + } + + c := openai.NewOpenAIEmbeddingClient(apiKey, url) + return c, nil +} + +func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureOpenAIEmbeddingClient, error) { + if apiKey == "" { + apiKey = os.Getenv(azureOpenaiApiKey) + } + if apiKey == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiApiKey) + } + + if url == "" { + url = os.Getenv(azureOpenaiEndpoint) + } + if url == "" { + return nil, fmt.Errorf("Must provide `url` arguments or configure the %s environment variable in the Milvus service", azureOpenaiEndpoint) + } + c := openai.NewAzureOpenAIEmbeddingClient(apiKey, url) + return c, nil +} + +func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, isAzure bool) (*OpenAIEmbeddingProvider, error) { + fieldDim, err := typeutil.GetDim(fieldSchema) + if err != nil { + return nil, err + } + var apiKey, url, modelName, user string + var dim int64 + + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case modelNameParamKey: + modelName = param.Value + case dimParamKey: + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) + if err != nil { + return nil, err + } + case userParamKey: + user = param.Value + case apiKeyParamKey: + apiKey = param.Value + case embeddingURLParamKey: + url = param.Value + default: + } + } + + var c openai.OpenAIEmbeddingInterface + if !isAzure { + if modelName != TextEmbeddingAda002 && modelName != TextEmbedding3Small && modelName != TextEmbedding3Large { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", + modelName, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large) + } + + c, err = createOpenAIEmbeddingClient(apiKey, url) + if err != nil { + return nil, err + } + } else { + c, err = createAzureOpenAIEmbeddingClient(apiKey, url) + if err != nil { + return nil, err + } + } + + provider := OpenAIEmbeddingProvider{ + client: c, + fieldDim: fieldDim, + modelName: modelName, + user: user, + embedDimParam: dim, + maxBatch: 128, + timeoutSec: 30, + } + return &provider, nil +} + +func NewOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) { + return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, false) +} + +func NewAzureOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) { + return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, true) +} + +func (provider *OpenAIEmbeddingProvider) MaxBatch() int { + return 5 * provider.maxBatch +} + +func (provider *OpenAIEmbeddingProvider) FieldDim() int64 { + return provider.fieldDim +} + +func (provider *OpenAIEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, _ string) ([][]float32, error) { + numRows := len(texts) + if batchLimit && numRows > provider.MaxBatch() { + return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) + } + + data := make([][]float32, 0, numRows) + for i := 0; i < numRows; i += provider.maxBatch { + end := i + provider.maxBatch + if end > numRows { + end = numRows + } + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), provider.user, time.Duration(provider.timeoutSec)) + if err != nil { + return nil, err + } + if end-i != len(resp.Data) { + return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Data)) + } + for _, item := range resp.Data { + if len(item.Embedding) != int(provider.fieldDim) { + return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]", + provider.fieldDim, len(item.Embedding)) + } + data = append(data, item.Embedding) + } + } + return data, nil +} diff --git a/internal/util/function/openai_text_embedding_provider_test.go b/internal/util/function/openai_text_embedding_provider_test.go new file mode 100644 index 0000000000000..09b120e0603d5 --- /dev/null +++ b/internal/util/function/openai_text_embedding_provider_test.go @@ -0,0 +1,177 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/openai" + + "github.com/stretchr/testify/suite" +) + +func TestOpenAITextEmbeddingProvider(t *testing.T) { + suite.Run(t, new(OpenAITextEmbeddingProviderSuite)) +} + +type OpenAITextEmbeddingProviderSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + providers []string +} + +func (s *OpenAITextEmbeddingProviderSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + s.providers = []string{OpenAIProvider, AzureOpenAIProvider} +} + +func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName string) (TextEmbeddingProvider, error) { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: dimParamKey, Value: "4"}, + {Key: embeddingURLParamKey, Value: url}, + }, + } + switch providerName { + case OpenAIProvider: + return NewOpenAIEmbeddingProvider(schema, functionSchema) + case AzureOpenAIProvider: + return NewAzureOpenAIEmbeddingProvider(schema, functionSchema) + default: + return nil, fmt.Errorf("Unknow provider") + } +} + +func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() { + ts := CreateOpenAIEmbeddingServer() + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName) + s.NoError(err) + { + data := []string{"sentence"} + ret, err2 := provder.CallEmbedding(data, false, InsertMode) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(4, len(ret[0])) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0]) + } + { + data := []string{"sentence 1", "sentence 2", "sentence 3"} + ret, _ := provder.CallEmbedding(data, false, SearchMode) + s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) + } + } +} + +func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res openai.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + res.Data = append(res.Data, openai.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + Index: 0, + }) + + res.Data = append(res.Data, openai.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0}, + Index: 1, + }) + res.Usage = openai.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName) + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) + } +} + +func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res openai.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + res.Data = append(res.Data, openai.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + Index: 0, + }) + res.Usage = openai.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName) + + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence2"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) + } +} diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go new file mode 100644 index 0000000000000..030679df812fb --- /dev/null +++ b/internal/util/function/text_embedding_function.go @@ -0,0 +1,222 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/funcutil" +) + +const ( + Provider string = "provider" +) + +const ( + OpenAIProvider string = "openai" + AzureOpenAIProvider string = "azure_openai" + AliDashScopeProvider string = "dashscope" + BedrockProvider string = "bedrock" + VertexAIProvider string = "vertexai" +) + +// Text embedding for retrieval task +type TextEmbeddingProvider interface { + MaxBatch() int + CallEmbedding(texts []string, batchLimit bool, mode string) ([][]float32, error) + FieldDim() int64 +} + +func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) { + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case Provider: + return strings.ToLower(param.Value), nil + default: + } + } + return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider) +} + +type TextEmebddingFunction struct { + FunctionBase + + embProvider TextEmbeddingProvider +} + +func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *schemapb.FunctionSchema) (*TextEmebddingFunction, error) { + if len(functionSchema.GetOutputFieldNames()) != 1 { + return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldNames())) + } + + base, err := NewFunctionBase(coll, functionSchema) + if err != nil { + return nil, err + } + + if base.outputFields[0].DataType != schemapb.DataType_FloatVector { + return nil, fmt.Errorf("Text embedding function's output field not match, needs [%s], got [%s]", + schemapb.DataType_name[int32(schemapb.DataType_FloatVector)], + schemapb.DataType_name[int32(base.outputFields[0].DataType)]) + } + + provider, err := getProvider(functionSchema) + if err != nil { + return nil, err + } + switch provider { + case OpenAIProvider: + embP, err := NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil + case AzureOpenAIProvider: + embP, err := NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil + case BedrockProvider: + embP, err := NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil + case AliDashScopeProvider: + embP, err := NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil + case VertexAIProvider: + embP, err := NewVertextAIEmbeddingProvider(base.outputFields[0], functionSchema, nil) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil + default: + return nil, fmt.Errorf("Unsupported embedding service provider: [%s] , list of supported [%s, %s, %s, %s]", provider, OpenAIProvider, AzureOpenAIProvider, AliDashScopeProvider, BedrockProvider) + } +} + +func (runner *TextEmebddingFunction) MaxBatch() int { + return runner.embProvider.MaxBatch() +} + +func (runner *TextEmebddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("Text embedding function only receives one input, bug got [%d]", len(inputs)) + } + + if inputs[0].Type != schemapb.DataType_VarChar { + return nil, fmt.Errorf("Text embedding only supports varchar field as input field, but got %s", schemapb.DataType_name[int32(inputs[0].Type)]) + } + + texts := inputs[0].GetScalars().GetStringData().GetData() + if texts == nil { + return nil, fmt.Errorf("Input texts is empty") + } + + embds, err := runner.embProvider.CallEmbedding(texts, true, InsertMode) + if err != nil { + return nil, err + } + data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim())) + for _, emb := range embds { + data = append(data, emb...) + } + + var outputField schemapb.FieldData + outputField.FieldId = runner.outputFields[0].FieldID + outputField.FieldName = runner.outputFields[0].Name + outputField.Type = runner.outputFields[0].DataType + outputField.IsDynamic = runner.outputFields[0].IsDynamic + outputField.Field = &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: data, + }, + }, + Dim: runner.embProvider.FieldDim(), + }, + } + return []*schemapb.FieldData{&outputField}, nil +} + +func (runner *TextEmebddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) { + texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally + embds, err := runner.embProvider.CallEmbedding(texts, true, SearchMode) + if err != nil { + return nil, err + } + return funcutil.Float32VectorsToPlaceholderGroup(embds), nil +} + +func (runner *TextEmebddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("TextEmbedding function only receives one input, bug got [%d]", len(inputs)) + } + + if inputs[0].GetDataType() != schemapb.DataType_VarChar { + return nil, fmt.Errorf(" only supports varchar field, the input is not varchar") + } + + texts, ok := inputs[0].GetDataRows().([]string) + if !ok { + return nil, fmt.Errorf("Input texts is empty") + } + + embds, err := runner.embProvider.CallEmbedding(texts, false, InsertMode) + if err != nil { + return nil, err + } + data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim())) + for _, emb := range embds { + data = append(data, emb...) + } + + field := &storage.FloatVectorFieldData{ + Data: data, + Dim: int(runner.embProvider.FieldDim()), + } + return map[storage.FieldID]storage.FieldData{ + runner.outputFields[0].FieldID: field, + }, nil +} diff --git a/internal/util/function/text_embedding_function_test.go b/internal/util/function/text_embedding_function_test.go new file mode 100644 index 0000000000000..353684e55b838 --- /dev/null +++ b/internal/util/function/text_embedding_function_test.go @@ -0,0 +1,311 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func TestTextEmbeddingFunction(t *testing.T) { + suite.Run(t, new(TextEmbeddingFunctionSuite)) +} + +type TextEmbeddingFunctionSuite struct { + suite.Suite + schema *schemapb.CollectionSchema +} + +func (s *TextEmbeddingFunctionSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } +} + +func createData(texts []string) []*schemapb.FieldData { + data := []*schemapb.FieldData{} + f := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: texts, + }, + }, + }, + }, + } + data = append(data, &f) + return data +} + +func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { + ts := CreateOpenAIEmbeddingServer() + defer ts.Close() + { + runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: ts.URL}, + }, + }) + s.NoError(err) + + { + data := createData([]string{"sentence"}) + ret, err2 := runner.ProcessInsert(data) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(int64(4), ret[0].GetVectors().Dim) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data) + } + { + data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) + ret, _ := runner.ProcessInsert(data) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) + } + } + { + + runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: AzureOpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: ts.URL}, + }, + }) + s.NoError(err) + + { + data := createData([]string{"sentence"}) + ret, err2 := runner.ProcessInsert(data) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(int64(4), ret[0].GetVectors().Dim) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data) + } + { + data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) + ret, _ := runner.ProcessInsert(data) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) + } + } +} + +func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { + ts := CreateAliEmbeddingServer() + defer ts.Close() + + runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: AliDashScopeProvider}, + {Key: modelNameParamKey, Value: TextEmbeddingV3}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: ts.URL}, + }, + }) + s.NoError(err) + + { + data := createData([]string{"sentence"}) + ret, err2 := runner.ProcessInsert(data) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(int64(4), ret[0].GetVectors().Dim) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data) + } + { + data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) + ret, _ := runner.ProcessInsert(data) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) + } +} + +func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { + // outputfield datatype mismatch + { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + + _, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, + }, + }) + s.Error(err) + } + + // outputfield number mismatc + { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + _, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector", "vector2"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102, 103}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, + }, + }) + s.Error(err) + } + + // outputfield miss + { + _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector2"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{103}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, + }, + }) + s.Error(err) + } + + // error model name + { + _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-004"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, + }, + }) + s.Error(err) + } + + // no openai api key + { + _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-003"}, + }, + }) + s.Error(err) + } +} diff --git a/internal/util/function/vertexai_embedding_provider.go b/internal/util/function/vertexai_embedding_provider.go new file mode 100644 index 0000000000000..8eb4ad52ff1f9 --- /dev/null +++ b/internal/util/function/vertexai_embedding_provider.go @@ -0,0 +1,216 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/vertexai" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type vertexAIJsonKey struct { + jsonKey []byte + once sync.Once + initErr error +} + +var vtxKey vertexAIJsonKey + +func getVertexAIJsonKey() ([]byte, error) { + vtxKey.once.Do(func() { + jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv) + jsonKey, err := os.ReadFile(jsonKeyPath) + if err != nil { + vtxKey.initErr = fmt.Errorf("Read service account json file failed, %v", err) + return + } + vtxKey.jsonKey = jsonKey + }) + return vtxKey.jsonKey, vtxKey.initErr +} + +const ( + vertexAIDocRetrival string = "DOC_RETRIEVAL" + vertexAICodeRetrival string = "CODE_RETRIEVAL" + vertexAISTS string = "STS" +) + +func checkTask(modelName string, task string) error { + if task != vertexAIDocRetrival && task != vertexAICodeRetrival && task != vertexAISTS { + return fmt.Errorf("Unsupport task %s, the supported list: [%s, %s, %s]", task, vertexAIDocRetrival, vertexAICodeRetrival, vertexAISTS) + } + if modelName == textMultilingualEmbedding002 && task == vertexAICodeRetrival { + return fmt.Errorf("Model %s doesn't support %s task", textMultilingualEmbedding002, vertexAICodeRetrival) + } + return nil +} + +type VertextAIEmbeddingProvider struct { + fieldDim int64 + + client *vertexai.VertexAIEmbedding + modelName string + embedDimParam int64 + task string + + maxBatch int + timeoutSec int +} + +func createVertextAIEmbeddingClient(url string) (*vertexai.VertexAIEmbedding, error) { + jsonKey, err := getVertexAIJsonKey() + if err != nil { + return nil, err + } + c := vertexai.NewVertexAIEmbedding(url, jsonKey, "https://www.googleapis.com/auth/cloud-platform", "") + return c, nil +} + +func NewVertextAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding) (*VertextAIEmbeddingProvider, error) { + fieldDim, err := typeutil.GetDim(fieldSchema) + if err != nil { + return nil, err + } + var location, projectID, task, modelName string + var dim int64 + + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case modelNameParamKey: + modelName = param.Value + case dimParamKey: + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) + if err != nil { + return nil, err + } + case locationParamKey: + location = param.Value + case projectIDParamKey: + projectID = param.Value + case taskTypeParamKey: + task = param.Value + default: + } + } + + if task == "" { + task = vertexAIDocRetrival + } + if err := checkTask(modelName, task); err != nil { + return nil, err + } + + if location == "" { + location = "us-central1" + } + + if modelName != textEmbedding005 && modelName != textMultilingualEmbedding002 { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s]", + modelName, textEmbedding005, textMultilingualEmbedding002) + } + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName) + var client *vertexai.VertexAIEmbedding + if c == nil { + client, err = createVertextAIEmbeddingClient(url) + if err != nil { + return nil, err + } + } else { + client = c + } + + provider := VertextAIEmbeddingProvider{ + fieldDim: fieldDim, + client: client, + modelName: modelName, + embedDimParam: dim, + task: task, + maxBatch: 128, + timeoutSec: 30, + } + return &provider, nil +} + +func (provider *VertextAIEmbeddingProvider) MaxBatch() int { + return 5 * provider.maxBatch +} + +func (provider *VertextAIEmbeddingProvider) FieldDim() int64 { + return provider.fieldDim +} + +func (provider *VertextAIEmbeddingProvider) getTaskType(mode string) string { + if mode == SearchMode { + switch provider.task { + case vertexAIDocRetrival: + return "RETRIEVAL_QUERY" + case vertexAICodeRetrival: + return "CODE_RETRIEVAL_QUERY" + case vertexAISTS: + return "SEMANTIC_SIMILARITY" + } + } else { + switch provider.task { + case vertexAIDocRetrival: + return "RETRIEVAL_DOCUMENT" + case vertexAICodeRetrival: + return "RETRIEVAL_DOCUMENT" + case vertexAISTS: + return "SEMANTIC_SIMILARITY" + } + } + return "" +} + +func (provider *VertextAIEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, mode string) ([][]float32, error) { + numRows := len(texts) + if batchLimit && numRows > provider.MaxBatch() { + return nil, fmt.Errorf("VertextAI text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) + } + + taskType := provider.getTaskType(mode) + data := make([][]float32, 0, numRows) + for i := 0; i < numRows; i += provider.maxBatch { + end := i + provider.maxBatch + if end > numRows { + end = numRows + } + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], provider.embedDimParam, taskType, time.Duration(provider.timeoutSec)) + if err != nil { + return nil, err + } + if end-i != len(resp.Predictions) { + return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Predictions)) + } + for _, item := range resp.Predictions { + if len(item.Embeddings.Values) != int(provider.fieldDim) { + return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]", + provider.fieldDim, len(item.Embeddings.Values)) + } + data = append(data, item.Embeddings.Values) + } + } + return data, nil +} diff --git a/internal/util/function/vertexai_embedding_provider_test.go b/internal/util/function/vertexai_embedding_provider_test.go new file mode 100644 index 0000000000000..10a9093d69634 --- /dev/null +++ b/internal/util/function/vertexai_embedding_provider_test.go @@ -0,0 +1,173 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/vertexai" +) + +func TestVertextAITextEmbeddingProvider(t *testing.T) { + suite.Run(t, new(VertextAITextEmbeddingProviderSuite)) +} + +type VertextAITextEmbeddingProviderSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + providers []string +} + +func (s *VertextAITextEmbeddingProviderSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } +} + +func createVertextAIProvider(url string, schema *schemapb.FieldSchema) (TextEmbeddingProvider, error) { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: modelNameParamKey, Value: textEmbedding005}, + {Key: locationParamKey, Value: "mock_local"}, + {Key: projectIDParamKey, Value: "mock_id"}, + {Key: taskTypeParamKey, Value: vertexAICodeRetrival}, + {Key: embeddingURLParamKey, Value: url}, + {Key: dimParamKey, Value: "4"}, + }, + } + mockClient := vertexai.NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock scope", "mock token") + return NewVertextAIEmbeddingProvider(schema, functionSchema, mockClient) +} + +func (s *VertextAITextEmbeddingProviderSuite) TestEmbedding() { + ts := CreateVertexAIEmbeddingServer() + + defer ts.Close() + provder, err := createVertextAIProvider(ts.URL, s.schema.Fields[2]) + s.NoError(err) + { + data := []string{"sentence"} + ret, err2 := provder.CallEmbedding(data, false, InsertMode) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(4, len(ret[0])) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0]) + } + { + data := []string{"sentence 1", "sentence 2", "sentence 3"} + ret, _ := provder.CallEmbedding(data, false, SearchMode) + s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) + } +} + +func (s *VertextAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res vertexai.EmbeddingResponse + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: []float32{1.0, 1.0, 1.0, 1.0}, + }, + }) + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: []float32{1.0, 1.0}, + }, + }) + + res.Metadata = vertexai.Metadata{ + BillableCharacterCount: 100, + } + + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + provder, err := createVertextAIProvider(ts.URL, s.schema.Fields[2]) + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) +} + +func (s *VertextAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res vertexai.EmbeddingResponse + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: []float32{1.0, 1.0, 1.0, 1.0}, + }, + }) + res.Metadata = vertexai.Metadata{ + BillableCharacterCount: 100, + } + + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + provder, err := createVertextAIProvider(ts.URL, s.schema.Fields[2]) + + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence2"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) +} diff --git a/pkg/util/funcutil/placeholdergroup.go b/pkg/util/funcutil/placeholdergroup.go index 96ecdfa4df8fb..60aa1aa7ecbe4 100644 --- a/pkg/util/funcutil/placeholdergroup.go +++ b/pkg/util/funcutil/placeholdergroup.go @@ -25,6 +25,21 @@ func SparseVectorDataToPlaceholderGroupBytes(contents [][]byte) []byte { return bytes } +func Float32VectorsToPlaceholderGroup(embs [][]float32) *commonpb.PlaceholderGroup { + result := make([][]byte, 0, len(embs)) + for _, floatVector := range embs { + result = append(result, floatVectorToByteVector(floatVector)) + } + placeholderGroup := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{{ + Tag: "$0", + Type: commonpb.PlaceholderType_FloatVector, + Values: result, + }}, + } + return placeholderGroup +} + func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) { placeholderValue, err := fieldDataToPlaceholderValue(fieldData) if err != nil {