From 9705f0c5a423b3ba36ee29d1eb590084a77043e7 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Thu, 19 Sep 2024 17:50:17 +0800 Subject: [PATCH 01/18] Add openai embedding client Signed-off-by: junjie.jiang --- internal/models/openai_embedding.go | 192 +++++++++++++++++++++++ internal/models/openai_embedding_test.go | 185 ++++++++++++++++++++++ 2 files changed, 377 insertions(+) create mode 100644 internal/models/openai_embedding.go create mode 100644 internal/models/openai_embedding_test.go diff --git a/internal/models/openai_embedding.go b/internal/models/openai_embedding.go new file mode 100644 index 0000000000000..dbb568377c648 --- /dev/null +++ b/internal/models/openai_embedding.go @@ -0,0 +1,192 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package models + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +const ( + TextEmbeddingAda002 string = "text-embedding-ada-002" + TextEmbedding3Small string = "text-embedding-3-small" + TextEmbedding3Large string = "text-embedding-3-large" +) + + +type EmbeddingRequest struct { + // ID of the model to use. + Model string `json:"model"` + + // Input text to embed, encoded as a string. + Input []string `json:"input"` + + // A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + User string `json:"user,omitempty"` + + // The format to return the embeddings in. Can be either float or base64. + EncodingFormat string `json:"encoding_format,omitempty"` + + // The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` +} + +type Usage struct { + // The number of tokens used by the prompt. + PromptTokens int `json:"prompt_tokens"` + + // The total number of tokens used by the request. + TotalTokens int `json:"total_tokens"` +} + + +type EmbeddingData struct { + // The object type, which is always "embedding". + Object string `json:"object"` + + // The embedding vector, which is a list of floats. + Embedding []float32 `json:"embedding"` + + // The index of the embedding in the list of embeddings. + Index int `json:"index"` +} + + +type EmbeddingResponse struct { + // The object type, which is always "list". + Object string `json:"object"` + + // The list of embeddings generated by the model. + Data []EmbeddingData `json:"data"` + + // The name of the model used to generate the embedding. + Model string `json:"model"` + + // The usage information for the request. + Usage Usage `json:"usage"` +} + +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + Param string `json:"param,omitempty"` + Type string `json:"type"` +} + +type EmbedddingError struct { + Error ErrorInfo `json:"error"` +} + +type OpenAIEmbeddingClient struct { + api_key string + uri string + model_name string +} + +func (c *OpenAIEmbeddingClient) Check() error { + if c.model_name != TextEmbeddingAda002 && c.model_name != TextEmbedding3Small && c.model_name != TextEmbedding3Large { + return fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", + c.model_name, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large) + } + + if c.api_key == "" { + return fmt.Errorf("OpenAI api key is empty") + } + + if c.uri == "" { + return fmt.Errorf("OpenAI embedding uri is empty") + } + return nil +} + + +func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res *EmbeddingResponse) error { + // call openai + resp, err := client.Do(req) + + if err != nil { + return err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + if resp.StatusCode != 200 { + return fmt.Errorf(string(body)) + } + + err = json.Unmarshal(body, &res) + if err != nil { + return err + } + return nil +} + +func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request,res *EmbeddingResponse, max_retries int) error { + var err error + for i := 0; i < max_retries; i++ { + err = c.send(client, req, res) + if err == nil { + return nil + } + } + return err +} + +func (c *OpenAIEmbeddingClient) Embedding(texts []string, dim int, user string, timeout_sec time.Duration) (EmbeddingResponse, error) { + var r EmbeddingRequest + r.Model = c.model_name + r.Input = texts + r.EncodingFormat = "float" + if user != "" { + r.User = user + } + if dim != 0 { + r.Dimensions = dim + } + + var res EmbeddingResponse + data, err := json.Marshal(r) + if err != nil { + return res, err + } + + // call openai + if timeout_sec <= 0 { + timeout_sec = 30 + } + client := &http.Client{ + Timeout: timeout_sec * time.Second, + } + req, err := http.NewRequest("POST" , c.uri, bytes.NewBuffer(data)) + if err != nil { + return res, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", c.api_key) + + err = c.sendWithRetry(client, req, &res, 3) + return res, err + +} diff --git a/internal/models/openai_embedding_test.go b/internal/models/openai_embedding_test.go new file mode 100644 index 0000000000000..788e95e84f1cd --- /dev/null +++ b/internal/models/openai_embedding_test.go @@ -0,0 +1,185 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package models + +import ( + // "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingClientCheck(t *testing.T) { + { + c := OpenAIEmbeddingClient{"mock_key", "mock_uri", "unknow_model"} + err := c.Check(); + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := OpenAIEmbeddingClient{"", "mock_uri", TextEmbeddingAda002} + err := c.Check(); + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := OpenAIEmbeddingClient{"mock_key", "", TextEmbedding3Small} + err := c.Check(); + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := OpenAIEmbeddingClient{"mock_key", "mock_uri", TextEmbedding3Small} + err := c.Check(); + assert.True(t, err == nil) + } +} + + +func TestEmbeddingOK(t *testing.T) { + var res EmbeddingResponse + res.Object = "list" + res.Model = TextEmbedding3Small + res.Data = []EmbeddingData{ + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.3, 4.4}, + Index: 0, + }, + } + res.Usage = Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + url := ts.URL + + { + c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + err := c.Check(); + assert.True(t, err == nil) + ret, err := c.Embedding([]string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret, res) + } +} + + +func TestEmbeddingRetry(t *testing.T) { + var res EmbeddingResponse + res.Object = "list" + res.Model = TextEmbedding3Small + res.Data = []EmbeddingData{ + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.3, 4.4}, + Index: 0, + }, + } + res.Usage = Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + + var count = 0 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if count < 2 { + count += 1 + w.WriteHeader(http.StatusUnauthorized) + } else { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + } + })) + + defer ts.Close() + url := ts.URL + + { + c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + err := c.Check(); + assert.True(t, err == nil) + ret, err := c.Embedding([]string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret, res) + assert.Equal(t, count, 2) + } +} + + +func TestEmbeddingFailed(t *testing.T) { + var count = 0 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count += 1 + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + err := c.Check(); + assert.True(t, err == nil) + _, err = c.Embedding([]string{"sentence"}, 0, "", 0) + assert.True(t, err != nil) + assert.Equal(t, count, 3) + } +} + +func TestTimeout(t *testing.T) { + var st = "Doing" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + st = "Done" + w.WriteHeader(http.StatusUnauthorized) + + })) + + defer ts.Close() + url := ts.URL + + { + c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + err := c.Check(); + assert.True(t, err == nil) + _, err = c.Embedding([]string{"sentence"}, 0, "", 1) + assert.True(t, err != nil) + assert.Equal(t, st, "Doing") + time.Sleep(3 * time.Second) + assert.Equal(t, st, "Done") + } +} From b117d64d2d9e8c6cbd44f3761bc7874d73a4dd82 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Mon, 23 Sep 2024 20:57:11 +0800 Subject: [PATCH 02/18] Add embedding runner Signed-off-by: junjie.jiang --- internal/models/openai_embedding.go | 70 ++-- internal/models/openai_embedding_test.go | 79 ++-- internal/util/function/function.go | 1 + internal/util/function/function_base.go | 67 ++++ .../function/openai_embedding_function.go | 205 +++++++++++ .../openai_embedding_function_test.go | 339 ++++++++++++++++++ 6 files changed, 694 insertions(+), 67 deletions(-) create mode 100644 internal/util/function/function_base.go create mode 100644 internal/util/function/openai_embedding_function.go create mode 100644 internal/util/function/openai_embedding_function_test.go diff --git a/internal/models/openai_embedding.go b/internal/models/openai_embedding.go index dbb568377c648..70e8e2508a14c 100644 --- a/internal/models/openai_embedding.go +++ b/internal/models/openai_embedding.go @@ -22,15 +22,10 @@ import ( "fmt" "io" "net/http" + "sort" "time" ) -const ( - TextEmbeddingAda002 string = "text-embedding-ada-002" - TextEmbedding3Small string = "text-embedding-3-small" - TextEmbedding3Large string = "text-embedding-3-large" -) - type EmbeddingRequest struct { // ID of the model to use. @@ -84,6 +79,16 @@ type EmbeddingResponse struct { Usage Usage `json:"usage"` } + +type ByIndex struct { + resp *EmbeddingResponse +} + +func (eb *ByIndex) Len() int { return len(eb.resp.Data) } +func (eb *ByIndex) Swap(i, j int) { eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i] } +func (eb *ByIndex) Less(i, j int) bool { return eb.resp.Data[i].Index < eb.resp.Data[j].Index } + + type ErrorInfo struct { Code string `json:"code"` Message string `json:"message"` @@ -96,27 +101,28 @@ type EmbedddingError struct { } type OpenAIEmbeddingClient struct { - api_key string - uri string - model_name string + apiKey string + url string } func (c *OpenAIEmbeddingClient) Check() error { - if c.model_name != TextEmbeddingAda002 && c.model_name != TextEmbedding3Small && c.model_name != TextEmbedding3Large { - return fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", - c.model_name, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large) - } - - if c.api_key == "" { + if c.apiKey == "" { return fmt.Errorf("OpenAI api key is empty") } - if c.uri == "" { - return fmt.Errorf("OpenAI embedding uri is empty") + if c.url == "" { + return fmt.Errorf("OpenAI embedding url is empty") } return nil } +func NewOpenAIEmbeddingClient(apiKey string, url string) OpenAIEmbeddingClient{ + return OpenAIEmbeddingClient{ + apiKey: apiKey, + url: url, + } +} + func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res *EmbeddingResponse) error { // call openai @@ -143,9 +149,9 @@ func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res return nil } -func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request,res *EmbeddingResponse, max_retries int) error { +func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request,res *EmbeddingResponse, maxRetries int) error { var err error - for i := 0; i < max_retries; i++ { + for i := 0; i < maxRetries; i++ { err = c.send(client, req, res) if err == nil { return nil @@ -154,9 +160,9 @@ func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Req return err } -func (c *OpenAIEmbeddingClient) Embedding(texts []string, dim int, user string, timeout_sec time.Duration) (EmbeddingResponse, error) { +func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { var r EmbeddingRequest - r.Model = c.model_name + r.Model = modelName r.Input = texts r.EncodingFormat = "float" if user != "" { @@ -166,27 +172,31 @@ func (c *OpenAIEmbeddingClient) Embedding(texts []string, dim int, user string, r.Dimensions = dim } - var res EmbeddingResponse data, err := json.Marshal(r) if err != nil { - return res, err + return nil, err } // call openai - if timeout_sec <= 0 { - timeout_sec = 30 + if timeoutSec <= 0 { + timeoutSec = 30 } client := &http.Client{ - Timeout: timeout_sec * time.Second, + Timeout: timeoutSec * time.Second, } - req, err := http.NewRequest("POST" , c.uri, bytes.NewBuffer(data)) + req, err := http.NewRequest("POST" , c.url, bytes.NewBuffer(data)) if err != nil { - return res, err + return nil, err } req.Header.Set("Content-Type", "application/json") - req.Header.Set("api-key", c.api_key) + req.Header.Set("api-key", c.apiKey) + var res EmbeddingResponse err = c.sendWithRetry(client, req, &res, 3) - return res, err + if err != nil { + return nil, err + } + sort.Sort(&ByIndex{&res}) + return &res, err } diff --git a/internal/models/openai_embedding_test.go b/internal/models/openai_embedding_test.go index 788e95e84f1cd..eb31b9c23dffc 100644 --- a/internal/models/openai_embedding_test.go +++ b/internal/models/openai_embedding_test.go @@ -17,41 +17,34 @@ package models import ( - // "bytes" "encoding/json" "fmt" "net/http" "net/http/httptest" "testing" "time" + "sync/atomic" "github.com/stretchr/testify/assert" ) func TestEmbeddingClientCheck(t *testing.T) { { - c := OpenAIEmbeddingClient{"mock_key", "mock_uri", "unknow_model"} + c := OpenAIEmbeddingClient{"", "mock_uri"} err := c.Check(); assert.True(t, err != nil) fmt.Println(err) } { - c := OpenAIEmbeddingClient{"", "mock_uri", TextEmbeddingAda002} + c := OpenAIEmbeddingClient{"mock_key", ""} err := c.Check(); assert.True(t, err != nil) fmt.Println(err) } { - c := OpenAIEmbeddingClient{"mock_key", "", TextEmbedding3Small} - err := c.Check(); - assert.True(t, err != nil) - fmt.Println(err) - } - - { - c := OpenAIEmbeddingClient{"mock_key", "mock_uri", TextEmbedding3Small} + c := OpenAIEmbeddingClient{"mock_key", "mock_uri"} err := c.Check(); assert.True(t, err == nil) } @@ -61,7 +54,7 @@ func TestEmbeddingClientCheck(t *testing.T) { func TestEmbeddingOK(t *testing.T) { var res EmbeddingResponse res.Object = "list" - res.Model = TextEmbedding3Small + res.Model = "text-embedding-3-small" res.Data = []EmbeddingData{ { Object: "embedding", @@ -84,12 +77,12 @@ func TestEmbeddingOK(t *testing.T) { url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + c := OpenAIEmbeddingClient{"mock_key", url} err := c.Check(); assert.True(t, err == nil) - ret, err := c.Embedding([]string{"sentence"}, 0, "", 0) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err == nil) - assert.Equal(t, ret, res) + assert.Equal(t, ret, &res) } } @@ -97,24 +90,34 @@ func TestEmbeddingOK(t *testing.T) { func TestEmbeddingRetry(t *testing.T) { var res EmbeddingResponse res.Object = "list" - res.Model = TextEmbedding3Small + res.Model = "text-embedding-3-small" res.Data = []EmbeddingData{ + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.2, 4.5}, + Index: 2, + }, { Object: "embedding", Embedding: []float32{1.1, 2.2, 3.3, 4.4}, Index: 0, }, + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.2, 4.3}, + Index: 1, + }, } res.Usage = Usage{ PromptTokens: 1, TotalTokens: 100, } - var count = 0 + var count int32 = 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if count < 2 { - count += 1 + if atomic.LoadInt32(&count) < 2 { + atomic.AddInt32(&count, 1) w.WriteHeader(http.StatusUnauthorized) } else { w.WriteHeader(http.StatusOK) @@ -127,22 +130,26 @@ func TestEmbeddingRetry(t *testing.T) { url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + c := OpenAIEmbeddingClient{"mock_key", url} err := c.Check(); assert.True(t, err == nil) - ret, err := c.Embedding([]string{"sentence"}, 0, "", 0) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err == nil) - assert.Equal(t, ret, res) - assert.Equal(t, count, 2) + assert.Equal(t, ret.Usage, res.Usage) + assert.Equal(t, ret.Object, res.Object) + assert.Equal(t, ret.Model, res.Model) + assert.Equal(t, ret.Data[0], res.Data[1]) + assert.Equal(t, ret.Data[1], res.Data[2]) + assert.Equal(t, ret.Data[2], res.Data[0]) + assert.Equal(t, atomic.LoadInt32(&count), int32(2)) } } func TestEmbeddingFailed(t *testing.T) { - var count = 0 - + var count int32 = 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - count += 1 + atomic.AddInt32(&count, 1) w.WriteHeader(http.StatusUnauthorized) })) @@ -150,36 +157,34 @@ func TestEmbeddingFailed(t *testing.T) { url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + c := OpenAIEmbeddingClient{"mock_key", url} err := c.Check(); assert.True(t, err == nil) - _, err = c.Embedding([]string{"sentence"}, 0, "", 0) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err != nil) - assert.Equal(t, count, 3) + assert.Equal(t, atomic.LoadInt32(&count), int32(3)) } } func TestTimeout(t *testing.T) { - var st = "Doing" - + var st int32 = 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(3 * time.Second) - st = "Done" + atomic.AddInt32(&st, 1) w.WriteHeader(http.StatusUnauthorized) - })) defer ts.Close() url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small} + c := OpenAIEmbeddingClient{"mock_key", url} err := c.Check(); assert.True(t, err == nil) - _, err = c.Embedding([]string{"sentence"}, 0, "", 1) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1) assert.True(t, err != nil) - assert.Equal(t, st, "Doing") + assert.Equal(t, atomic.LoadInt32(&st), int32(0)) time.Sleep(3 * time.Second) - assert.Equal(t, st, "Done") + assert.Equal(t, atomic.LoadInt32(&st), int32(1)) } } diff --git a/internal/util/function/function.go b/internal/util/function/function.go index a9056af41298d..9eeaa110c3d03 100644 --- a/internal/util/function/function.go +++ b/internal/util/function/function.go @@ -31,6 +31,7 @@ type FunctionRunner interface { GetOutputFields() []*schemapb.FieldSchema } + func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (FunctionRunner, error) { switch schema.GetType() { case schemapb.FunctionType_BM25: diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go new file mode 100644 index 0000000000000..54cd55d18d496 --- /dev/null +++ b/internal/util/function/function_base.go @@ -0,0 +1,67 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + + +type RunnerMode int + +const ( + InsertMode RunnerMode = iota + SearchMode +) + + +type FunctionBase struct { + schema *schemapb.FunctionSchema + outputFields []*schemapb.FieldSchema + mode RunnerMode +} + +func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*FunctionBase, error) { + var base FunctionBase + base.schema = schema + base.mode = mode + for _, field_id := range schema.GetOutputFieldIds() { + for _, field := range coll.GetFields() { + if field.GetFieldID() == field_id { + base.outputFields = append(base.outputFields, field) + break + } + } + } + + if len(base.outputFields) != len(schema.GetOutputFieldIds()) { + return &base, fmt.Errorf("Collection [%s]'s function [%s]'s outputs mismatch schema", coll.Name, schema.Name) + } + return &base, nil +} + +func (base *FunctionBase) GetSchema() *schemapb.FunctionSchema { + return base.schema +} + +func (base *FunctionBase) GetOutputFields() []*schemapb.FieldSchema { + return base.outputFields +} diff --git a/internal/util/function/openai_embedding_function.go b/internal/util/function/openai_embedding_function.go new file mode 100644 index 0000000000000..10182cf9fa7cc --- /dev/null +++ b/internal/util/function/openai_embedding_function.go @@ -0,0 +1,205 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "os" + "strconv" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + + +const ( + TextEmbeddingAda002 string = "text-embedding-ada-002" + TextEmbedding3Small string = "text-embedding-3-small" + TextEmbedding3Large string = "text-embedding-3-large" +) + +const ( + maxBatch = 128 + timeoutSec = 60 + maxRowNum = 60 * maxBatch +) + +const ( + ModelNameParamKey string = "model_name" + DimParamKey string = "dim" + UserParamKey string = "user" + OpenaiEmbeddingUrlParamKey string = "embedding_url" + OpenaiApiKeyParamKey string = "api_key" +) + + +type OpenAIEmbeddingFunction struct { + base *FunctionBase + fieldDim int64 + + client *models.OpenAIEmbeddingClient + modelName string + embedDimParam int64 + user string +} + +func createOpenAIEmbeddingClient(apiKey string, url string) (*models.OpenAIEmbeddingClient, error) { + if apiKey == "" { + apiKey = os.Getenv("OPENAI_API_KEY") + } + if apiKey == "" { + return nil, fmt.Errorf("The apiKey configuration was not found in the environment variables") + } + + if url == "" { + url = os.Getenv("OPENAI_EMBEDDING_URL") + } + if url == "" { + url = "https://api.openai.com/v1/embeddings" + } + c := models.NewOpenAIEmbeddingClient(apiKey, url) + return &c, nil +} + +func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*OpenAIEmbeddingFunction, error) { + if len(schema.GetOutputFieldIds()) != 1 { + return nil, fmt.Errorf("OpenAIEmbedding function should only have one output field, but now %d", len(schema.GetOutputFieldIds())) + } + + base, err := NewBase(coll, schema, mode) + if err != nil { + return nil, err + } + + if base.outputFields[0].DataType != schemapb.DataType_FloatVector { + return nil, fmt.Errorf("Output field not match, openai embedding needs [%s], got [%s]", + schemapb.DataType_name[int32(schemapb.DataType_FloatVector)], + schemapb.DataType_name[int32(base.outputFields[0].DataType)]) + } + + fieldDim, err := typeutil.GetDim(base.outputFields[0]) + if err != nil { + return nil, err + } + var apiKey, url, modelName, user string + var dim int64 + + for _, param := range schema.Params { + switch strings.ToLower(param.Key) { + case ModelNameParamKey: + modelName = param.Value + case DimParamKey: + dim, err := strconv.ParseInt(param.Value, 10, 64) + if err != nil { + return nil, fmt.Errorf("dim [%s] is not int", param.Value) + } + + if dim != 0 && dim != fieldDim { + return nil, fmt.Errorf("Dim in field's schema is [%d], but embeding dim is [%d]", fieldDim, dim) + } + case UserParamKey: + user = param.Value + case OpenaiApiKeyParamKey: + apiKey = param.Value + case OpenaiEmbeddingUrlParamKey: + url = param.Value + default: + } + } + + c, err := createOpenAIEmbeddingClient(apiKey, url) + if err != nil { + return nil, err + } + + runner := OpenAIEmbeddingFunction{ + base: base, + client: c, + fieldDim: fieldDim, + modelName: modelName, + user: user, + embedDimParam: dim, + } + + if runner.modelName != TextEmbeddingAda002 && runner.modelName != TextEmbedding3Small && runner.modelName != TextEmbedding3Large { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", + runner.modelName, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large) + } + return &runner, nil +} + +func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) + } + + if inputs[0].Type != schemapb.DataType_VarChar { + return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") + } + + texts := inputs[0].GetScalars().GetStringData().GetData() + if texts == nil { + return nil, fmt.Errorf("Input texts is empty") + } + + numRows := len(texts) + if numRows > maxRowNum { + return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", maxRowNum, numRows) + } + + var output_field schemapb.FieldData + output_field.FieldId = runner.base.outputFields[0].FieldID + output_field.FieldName = runner.base.outputFields[0].Name + output_field.Type = runner.base.outputFields[0].DataType + output_field.IsDynamic = runner.base.outputFields[0].IsDynamic + data := make([]float32, 0, numRows * int(runner.fieldDim)) + for i := 0; i < numRows; i += maxBatch { + end := i + maxBatch + if end > numRows { + end = numRows + } + resp, err := runner.client.Embedding(runner.modelName, texts[i:end], int(runner.embedDimParam), runner.user, timeoutSec) + if err != nil { + return nil, err + } + if end - i != len(resp.Data) { + return nil, fmt.Errorf("The texts number is [%d], but got embedding number [%d]", end - i, len(resp.Data)) + } + for _, item := range resp.Data { + if len(item.Embedding) != int(runner.fieldDim) { + return nil, fmt.Errorf("Dim in field's schema is [%d], but embeding dim is [%d]", + runner.fieldDim, len(resp.Data[0].Embedding)) + } + data = append(data, item.Embedding...) + } + } + output_field.Field = &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: data, + }, + }, + Dim: runner.fieldDim, + }, + } + return []*schemapb.FieldData{&output_field}, nil +} diff --git a/internal/util/function/openai_embedding_function_test.go b/internal/util/function/openai_embedding_function_test.go new file mode 100644 index 0000000000000..68420cbeddbeb --- /dev/null +++ b/internal/util/function/openai_embedding_function_test.go @@ -0,0 +1,339 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + + +package function + +import ( + "io" + "fmt" + "testing" + "net/http" + "net/http/httptest" + "encoding/json" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + + "github.com/milvus-io/milvus/internal/models" +) + + +func TestOpenAIEmbeddingFunction(t *testing.T) { + suite.Run(t, new(OpenAIEmbeddingFunctionSuite)) +} + +type OpenAIEmbeddingFunctionSuite struct { + suite.Suite + schema *schemapb.CollectionSchema +} + +func (s *OpenAIEmbeddingFunctionSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } +} + +func createData(texts []string) []*schemapb.FieldData{ + data := []*schemapb.FieldData{} + f := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: texts, + }, + }, + }, + }, + } + data = append(data, &f) + return data +} + +func createEmbedding(texts []string, dim int) [][]float32 { + embeddings := make([][]float32, 0) + for i := 0; i < len(texts); i++ { + f := float32(i) + emb := make([]float32, 0) + for j := 0; j < dim; j++ { + emb = append(emb, f + float32(j) * 0.1) + } + embeddings = append(embeddings, emb) + } + return embeddings +} + +func createRunner(url string, schema *schemapb.CollectionSchema) (*OpenAIEmbeddingFunction, error) { + return NewOpenAIEmbeddingFunction(schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: url}, + }, + }, InsertMode) +} + +func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req models.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + embs := createEmbedding(req.Input, 4) + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: embs[i], + Index: i, + }) + } + + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + + })) + + defer ts.Close() + runner, err := createRunner(ts.URL, s.schema) + s.NoError(err) + { + data := createData([]string{"sentence"}) + ret, err2 := runner.Run(data) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(int64(4), ret[0].GetVectors().Dim) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data) + } + { + data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) + ret, _ := runner.Run(data) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) + } +} + +func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingDimNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + Index: 0, + }) + + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0}, + Index: 1, + }) + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + runner, err := createRunner(ts.URL, s.schema) + s.NoError(err) + + // embedding dim not match + data := createData([]string{"sentence", "sentence"}) + _, err2 := runner.Run(data) + s.Error(err2) + fmt.Println(err2.Error()) + // s.NoError(err2) +} + +func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + Index: 0, + }) + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + runner, err := createRunner(ts.URL, s.schema) + + s.NoError(err) + + // embedding dim not match + data := createData([]string{"sentence", "sentence2"}) + _, err2 := runner.Run(data) + s.Error(err2) + fmt.Println(err2.Error()) + // s.NoError(err2) +} + +func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { + // outputfield datatype mismatch + { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } + + _, err := NewOpenAIEmbeddingFunction(schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: DimParamKey, Value: "4"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + }, + }, InsertMode) + s.Error(err) + fmt.Println(err.Error()) + } + + // outputfield number mismatc + { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + {FieldID: 103, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } + _, err := NewOpenAIEmbeddingFunction(schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102, 103}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: DimParamKey, Value: "4"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + }, + }, InsertMode) + s.Error(err) + fmt.Println(err.Error()) + } + + // outputfield miss + { + _, err := NewOpenAIEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{103}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: DimParamKey, Value: "4"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + }, + }, InsertMode) + s.Error(err) + fmt.Println(err.Error()) + } + + // error model name + { + _, err := NewOpenAIEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-004"}, + {Key: DimParamKey, Value: "4"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + }, + }, InsertMode) + s.Error(err) + fmt.Println(err.Error()) + } + + // no openai api key + { + _, err := NewOpenAIEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-003"}, + }, + }, InsertMode) + s.Error(err) + fmt.Println(err.Error()) + } +} From 4c2baae84e3705e732b503a2679cb00652171647 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Sun, 29 Sep 2024 20:13:19 +0800 Subject: [PATCH 03/18] Add function executor Signed-off-by: junjie.jiang --- internal/util/function/function_base.go | 15 +- internal/util/function/function_executor.go | 125 ++++++++++ .../util/function/function_executor_test.go | 213 ++++++++++++++++++ .../function/openai_embedding_function.go | 44 ++-- .../openai_embedding_function_test.go | 22 +- 5 files changed, 373 insertions(+), 46 deletions(-) create mode 100644 internal/util/function/function_executor.go create mode 100644 internal/util/function/function_executor_test.go diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index 54cd55d18d496..1914ed7dbaae2 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -25,24 +25,14 @@ import ( ) -type RunnerMode int - -const ( - InsertMode RunnerMode = iota - SearchMode -) - - type FunctionBase struct { schema *schemapb.FunctionSchema outputFields []*schemapb.FieldSchema - mode RunnerMode } -func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*FunctionBase, error) { +func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*FunctionBase, error) { var base FunctionBase base.schema = schema - base.mode = mode for _, field_id := range schema.GetOutputFieldIds() { for _, field := range coll.GetFields() { if field.GetFieldID() == field_id { @@ -53,7 +43,8 @@ func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, m } if len(base.outputFields) != len(schema.GetOutputFieldIds()) { - return &base, fmt.Errorf("Collection [%s]'s function [%s]'s outputs mismatch schema", coll.Name, schema.Name) + return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema", + coll.Name, schema.Name) } return &base, nil } diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go new file mode 100644 index 0000000000000..beaaf4b1e8e2a --- /dev/null +++ b/internal/util/function/function_executor.go @@ -0,0 +1,125 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + + +package function + + +import ( + "fmt" + "sync" + + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + + +type Runner interface { + GetSchema() *schemapb.FunctionSchema + GetOutputFields() []*schemapb.FieldSchema + + MaxBatch() int + ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) +} + + +type FunctionExecutor struct { + runners []Runner +} + +func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { + executor := new(FunctionExecutor) + for _, f_schema := range schema.Functions { + switch f_schema.GetType() { + case schemapb.FunctionType_BM25: + case schemapb.FunctionType_OpenAIEmbedding: + f, err := NewOpenAIEmbeddingFunction(schema, f_schema) + if err != nil { + return nil, err + } + executor.runners = append(executor.runners, f) + default: + return nil, fmt.Errorf("unknown functionRunner type %s", f_schema.GetType().String()) + } + } + return executor, nil +} + +func (executor *FunctionExecutor)processSingleFunction(idx int, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { + runner := executor.runners[idx] + inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds)) + for _, id := range runner.GetSchema().InputFieldIds { + for _, field := range msg.FieldsData{ + if field.FieldId == id { + inputs = append(inputs, field) + } + } + } + + if len(inputs) != len(runner.GetSchema().InputFieldIds) { + return nil, fmt.Errorf("Input field not found") + } + + outputs, err := runner.ProcessInsert(inputs) + if err != nil { + return nil, err + } + return outputs, nil +} + +func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error { + numRows := msg.NumRows + for _, runner := range executor.runners { + if numRows > uint64(runner.MaxBatch()) { + return fmt.Errorf("numRows [%d] > function [%s]'s max batch [%d]", numRows, runner.GetSchema().Name, runner.MaxBatch()) + } + } + + outputs := make(chan []*schemapb.FieldData, len(executor.runners)) + errChan := make(chan error, len(executor.runners)) + var wg sync.WaitGroup + for idx, _ := range executor.runners { + wg.Add(1) + go func(index int) { + defer wg.Done() + data, err := executor.processSingleFunction(index, msg) + if err != nil { + errChan <- err + } else { + outputs <- data + } + + }(idx) + } + wg.Wait() + close(errChan) + close(outputs) + for err := range errChan { + return err + } + for output := range outputs { + msg.FieldsData = append(msg.FieldsData, output...) + } + return nil +} + + +func (executor *FunctionExecutor)ProcessSearch(msg *milvuspb.SearchRequest) error { + return nil +} diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go new file mode 100644 index 0000000000000..2d0a701352f75 --- /dev/null +++ b/internal/util/function/function_executor_test.go @@ -0,0 +1,213 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + + +package function + + +import ( + "io" + "testing" + "net/http" + "net/http/httptest" + "encoding/json" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/models" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" +) + + +func TestFunctionExecutor(t *testing.T) { + suite.Run(t, new(FunctionExecutorSuite)) +} + +type FunctionExecutorSuite struct { + suite.Suite +} + + +func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema{ + return &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "8"}, + }}, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "test", + Type: schemapb.FunctionType_OpenAIEmbedding, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: url}, + {Key: DimParamKey, Value: "4"}, + }, + }, + { + Name: "test", + Type: schemapb.FunctionType_OpenAIEmbedding, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{103}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: url}, + {Key: DimParamKey, Value: "8"}, + }, + }, + }, + } + +} + +func (s *FunctionExecutorSuite)createMsg(texts []string) *msgstream.InsertMsg{ + + data := []*schemapb.FieldData{} + f := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: texts, + }, + }, + }, + }, + } + data = append(data, &f) + + msg := msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + FieldsData: data, + }, + } + return &msg +} + +func (s *FunctionExecutorSuite)createEmbedding(texts []string, dim int) [][]float32{ + embeddings := make([][]float32, 0) + for i := 0; i < len(texts); i++ { + f := float32(i) + emb := make([]float32, 0) + for j := 0; j < dim; j++ { + emb = append(emb, f + float32(j) * 0.1) + } + embeddings = append(embeddings, emb) + } + return embeddings +} + +func (s *FunctionExecutorSuite) TestExecutor() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request){ + var req models.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + embs := s.createEmbedding(req.Input, req.Dimensions) + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: embs[i], + Index: i, + }) + } + + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + + })) + + defer ts.Close() + schema := s.creataSchema(ts.URL) + exec, err := newFunctionExecutor(schema) + s.NoError(err) + msg := s.createMsg([]string{"sentence", "sentence"}) + exec.ProcessInsert(msg) + s.Equal(len(msg.FieldsData), 3) +} + +func (s *FunctionExecutorSuite) TestErrorEmbedding() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request){ + var req models.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: []float32{}, + Index: i, + }) + } + + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + + })) + defer ts.Close() + schema := s.creataSchema(ts.URL) + exec, err := newFunctionExecutor(schema) + s.NoError(err) + msg := s.createMsg([]string{"sentence", "sentence"}) + err = exec.ProcessInsert(msg) + s.Error(err) +} + +func (s *FunctionExecutorSuite) TestErrorSchema() { + schema := s.creataSchema("http://localhost") + schema.Functions[0].Type = schemapb.FunctionType_Unknown + _, err := newFunctionExecutor(schema) + s.Error(err) +} diff --git a/internal/util/function/openai_embedding_function.go b/internal/util/function/openai_embedding_function.go index 10182cf9fa7cc..11151f92a6adb 100644 --- a/internal/util/function/openai_embedding_function.go +++ b/internal/util/function/openai_embedding_function.go @@ -38,8 +38,7 @@ const ( const ( maxBatch = 128 - timeoutSec = 60 - maxRowNum = 60 * maxBatch + timeoutSec = 30 ) const ( @@ -52,7 +51,7 @@ const ( type OpenAIEmbeddingFunction struct { - base *FunctionBase + FunctionBase fieldDim int64 client *models.OpenAIEmbeddingClient @@ -79,12 +78,12 @@ func createOpenAIEmbeddingClient(apiKey string, url string) (*models.OpenAIEmbed return &c, nil } -func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*OpenAIEmbeddingFunction, error) { +func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*OpenAIEmbeddingFunction, error) { if len(schema.GetOutputFieldIds()) != 1 { - return nil, fmt.Errorf("OpenAIEmbedding function should only have one output field, but now %d", len(schema.GetOutputFieldIds())) + return nil, fmt.Errorf("OpenAIEmbedding function should only have one output field, but now is %d", len(schema.GetOutputFieldIds())) } - base, err := NewBase(coll, schema, mode) + base, err := NewBase(coll, schema) if err != nil { return nil, err } @@ -107,13 +106,13 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap case ModelNameParamKey: modelName = param.Value case DimParamKey: - dim, err := strconv.ParseInt(param.Value, 10, 64) + dim, err = strconv.ParseInt(param.Value, 10, 64) if err != nil { return nil, fmt.Errorf("dim [%s] is not int", param.Value) } if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Dim in field's schema is [%d], but embeding dim is [%d]", fieldDim, dim) + return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", schema.Name, fieldDim, dim) } case UserParamKey: user = param.Value @@ -131,7 +130,7 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap } runner := OpenAIEmbeddingFunction{ - base: base, + FunctionBase: *base, client: c, fieldDim: fieldDim, modelName: modelName, @@ -146,7 +145,16 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap return &runner, nil } -func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { +func (runner *OpenAIEmbeddingFunction)MaxBatch() int { + return 5 * maxBatch +} + + +func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { + return runner.Run(inputs) +} + +func (runner *OpenAIEmbeddingFunction) Run( inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { if len(inputs) != 1 { return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) } @@ -161,15 +169,15 @@ func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*sch } numRows := len(texts) - if numRows > maxRowNum { - return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", maxRowNum, numRows) + if numRows > runner.MaxBatch() { + return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows) } var output_field schemapb.FieldData - output_field.FieldId = runner.base.outputFields[0].FieldID - output_field.FieldName = runner.base.outputFields[0].Name - output_field.Type = runner.base.outputFields[0].DataType - output_field.IsDynamic = runner.base.outputFields[0].IsDynamic + output_field.FieldId = runner.outputFields[0].FieldID + output_field.FieldName = runner.outputFields[0].Name + output_field.Type = runner.outputFields[0].DataType + output_field.IsDynamic = runner.outputFields[0].IsDynamic data := make([]float32, 0, numRows * int(runner.fieldDim)) for i := 0; i < numRows; i += maxBatch { end := i + maxBatch @@ -185,8 +193,8 @@ func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*sch } for _, item := range resp.Data { if len(item.Embedding) != int(runner.fieldDim) { - return nil, fmt.Errorf("Dim in field's schema is [%d], but embeding dim is [%d]", - runner.fieldDim, len(resp.Data[0].Embedding)) + return nil, fmt.Errorf("The required embedding dim for field [%s] is [%d], but the embedding obtained from the model is [%d]", + output_field.FieldName, runner.fieldDim, len(item.Embedding)) } data = append(data, item.Embedding...) } diff --git a/internal/util/function/openai_embedding_function_test.go b/internal/util/function/openai_embedding_function_test.go index 68420cbeddbeb..f295edbbe577d 100644 --- a/internal/util/function/openai_embedding_function_test.go +++ b/internal/util/function/openai_embedding_function_test.go @@ -21,7 +21,6 @@ package function import ( "io" - "fmt" "testing" "net/http" "net/http/httptest" @@ -103,7 +102,7 @@ func createRunner(url string, schema *schemapb.CollectionSchema) (*OpenAIEmbeddi {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: url}, }, - }, InsertMode) + }) } func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() { @@ -186,8 +185,6 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingDimNotMatch() { data := createData([]string{"sentence", "sentence"}) _, err2 := runner.Run(data) s.Error(err2) - fmt.Println(err2.Error()) - // s.NoError(err2) } func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() { @@ -218,8 +215,6 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() { data := createData([]string{"sentence", "sentence2"}) _, err2 := runner.Run(data) s.Error(err2) - fmt.Println(err2.Error()) - // s.NoError(err2) } func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { @@ -248,9 +243,8 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) - fmt.Println(err.Error()) } // outputfield number mismatc @@ -281,9 +275,8 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) - fmt.Println(err.Error()) } // outputfield miss @@ -299,9 +292,8 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) - fmt.Println(err.Error()) } // error model name @@ -317,9 +309,8 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) - fmt.Println(err.Error()) } // no openai api key @@ -332,8 +323,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { Params: []*commonpb.KeyValuePair{ {Key: ModelNameParamKey, Value: "text-embedding-ada-003"}, }, - }, InsertMode) + }) s.Error(err) - fmt.Println(err.Error()) } } From 542c2493527b13794c383b898d7b453b359a69a8 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Wed, 9 Oct 2024 11:25:05 +0800 Subject: [PATCH 04/18] Insert & Upsert support functions Signed-off-by: junjie.jiang --- internal/proxy/task_insert.go | 13 ++ internal/proxy/task_insert_test.go | 128 ++++++++++++++++++ internal/proxy/task_upsert.go | 15 ++ internal/util/function/function_executor.go | 18 +-- .../util/function/function_executor_test.go | 6 +- .../function/openai_embedding_function.go | 19 ++- .../openai_embedding_function_test.go | 8 +- 7 files changed, 188 insertions(+), 19 deletions(-) diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 9de31cd53d600..3c80703fce9e4 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -12,6 +12,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -141,6 +142,18 @@ func (it *insertTask) PreExecute(ctx context.Context) error { } it.schema = schema.CollectionSchema + // Calculate embedding fields + exec, err := function.NewFunctionExecutor(schema.CollectionSchema) + if err != nil { + return err + } + + if !exec.Empty() { + if err := exec.ProcessInsert(it.insertMsg); err != nil { + return err + } + } + rowNums := uint32(it.insertMsg.NRows()) // set insertTask.rowIDs var rowIDBegin UniqueID diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index cdf90290567ce..7b1d3a71c99db 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -1,15 +1,25 @@ package proxy import ( + "io" + "fmt" "context" "testing" + "net/http" + "net/http/httptest" + "encoding/json" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus/internal/models" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -308,3 +318,121 @@ func TestMaxInsertSize(t *testing.T) { assert.ErrorIs(t, err, merr.ErrParameterTooLarge) }) } + +func TestInsertTask_Function(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request){ + var req models.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: make([]float32, req.Dimensions), + Index: i, + }) + } + + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + defer ts.Close() + data := []*schemapb.FieldData{} + f := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + FieldName: "text", + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"sentence", "sentence"}, + }, + }, + }, + }, + } + data = append(data, &f) + collectionName := "TestInsertTask_function" + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "TestInsertTask_function", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length", Value: "200"}, + }}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "test_function", + Type: schemapb.FunctionType_OpenAIEmbedding, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, + {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, + {Key: function.DimParamKey, Value: "4"}, + }, + }, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + rc := mocks.NewMockRootCoordClient(t) + rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{ + Status: merr.Status(nil), + ID: 11198, + Count: 10, + }, nil) + idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0) + idAllocator.Start() + defer idAllocator.Close() + assert.NoError(t, err) + task := insertTask{ + ctx: context.Background(), + insertMsg: &BaseInsertTask{ + InsertRequest: &msgpb.InsertRequest{ + CollectionName: collectionName, + DbName: "hooooooo", + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + Version: msgpb.InsertDataVersion_ColumnBased, + FieldsData: data, + NumRows: 2, + }, + }, + schema: schema, + idAllocator: idAllocator, + } + + info := newSchemaInfo(schema) + cache := NewMockCache(t) + cache.On("GetCollectionSchema", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(info, nil) + globalMetaCache = cache + err = task.PreExecute(ctx) + assert.NoError(t, err) +} diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 1de223fa4124d..390f27107ed5c 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -152,6 +153,20 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { return err } + // Calculate embedding fields + { + exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema) + if err != nil { + return err + } + + if !exec.Empty() { + if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil { + return err + } + } + } + rowNums := uint32(it.upsertMsg.InsertMsg.NRows()) // set upsertTask.insertRequest.rowIDs tr := timerecord.NewTimeRecorder("applyPK") diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index beaaf4b1e8e2a..1af56952d9b3a 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -23,9 +23,8 @@ package function import ( "fmt" "sync" - + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) @@ -35,15 +34,15 @@ type Runner interface { GetOutputFields() []*schemapb.FieldSchema MaxBatch() int - ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) + ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) + ProcessSearch(placeholderGroups [][]byte) ([][]byte, error) } - type FunctionExecutor struct { runners []Runner } -func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { +func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { executor := new(FunctionExecutor) for _, f_schema := range schema.Functions { switch f_schema.GetType() { @@ -61,6 +60,10 @@ func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, return executor, nil } +func (executor *FunctionExecutor)Empty() bool { + return len(executor.runners) == 0 +} + func (executor *FunctionExecutor)processSingleFunction(idx int, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { runner := executor.runners[idx] inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds)) @@ -119,7 +122,6 @@ func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error { return nil } - -func (executor *FunctionExecutor)ProcessSearch(msg *milvuspb.SearchRequest) error { - return nil +func (executor *FunctionExecutor)ProcessSearch(req *internalpb.SearchRequest) (interface{}, error) { + return nil, nil } diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index 2d0a701352f75..8eef0b7b55857 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -162,7 +162,7 @@ func (s *FunctionExecutorSuite) TestExecutor() { defer ts.Close() schema := s.creataSchema(ts.URL) - exec, err := newFunctionExecutor(schema) + exec, err := NewFunctionExecutor(schema) s.NoError(err) msg := s.createMsg([]string{"sentence", "sentence"}) exec.ProcessInsert(msg) @@ -198,7 +198,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() { })) defer ts.Close() schema := s.creataSchema(ts.URL) - exec, err := newFunctionExecutor(schema) + exec, err := NewFunctionExecutor(schema) s.NoError(err) msg := s.createMsg([]string{"sentence", "sentence"}) err = exec.ProcessInsert(msg) @@ -208,6 +208,6 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() { func (s *FunctionExecutorSuite) TestErrorSchema() { schema := s.creataSchema("http://localhost") schema.Functions[0].Type = schemapb.FunctionType_Unknown - _, err := newFunctionExecutor(schema) + _, err := NewFunctionExecutor(schema) s.Error(err) } diff --git a/internal/util/function/openai_embedding_function.go b/internal/util/function/openai_embedding_function.go index 11151f92a6adb..25021ed802959 100644 --- a/internal/util/function/openai_embedding_function.go +++ b/internal/util/function/openai_embedding_function.go @@ -151,10 +151,6 @@ func (runner *OpenAIEmbeddingFunction)MaxBatch() int { func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { - return runner.Run(inputs) -} - -func (runner *OpenAIEmbeddingFunction) Run( inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { if len(inputs) != 1 { return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) } @@ -211,3 +207,18 @@ func (runner *OpenAIEmbeddingFunction) Run( inputs []*schemapb.FieldData) ([]*sc } return []*schemapb.FieldData{&output_field}, nil } + +func (runner *OpenAIEmbeddingFunction)ProcessSearch(placeholderGroups [][]byte) ([][]byte, error){ + if len(placeholderGroups) != 1 { + return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(placeholderGroups)) + } + + // get tests from placeholderGroups + + // texts := []string{} + + // calc embedding + + //to placeholderGroups + return nil, nil +} diff --git a/internal/util/function/openai_embedding_function_test.go b/internal/util/function/openai_embedding_function_test.go index f295edbbe577d..46e09d8e68108 100644 --- a/internal/util/function/openai_embedding_function_test.go +++ b/internal/util/function/openai_embedding_function_test.go @@ -139,7 +139,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() { s.NoError(err) { data := createData([]string{"sentence"}) - ret, err2 := runner.Run(data) + ret, err2 := runner.ProcessInsert(data) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(int64(4), ret[0].GetVectors().Dim) @@ -147,7 +147,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() { } { data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) - ret, _ := runner.Run(data) + ret, _ := runner.ProcessInsert(data) s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) } } @@ -183,7 +183,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingDimNotMatch() { // embedding dim not match data := createData([]string{"sentence", "sentence"}) - _, err2 := runner.Run(data) + _, err2 := runner.ProcessInsert(data) s.Error(err2) } @@ -213,7 +213,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() { // embedding dim not match data := createData([]string{"sentence", "sentence2"}) - _, err2 := runner.Run(data) + _, err2 := runner.ProcessInsert(data) s.Error(err2) } From 45875720d2ae63b489c5623f93e35064c9852559 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Thu, 17 Oct 2024 10:15:48 +0800 Subject: [PATCH 05/18] Search supports embedding function Signed-off-by: junjie.jiang --- internal/proxy/task_insert.go | 12 +- internal/proxy/task_search.go | 23 +++ internal/proxy/task_upsert.go | 10 +- internal/util/function/function_executor.go | 137 ++++++++++++++---- internal/util/function/function_util.go | 62 ++++++++ .../function/openai_embedding_function.go | 81 ++++++----- pkg/util/funcutil/placeholdergroup.go | 15 ++ 7 files changed, 265 insertions(+), 75 deletions(-) create mode 100644 internal/util/function/function_util.go diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 3c80703fce9e4..37f0a58b84aa4 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -143,17 +143,15 @@ func (it *insertTask) PreExecute(ctx context.Context) error { it.schema = schema.CollectionSchema // Calculate embedding fields - exec, err := function.NewFunctionExecutor(schema.CollectionSchema) - if err != nil { - return err - } - - if !exec.Empty() { + if function.HasFunctions(schema.CollectionSchema.Functions, []int64{}) { + exec, err := function.NewFunctionExecutor(schema.CollectionSchema) + if err != nil { + return err + } if err := exec.ProcessInsert(it.insertMsg); err != nil { return err } } - rowNums := uint32(it.insertMsg.NRows()) // set insertTask.rowIDs var rowIDBegin UniqueID diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index ffa7c9b23b8ee..b6082452216d6 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -22,6 +22,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/exprutil" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -419,6 +420,17 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { zap.Stringer("plan", plan)) // may be very large if large term passed. } + var err error + if function.HasFunctions(t.schema.CollectionSchema.Functions, []int64{}) { + exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema) + if err != nil { + return err + } + if err := exec.ProcessSearchReq(t.SearchRequest); err != nil { + return err + } + } + t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId() t.SearchRequest.GroupSize = t.rankParams.GetGroupSize() @@ -426,6 +438,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { if t.partitionKeyMode { t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect() } + var err error t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams()) if err != nil { @@ -497,6 +510,16 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { t.SearchRequest.DslType = commonpb.DslType_BoolExprV1 t.SearchRequest.GroupByFieldId = queryInfo.GroupByFieldId t.SearchRequest.GroupSize = queryInfo.GroupSize + + if function.HasFunctions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) { + exec, err := function.NewFunctionExecutor(t.schema.CollectionSchema) + if err != nil { + return err + } + if err := exec.ProcessSearchReq(t.SearchRequest); err != nil { + return err + } + } log.Debug("proxy init search request", zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.Stringer("plan", plan)) // may be very large if large term passed. diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 390f27107ed5c..32f3e29f0e0a5 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -154,19 +154,15 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { } // Calculate embedding fields - { + if function.HasFunctions(it.schema.CollectionSchema.Functions, []int64{}) { exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema) if err != nil { return err } - - if !exec.Empty() { - if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil { - return err - } + if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil { + return err } } - rowNums := uint32(it.upsertMsg.InsertMsg.NRows()) // set upsertTask.insertRequest.rowIDs tr := timerecord.NewTimeRecorder("applyPK") diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index 1af56952d9b3a..51913d54a97c0 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -23,9 +23,14 @@ package function import ( "fmt" "sync" - "github.com/milvus-io/milvus/internal/proto/internalpb" + + "google.golang.org/protobuf/proto" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" ) @@ -35,37 +40,48 @@ type Runner interface { MaxBatch() int ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) - ProcessSearch(placeholderGroups [][]byte) ([][]byte, error) + ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) } type FunctionExecutor struct { - runners []Runner + runners map[int64]Runner +} + +func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (Runner, error) { + switch schema.GetType() { + case schemapb.FunctionType_BM25: // ignore bm25 function + return nil, nil + case schemapb.FunctionType_OpenAIEmbedding: + f, err := NewOpenAIEmbeddingFunction(coll, schema) + if err != nil { + return nil, err + } + return f, nil + default: + return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String()) + } } func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { - executor := new(FunctionExecutor) + // If the function's outputs exists in outputIDs, then create the function + // when outputIDs is empty, create all functions + // Please note that currently we only support the function with one input and one output. + executor := &FunctionExecutor{ + runners: make(map[int64]Runner), + } for _, f_schema := range schema.Functions { - switch f_schema.GetType() { - case schemapb.FunctionType_BM25: - case schemapb.FunctionType_OpenAIEmbedding: - f, err := NewOpenAIEmbeddingFunction(schema, f_schema) - if err != nil { - return nil, err + if runner, err := createFunction(schema, f_schema); err != nil { + return nil, err + } else { + if runner != nil { + executor.runners[f_schema.GetOutputFieldIds()[0]] = runner } - executor.runners = append(executor.runners, f) - default: - return nil, fmt.Errorf("unknown functionRunner type %s", f_schema.GetType().String()) } } return executor, nil } -func (executor *FunctionExecutor)Empty() bool { - return len(executor.runners) == 0 -} - -func (executor *FunctionExecutor)processSingleFunction(idx int, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { - runner := executor.runners[idx] +func (executor *FunctionExecutor)processSingleFunction(runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds)) for _, id := range runner.GetSchema().InputFieldIds { for _, field := range msg.FieldsData{ @@ -97,18 +113,18 @@ func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error { outputs := make(chan []*schemapb.FieldData, len(executor.runners)) errChan := make(chan error, len(executor.runners)) var wg sync.WaitGroup - for idx, _ := range executor.runners { + for _, runner := range executor.runners { wg.Add(1) - go func(index int) { + go func(runner Runner) { defer wg.Done() - data, err := executor.processSingleFunction(index, msg) + data, err := executor.processSingleFunction(runner, msg) if err != nil { errChan <- err } else { outputs <- data } - }(idx) + }(runner) } wg.Wait() close(errChan) @@ -122,6 +138,77 @@ func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error { return nil } -func (executor *FunctionExecutor)ProcessSearch(req *internalpb.SearchRequest) (interface{}, error) { - return nil, nil +func (executor *FunctionExecutor)processSingleSearch(runner Runner, placeholderGroup []byte) ([]byte, error) { + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(placeholderGroup, pb) + if len(pb.Placeholders) != 1 { + return nil, merr.WrapErrParameterInvalidMsg("No placeholders founded") + } + if pb.Placeholders[0].Type != commonpb.PlaceholderType_VarChar { + return placeholderGroup, nil + } + res, err := runner.ProcessSearch(pb) + if err != nil { + return nil, err + } + return proto.Marshal(res) +} + +func (executor *FunctionExecutor)prcessSearch(req *internalpb.SearchRequest) error { + runner, exist := executor.runners[req.FieldId] + if !exist { + return nil + } + if req.Nq > int64(runner.MaxBatch()) { + return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", req.Nq, runner.GetSchema().Name, runner.MaxBatch()) + } + if newHolder, err := executor.processSingleSearch(runner, req.GetPlaceholderGroup()); err != nil { + return err + } else { + req.PlaceholderGroup = newHolder + } + return nil +} + +func (executor *FunctionExecutor)prcessAdvanceSearch(req *internalpb.SearchRequest) error { + outputs := make(chan map[int64][]byte, len(req.GetSubReqs())) + errChan := make(chan error, len(req.GetSubReqs())) + var wg sync.WaitGroup + for idx, sub := range req.GetSubReqs() { + if runner, exist := executor.runners[sub.FieldId]; exist { + if sub.Nq > int64(runner.MaxBatch()) { + return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", sub.Nq, runner.GetSchema().Name, runner.MaxBatch()) + } + go func(runner Runner, idx int64) { + defer wg.Done() + if newHolder, err := executor.processSingleSearch(runner, sub.GetPlaceholderGroup()); err != nil { + errChan <- err + } else { + outputs <- map[int64][]byte{idx: newHolder} + } + + }(runner, int64(idx)) + } + } + wg.Wait() + close(errChan) + close(outputs) + for err := range errChan { + return err + } + + for output := range outputs { + for idx, holder := range output { + req.SubReqs[idx].PlaceholderGroup = holder + } + } + return nil +} + +func (executor *FunctionExecutor)ProcessSearchReq(req *internalpb.SearchRequest) error { + if req.IsAdvanced { + return executor.prcessSearch(req) + } else { + return executor.prcessAdvanceSearch(req) + } } diff --git a/internal/util/function/function_util.go b/internal/util/function/function_util.go new file mode 100644 index 0000000000000..b725ac2a66bb1 --- /dev/null +++ b/internal/util/function/function_util.go @@ -0,0 +1,62 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + + +package function + + +import ( + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool{ + // Determine whether the column corresponding to outputIDs contains functions, except bm25 function, + // if outputIDs is empty, check all cols + for _, f_schema := range functions { + switch f_schema.GetType() { + case schemapb.FunctionType_BM25: + default: + if len(outputIDs) == 0 { + return true + } else { + for _, id := range outputIDs { + if f_schema.GetOutputFieldIds()[0] == id { + return true + } + } + } + } + } + return false +} + +func GetOutputIDFunctionsMap(functions []*schemapb.FunctionSchema) (map[int64]*schemapb.FunctionSchema, error) { + outputIdMap := map[int64]*schemapb.FunctionSchema{} + for _, f_schema := range functions { + switch f_schema.GetType() { + case schemapb.FunctionType_BM25: + default: + if len(f_schema.OutputFieldIds) != 1 { + return nil, merr.WrapErrParameterInvalidMsg("Function [%s]'s outputs err, only supports one outputs", f_schema.Name) + } + outputIdMap[f_schema.OutputFieldIds[0]] = f_schema + } + } + return outputIdMap, nil +} diff --git a/internal/util/function/openai_embedding_function.go b/internal/util/function/openai_embedding_function.go index 25021ed802959..5554c5e461d30 100644 --- a/internal/util/function/openai_embedding_function.go +++ b/internal/util/function/openai_embedding_function.go @@ -25,7 +25,9 @@ import ( "strings" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/models" + "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -150,31 +152,13 @@ func (runner *OpenAIEmbeddingFunction)MaxBatch() int { } -func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { - if len(inputs) != 1 { - return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) - } - - if inputs[0].Type != schemapb.DataType_VarChar { - return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") - } - - texts := inputs[0].GetScalars().GetStringData().GetData() - if texts == nil { - return nil, fmt.Errorf("Input texts is empty") - } - +func (runner *OpenAIEmbeddingFunction)callEmbedding(texts []string) ([][]float32, error) { numRows := len(texts) if numRows > runner.MaxBatch() { return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows) } - var output_field schemapb.FieldData - output_field.FieldId = runner.outputFields[0].FieldID - output_field.FieldName = runner.outputFields[0].Name - output_field.Type = runner.outputFields[0].DataType - output_field.IsDynamic = runner.outputFields[0].IsDynamic - data := make([]float32, 0, numRows * int(runner.fieldDim)) + data := make([][]float32, numRows) for i := 0; i < numRows; i += maxBatch { end := i + maxBatch if end > numRows { @@ -190,12 +174,43 @@ func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldDat for _, item := range resp.Data { if len(item.Embedding) != int(runner.fieldDim) { return nil, fmt.Errorf("The required embedding dim for field [%s] is [%d], but the embedding obtained from the model is [%d]", - output_field.FieldName, runner.fieldDim, len(item.Embedding)) + runner.outputFields[0].Name, runner.fieldDim, len(item.Embedding)) } - data = append(data, item.Embedding...) + data = append(data, item.Embedding) } } - output_field.Field = &schemapb.FieldData_Vectors{ + return data, nil +} + +func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) + } + + if inputs[0].Type != schemapb.DataType_VarChar { + return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") + } + + texts := inputs[0].GetScalars().GetStringData().GetData() + if texts == nil { + return nil, fmt.Errorf("Input texts is empty") + } + + embds, err := runner.callEmbedding(texts) + if err != nil { + return nil, err + } + data := make([]float32, 0, len(texts) * int(runner.fieldDim)) + for _, emb := range embds { + data = append(data, emb...) + } + + var outputField schemapb.FieldData + outputField.FieldId = runner.outputFields[0].FieldID + outputField.FieldName = runner.outputFields[0].Name + outputField.Type = runner.outputFields[0].DataType + outputField.IsDynamic = runner.outputFields[0].IsDynamic + outputField.Field = &schemapb.FieldData_Vectors{ Vectors: &schemapb.VectorField{ Data: &schemapb.VectorField_FloatVector{ FloatVector: &schemapb.FloatArray{ @@ -205,20 +220,14 @@ func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldDat Dim: runner.fieldDim, }, } - return []*schemapb.FieldData{&output_field}, nil + return []*schemapb.FieldData{&outputField}, nil } -func (runner *OpenAIEmbeddingFunction)ProcessSearch(placeholderGroups [][]byte) ([][]byte, error){ - if len(placeholderGroups) != 1 { - return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(placeholderGroups)) +func (runner *OpenAIEmbeddingFunction)ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error){ + texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally + embds, err := runner.callEmbedding(texts) + if err == nil { + return nil, err } - - // get tests from placeholderGroups - - // texts := []string{} - - // calc embedding - - //to placeholderGroups - return nil, nil + return funcutil.Float32VectorsToPlaceholderGroup(embds), nil } diff --git a/pkg/util/funcutil/placeholdergroup.go b/pkg/util/funcutil/placeholdergroup.go index 96ecdfa4df8fb..abdec286e6fd8 100644 --- a/pkg/util/funcutil/placeholdergroup.go +++ b/pkg/util/funcutil/placeholdergroup.go @@ -25,6 +25,21 @@ func SparseVectorDataToPlaceholderGroupBytes(contents [][]byte) []byte { return bytes } +func Float32VectorsToPlaceholderGroup(embs [][]float32) *commonpb.PlaceholderGroup { + result := make([][]byte, 0) + for _, floatVector := range embs { + result = append(result, floatVectorToByteVector(floatVector)) + } + placeholderGroup := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{{ + Tag: "$0", + Type: commonpb.PlaceholderType_SparseFloatVector, + Values: result, + }}, + } + return placeholderGroup +} + func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) { placeholderValue, err := fieldDataToPlaceholderValue(fieldData) if err != nil { From c01d61cbeece8d8b2d946e09be734b9a6a42bbff Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Thu, 17 Oct 2024 17:40:32 +0800 Subject: [PATCH 06/18] Add search embedding test Signed-off-by: junjie.jiang --- internal/models/openai_embedding.go | 70 +++-- internal/models/openai_embedding_test.go | 43 ++- internal/proxy/task_insert_test.go | 31 ++- internal/proxy/task_search.go | 4 +- internal/proxy/task_search_test.go | 246 ++++++++++++++++++ internal/util/function/function.go | 1 - internal/util/function/function_base.go | 3 +- internal/util/function/function_executor.go | 28 +- .../util/function/function_executor_test.go | 50 ++-- internal/util/function/function_util.go | 6 +- .../function/openai_embedding_function.go | 59 ++--- .../openai_embedding_function_test.go | 46 ++-- pkg/util/funcutil/placeholdergroup.go | 4 +- 13 files changed, 407 insertions(+), 184 deletions(-) diff --git a/internal/models/openai_embedding.go b/internal/models/openai_embedding.go index 70e8e2508a14c..dc1c3660c7dda 100644 --- a/internal/models/openai_embedding.go +++ b/internal/models/openai_embedding.go @@ -26,83 +26,80 @@ import ( "time" ) - type EmbeddingRequest struct { - // ID of the model to use. - Model string `json:"model"` + // ID of the model to use. + Model string `json:"model"` // Input text to embed, encoded as a string. - Input []string `json:"input"` + Input []string `json:"input"` // A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. - User string `json:"user,omitempty"` + User string `json:"user,omitempty"` // The format to return the embeddings in. Can be either float or base64. - EncodingFormat string `json:"encoding_format,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` // The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. - Dimensions int `json:"dimensions,omitempty"` + Dimensions int `json:"dimensions,omitempty"` } type Usage struct { // The number of tokens used by the prompt. - PromptTokens int `json:"prompt_tokens"` - + PromptTokens int `json:"prompt_tokens"` + // The total number of tokens used by the request. - TotalTokens int `json:"total_tokens"` + TotalTokens int `json:"total_tokens"` } - type EmbeddingData struct { // The object type, which is always "embedding". - Object string `json:"object"` + Object string `json:"object"` // The embedding vector, which is a list of floats. Embedding []float32 `json:"embedding"` // The index of the embedding in the list of embeddings. - Index int `json:"index"` + Index int `json:"index"` } - type EmbeddingResponse struct { // The object type, which is always "list". - Object string `json:"object"` + Object string `json:"object"` // The list of embeddings generated by the model. - Data []EmbeddingData `json:"data"` + Data []EmbeddingData `json:"data"` // The name of the model used to generate the embedding. - Model string `json:"model"` + Model string `json:"model"` // The usage information for the request. - Usage Usage `json:"usage"` + Usage Usage `json:"usage"` } - type ByIndex struct { resp *EmbeddingResponse } -func (eb *ByIndex) Len() int { return len(eb.resp.Data) } -func (eb *ByIndex) Swap(i, j int) { eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i] } +func (eb *ByIndex) Len() int { return len(eb.resp.Data) } +func (eb *ByIndex) Swap(i, j int) { + eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i] +} func (eb *ByIndex) Less(i, j int) bool { return eb.resp.Data[i].Index < eb.resp.Data[j].Index } - type ErrorInfo struct { - Code string `json:"code"` - Message string `json:"message"` - Param string `json:"param,omitempty"` - Type string `json:"type"` + Code string `json:"code"` + Message string `json:"message"` + Param string `json:"param,omitempty"` + Type string `json:"type"` } type EmbedddingError struct { - Error ErrorInfo `json:"error"` + Error ErrorInfo `json:"error"` } type OpenAIEmbeddingClient struct { apiKey string - url string + url string } func (c *OpenAIEmbeddingClient) Check() error { @@ -116,18 +113,17 @@ func (c *OpenAIEmbeddingClient) Check() error { return nil } -func NewOpenAIEmbeddingClient(apiKey string, url string) OpenAIEmbeddingClient{ +func NewOpenAIEmbeddingClient(apiKey string, url string) OpenAIEmbeddingClient { return OpenAIEmbeddingClient{ apiKey: apiKey, - url: url, + url: url, } } - func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res *EmbeddingResponse) error { // call openai resp, err := client.Do(req) - + if err != nil { return err } @@ -142,14 +138,14 @@ func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res return fmt.Errorf(string(body)) } - err = json.Unmarshal(body, &res) + err = json.Unmarshal(body, &res) if err != nil { return err } return nil } -func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request,res *EmbeddingResponse, maxRetries int) error { +func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request, res *EmbeddingResponse, maxRetries int) error { var err error for i := 0; i < maxRetries; i++ { err = c.send(client, req, res) @@ -176,7 +172,7 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim if err != nil { return nil, err } - + // call openai if timeoutSec <= 0 { timeoutSec = 30 @@ -184,7 +180,7 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim client := &http.Client{ Timeout: timeoutSec * time.Second, } - req, err := http.NewRequest("POST" , c.url, bytes.NewBuffer(data)) + req, err := http.NewRequest("POST", c.url, bytes.NewBuffer(data)) if err != nil { return nil, err } @@ -198,5 +194,5 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim } sort.Sort(&ByIndex{&res}) return &res, err - + } diff --git a/internal/models/openai_embedding_test.go b/internal/models/openai_embedding_test.go index eb31b9c23dffc..0c4cfed6d3ece 100644 --- a/internal/models/openai_embedding_test.go +++ b/internal/models/openai_embedding_test.go @@ -21,9 +21,9 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync/atomic" "testing" "time" - "sync/atomic" "github.com/stretchr/testify/assert" ) @@ -31,42 +31,41 @@ import ( func TestEmbeddingClientCheck(t *testing.T) { { c := OpenAIEmbeddingClient{"", "mock_uri"} - err := c.Check(); + err := c.Check() assert.True(t, err != nil) fmt.Println(err) } { c := OpenAIEmbeddingClient{"mock_key", ""} - err := c.Check(); + err := c.Check() assert.True(t, err != nil) fmt.Println(err) } { c := OpenAIEmbeddingClient{"mock_key", "mock_uri"} - err := c.Check(); + err := c.Check() assert.True(t, err == nil) } } - func TestEmbeddingOK(t *testing.T) { var res EmbeddingResponse res.Object = "list" res.Model = "text-embedding-3-small" res.Data = []EmbeddingData{ { - Object: "embedding", + Object: "embedding", Embedding: []float32{1.1, 2.2, 3.3, 4.4}, - Index: 0, + Index: 0, }, } res.Usage = Usage{ PromptTokens: 1, - TotalTokens: 100, + TotalTokens: 100, } - + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) @@ -78,7 +77,7 @@ func TestEmbeddingOK(t *testing.T) { { c := OpenAIEmbeddingClient{"mock_key", url} - err := c.Check(); + err := c.Check() assert.True(t, err == nil) ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err == nil) @@ -86,35 +85,34 @@ func TestEmbeddingOK(t *testing.T) { } } - func TestEmbeddingRetry(t *testing.T) { var res EmbeddingResponse res.Object = "list" res.Model = "text-embedding-3-small" res.Data = []EmbeddingData{ { - Object: "embedding", + Object: "embedding", Embedding: []float32{1.1, 2.2, 3.2, 4.5}, - Index: 2, + Index: 2, }, { - Object: "embedding", + Object: "embedding", Embedding: []float32{1.1, 2.2, 3.3, 4.4}, - Index: 0, + Index: 0, }, { - Object: "embedding", + Object: "embedding", Embedding: []float32{1.1, 2.2, 3.2, 4.3}, - Index: 1, + Index: 1, }, } res.Usage = Usage{ PromptTokens: 1, - TotalTokens: 100, + TotalTokens: 100, } var count int32 = 0 - + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if atomic.LoadInt32(&count) < 2 { atomic.AddInt32(&count, 1) @@ -131,7 +129,7 @@ func TestEmbeddingRetry(t *testing.T) { { c := OpenAIEmbeddingClient{"mock_key", url} - err := c.Check(); + err := c.Check() assert.True(t, err == nil) ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err == nil) @@ -145,7 +143,6 @@ func TestEmbeddingRetry(t *testing.T) { } } - func TestEmbeddingFailed(t *testing.T) { var count int32 = 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -158,7 +155,7 @@ func TestEmbeddingFailed(t *testing.T) { { c := OpenAIEmbeddingClient{"mock_key", url} - err := c.Check(); + err := c.Check() assert.True(t, err == nil) _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err != nil) @@ -179,7 +176,7 @@ func TestTimeout(t *testing.T) { { c := OpenAIEmbeddingClient{"mock_key", url} - err := c.Check(); + err := c.Check() assert.True(t, err == nil) _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1) assert.True(t, err != nil) diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index 7b1d3a71c99db..a45c0046e6de2 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -1,25 +1,24 @@ package proxy import ( - "io" - "fmt" "context" - "testing" + "encoding/json" + "io" "net/http" "net/http/httptest" - "encoding/json" + "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/milvus-io/milvus/internal/models" - "github.com/milvus-io/milvus/internal/allocator" - "github.com/milvus-io/milvus/internal/mocks" - "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/models" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -320,7 +319,7 @@ func TestMaxInsertSize(t *testing.T) { } func TestInsertTask_Function(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request){ + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req models.EmbeddingRequest body, _ := io.ReadAll(r.Body) defer r.Body.Close() @@ -331,15 +330,15 @@ func TestInsertTask_Function(t *testing.T) { res.Model = "text-embedding-3-small" for i := 0; i < len(req.Input); i++ { res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", + Object: "embedding", Embedding: make([]float32, req.Dimensions), - Index: i, + Index: i, }) } res.Usage = models.Usage{ PromptTokens: 1, - TotalTokens: 100, + TotalTokens: 100, } w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) @@ -348,8 +347,8 @@ func TestInsertTask_Function(t *testing.T) { defer ts.Close() data := []*schemapb.FieldData{} f := schemapb.FieldData{ - Type: schemapb.DataType_VarChar, - FieldId: 101, + Type: schemapb.DataType_VarChar, + FieldId: 101, FieldName: "text", IsDynamic: false, Field: &schemapb.FieldData_Scalars{ @@ -418,10 +417,10 @@ func TestInsertTask_Function(t *testing.T) { }, Version: msgpb.InsertDataVersion_ColumnBased, FieldsData: data, - NumRows: 2, + NumRows: 2, }, }, - schema: schema, + schema: schema, idAllocator: idAllocator, } diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index b6082452216d6..6c7f271f74e90 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -426,7 +426,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { if err != nil { return err } - if err := exec.ProcessSearchReq(t.SearchRequest); err != nil { + if err := exec.ProcessSearch(t.SearchRequest); err != nil { return err } } @@ -516,7 +516,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { if err != nil { return err } - if err := exec.ProcessSearchReq(t.SearchRequest); err != nil { + if err := exec.ProcessSearch(t.SearchRequest); err != nil { return err } } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 1edf764c8b418..f775fa53deb7e 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -17,8 +17,12 @@ package proxy import ( "context" + "encoding/json" "fmt" + "io" "math" + "net/http" + "net/http/httptest" "strconv" "strings" "testing" @@ -26,6 +30,7 @@ import ( "github.com/cockroachdb/errors" "github.com/google/uuid" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -37,11 +42,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/models" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -475,6 +482,245 @@ func TestSearchTask_PreExecute(t *testing.T) { }) } +func TestSearchTask_WithFunctions(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req models.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: make([]float32, req.Dimensions), + Index: i, + }) + } + + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + defer ts.Close() + + collectionName := "TestInsertTask_function" + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "TestInsertTask_function", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length", Value: "200"}, + }}, + {FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "func1", + Type: schemapb.FunctionType_OpenAIEmbedding, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, + {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, + {Key: function.DimParamKey, Value: "4"}, + }, + }, + { + Name: "func2", + Type: schemapb.FunctionType_OpenAIEmbedding, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{103}, + Params: []*commonpb.KeyValuePair{ + {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, + {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, + {Key: function.DimParamKey, Value: "4"}, + }, + }, + }, + } + + var err error + var ( + rc = NewRootCoordMock() + qc = mocks.NewMockQueryCoordClient(t) + ctx = context.TODO() + ) + + defer rc.Close() + require.NoError(t, err) + mgr := newShardClientMgr() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() + err = InitMetaCache(ctx, rc, qc, mgr) + require.NoError(t, err) + + getSearchTask := func(t *testing.T, collName string, data []string) *searchTask { + placeholderValue := &commonpb.PlaceholderValue{ + Tag: "$0", + Type: commonpb.PlaceholderType_VarChar, + Values: lo.Map(data, func(str string, _ int) []byte { return []byte(str) }), + } + holder := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{placeholderValue}, + } + holderByte, _ := proto.Marshal(holder) + task := &searchTask{ + ctx: ctx, + collectionName: collectionName, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + Timestamp: uint64(time.Now().UnixNano()), + }, + }, + request: &milvuspb.SearchRequest{ + CollectionName: collectionName, + Nq: int64(len(data)), + SearchParams: []*commonpb.KeyValuePair{ + {Key: AnnsFieldKey, Value: "vector1"}, + {Key: TopKKey, Value: "10"}, + }, + PlaceholderGroup: holderByte, + }, + qc: qc, + tr: timerecord.NewTimeRecorder("test-search"), + } + require.NoError(t, task.OnEnqueue()) + return task + } + + collectionID := UniqueID(1000) + cache := NewMockCache(t) + info := newSchemaInfo(schema) + cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe() + cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(info, nil).Maybe() + cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe() + cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionBasicInfo{}, nil).Maybe() + cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe() + cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe() + globalMetaCache = cache + + { + task := getSearchTask(t, collectionName, []string{"sentence"}) + err = task.PreExecute(ctx) + assert.NoError(t, err) + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(task.SearchRequest.PlaceholderGroup, pb) + assert.Equal(t, len(pb.Placeholders), 1) + assert.Equal(t, len(pb.Placeholders[0].Values), 1) + assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector) + } + + { + task := getSearchTask(t, collectionName, []string{"sentence 1", "sentence 2"}) + err = task.PreExecute(ctx) + assert.NoError(t, err) + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(task.SearchRequest.PlaceholderGroup, pb) + assert.Equal(t, len(pb.Placeholders), 1) + assert.Equal(t, len(pb.Placeholders[0].Values), 2) + assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector) + } + + getHybridSearchTask := func(t *testing.T, collName string, data [][]string) *searchTask { + subReqs := []*milvuspb.SubSearchRequest{} + for _, item := range data { + placeholderValue := &commonpb.PlaceholderValue{ + Tag: "$0", + Type: commonpb.PlaceholderType_VarChar, + Values: lo.Map(item, func(str string, _ int) []byte { return []byte(str) }), + } + holder := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{placeholderValue}, + } + holderByte, _ := proto.Marshal(holder) + subReq := &milvuspb.SubSearchRequest{ + PlaceholderGroup: holderByte, + SearchParams: []*commonpb.KeyValuePair{ + {Key: AnnsFieldKey, Value: "vector1"}, + {Key: TopKKey, Value: "10"}, + }, + Nq: int64(len(item)), + } + subReqs = append(subReqs, subReq) + } + task := &searchTask{ + ctx: ctx, + collectionName: collectionName, + SearchRequest: &internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + Timestamp: uint64(time.Now().UnixNano()), + }, + }, + request: &milvuspb.SearchRequest{ + CollectionName: collectionName, + SubReqs: subReqs, + SearchParams: []*commonpb.KeyValuePair{ + {Key: LimitKey, Value: "10"}, + }, + }, + qc: qc, + tr: timerecord.NewTimeRecorder("test-search"), + } + require.NoError(t, task.OnEnqueue()) + return task + } + + { + task := getHybridSearchTask(t, collectionName, [][]string{ + {"sentence1"}, + {"sentence2"}, + }) + err = task.PreExecute(ctx) + assert.NoError(t, err) + assert.Equal(t, len(task.SearchRequest.SubReqs), 2) + for _, sub := range task.SearchRequest.SubReqs { + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(sub.PlaceholderGroup, pb) + assert.Equal(t, len(pb.Placeholders), 1) + assert.Equal(t, len(pb.Placeholders[0].Values), 1) + assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector) + } + } + + { + task := getHybridSearchTask(t, collectionName, [][]string{ + {"sentence1", "sentence1"}, + {"sentence2", "sentence2"}, + {"sentence3", "sentence3"}, + }) + err = task.PreExecute(ctx) + assert.NoError(t, err) + assert.Equal(t, len(task.SearchRequest.SubReqs), 3) + for _, sub := range task.SearchRequest.SubReqs { + pb := &commonpb.PlaceholderGroup{} + proto.Unmarshal(sub.PlaceholderGroup, pb) + assert.Equal(t, len(pb.Placeholders), 1) + assert.Equal(t, len(pb.Placeholders[0].Values), 2) + assert.Equal(t, pb.Placeholders[0].Type, commonpb.PlaceholderType_FloatVector) + } + } +} + func getQueryCoord() *mocks.MockQueryCoord { qc := &mocks.MockQueryCoord{} qc.EXPECT().Start().Return(nil) diff --git a/internal/util/function/function.go b/internal/util/function/function.go index 9eeaa110c3d03..a9056af41298d 100644 --- a/internal/util/function/function.go +++ b/internal/util/function/function.go @@ -31,7 +31,6 @@ type FunctionRunner interface { GetOutputFields() []*schemapb.FieldSchema } - func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (FunctionRunner, error) { switch schema.GetType() { case schemapb.FunctionType_BM25: diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index 1914ed7dbaae2..393e03d131c52 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -24,9 +24,8 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) - type FunctionBase struct { - schema *schemapb.FunctionSchema + schema *schemapb.FunctionSchema outputFields []*schemapb.FieldSchema } diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index 51913d54a97c0..6ad806401cd5d 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -16,24 +16,21 @@ * # limitations under the License. */ - package function - import ( "fmt" "sync" "google.golang.org/protobuf/proto" - "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" ) - type Runner interface { GetSchema() *schemapb.FunctionSchema GetOutputFields() []*schemapb.FieldSchema @@ -81,10 +78,10 @@ func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, return executor, nil } -func (executor *FunctionExecutor)processSingleFunction(runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { +func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds)) for _, id := range runner.GetSchema().InputFieldIds { - for _, field := range msg.FieldsData{ + for _, field := range msg.FieldsData { if field.FieldId == id { inputs = append(inputs, field) } @@ -102,14 +99,14 @@ func (executor *FunctionExecutor)processSingleFunction(runner Runner, msg *msgst return outputs, nil } -func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error { +func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error { numRows := msg.NumRows for _, runner := range executor.runners { if numRows > uint64(runner.MaxBatch()) { return fmt.Errorf("numRows [%d] > function [%s]'s max batch [%d]", numRows, runner.GetSchema().Name, runner.MaxBatch()) } } - + outputs := make(chan []*schemapb.FieldData, len(executor.runners)) errChan := make(chan error, len(executor.runners)) var wg sync.WaitGroup @@ -138,7 +135,7 @@ func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error { return nil } -func (executor *FunctionExecutor)processSingleSearch(runner Runner, placeholderGroup []byte) ([]byte, error) { +func (executor *FunctionExecutor) processSingleSearch(runner Runner, placeholderGroup []byte) ([]byte, error) { pb := &commonpb.PlaceholderGroup{} proto.Unmarshal(placeholderGroup, pb) if len(pb.Placeholders) != 1 { @@ -154,7 +151,7 @@ func (executor *FunctionExecutor)processSingleSearch(runner Runner, placeholderG return proto.Marshal(res) } -func (executor *FunctionExecutor)prcessSearch(req *internalpb.SearchRequest) error { +func (executor *FunctionExecutor) prcessSearch(req *internalpb.SearchRequest) error { runner, exist := executor.runners[req.FieldId] if !exist { return nil @@ -170,7 +167,7 @@ func (executor *FunctionExecutor)prcessSearch(req *internalpb.SearchRequest) err return nil } -func (executor *FunctionExecutor)prcessAdvanceSearch(req *internalpb.SearchRequest) error { +func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequest) error { outputs := make(chan map[int64][]byte, len(req.GetSubReqs())) errChan := make(chan error, len(req.GetSubReqs())) var wg sync.WaitGroup @@ -179,6 +176,7 @@ func (executor *FunctionExecutor)prcessAdvanceSearch(req *internalpb.SearchReque if sub.Nq > int64(runner.MaxBatch()) { return fmt.Errorf("Nq [%d] > function [%s]'s max batch [%d]", sub.Nq, runner.GetSchema().Name, runner.MaxBatch()) } + wg.Add(1) go func(runner Runner, idx int64) { defer wg.Done() if newHolder, err := executor.processSingleSearch(runner, sub.GetPlaceholderGroup()); err != nil { @@ -205,8 +203,8 @@ func (executor *FunctionExecutor)prcessAdvanceSearch(req *internalpb.SearchReque return nil } -func (executor *FunctionExecutor)ProcessSearchReq(req *internalpb.SearchRequest) error { - if req.IsAdvanced { +func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) error { + if !req.IsAdvanced { return executor.prcessSearch(req) } else { return executor.prcessAdvanceSearch(req) diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index 8eef0b7b55857..b7e7e7cc932cc 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -16,27 +16,24 @@ * # limitations under the License. */ - package function - import ( + "encoding/json" "io" - "testing" "net/http" "net/http/httptest" - "encoding/json" + "testing" "github.com/stretchr/testify/suite" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/models" "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" ) - func TestFunctionExecutor(t *testing.T) { suite.Run(t, new(FunctionExecutorSuite)) } @@ -45,8 +42,7 @@ type FunctionExecutorSuite struct { suite.Suite } - -func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema{ +func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema { return &schemapb.CollectionSchema{ Name: "test", Fields: []*schemapb.FieldSchema{ @@ -59,7 +55,7 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "8"}, - }}, + }}, }, Functions: []*schemapb.FunctionSchema{ { @@ -72,7 +68,7 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: url}, {Key: DimParamKey, Value: "4"}, - }, + }, }, { Name: "test", @@ -88,15 +84,15 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch }, }, } - + } -func (s *FunctionExecutorSuite)createMsg(texts []string) *msgstream.InsertMsg{ +func (s *FunctionExecutorSuite) createMsg(texts []string) *msgstream.InsertMsg { data := []*schemapb.FieldData{} f := schemapb.FieldData{ - Type: schemapb.DataType_VarChar, - FieldId: 101, + Type: schemapb.DataType_VarChar, + FieldId: 101, IsDynamic: false, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ @@ -118,13 +114,13 @@ func (s *FunctionExecutorSuite)createMsg(texts []string) *msgstream.InsertMsg{ return &msg } -func (s *FunctionExecutorSuite)createEmbedding(texts []string, dim int) [][]float32{ +func (s *FunctionExecutorSuite) createEmbedding(texts []string, dim int) [][]float32 { embeddings := make([][]float32, 0) for i := 0; i < len(texts); i++ { f := float32(i) emb := make([]float32, 0) for j := 0; j < dim; j++ { - emb = append(emb, f + float32(j) * 0.1) + emb = append(emb, f+float32(j)*0.1) } embeddings = append(embeddings, emb) } @@ -132,7 +128,7 @@ func (s *FunctionExecutorSuite)createEmbedding(texts []string, dim int) [][]floa } func (s *FunctionExecutorSuite) TestExecutor() { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request){ + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req models.EmbeddingRequest body, _ := io.ReadAll(r.Body) defer r.Body.Close() @@ -144,20 +140,20 @@ func (s *FunctionExecutorSuite) TestExecutor() { embs := s.createEmbedding(req.Input, req.Dimensions) for i := 0; i < len(req.Input); i++ { res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", + Object: "embedding", Embedding: embs[i], - Index: i, + Index: i, }) } res.Usage = models.Usage{ PromptTokens: 1, - TotalTokens: 100, + TotalTokens: 100, } w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - + })) defer ts.Close() @@ -170,7 +166,7 @@ func (s *FunctionExecutorSuite) TestExecutor() { } func (s *FunctionExecutorSuite) TestErrorEmbedding() { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request){ + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req models.EmbeddingRequest body, _ := io.ReadAll(r.Body) defer r.Body.Close() @@ -181,20 +177,20 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() { res.Model = "text-embedding-3-small" for i := 0; i < len(req.Input); i++ { res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", + Object: "embedding", Embedding: []float32{}, - Index: i, + Index: i, }) } res.Usage = models.Usage{ PromptTokens: 1, - TotalTokens: 100, + TotalTokens: 100, } w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - + })) defer ts.Close() schema := s.creataSchema(ts.URL) diff --git a/internal/util/function/function_util.go b/internal/util/function/function_util.go index b725ac2a66bb1..cc32a2bfde397 100644 --- a/internal/util/function/function_util.go +++ b/internal/util/function/function_util.go @@ -16,16 +16,14 @@ * # limitations under the License. */ - package function - import ( - "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" ) -func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool{ +func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool { // Determine whether the column corresponding to outputIDs contains functions, except bm25 function, // if outputIDs is empty, check all cols for _, f_schema := range functions { diff --git a/internal/util/function/openai_embedding_function.go b/internal/util/function/openai_embedding_function.go index 5554c5e461d30..15e7409628507 100644 --- a/internal/util/function/openai_embedding_function.go +++ b/internal/util/function/openai_embedding_function.go @@ -24,42 +24,40 @@ import ( "strconv" "strings" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/models" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) - const ( - TextEmbeddingAda002 string = "text-embedding-ada-002" + TextEmbeddingAda002 string = "text-embedding-ada-002" TextEmbedding3Small string = "text-embedding-3-small" TextEmbedding3Large string = "text-embedding-3-large" ) const ( - maxBatch = 128 + maxBatch = 128 timeoutSec = 30 ) const ( - ModelNameParamKey string = "model_name" - DimParamKey string = "dim" - UserParamKey string = "user" + ModelNameParamKey string = "model_name" + DimParamKey string = "dim" + UserParamKey string = "user" OpenaiEmbeddingUrlParamKey string = "embedding_url" - OpenaiApiKeyParamKey string = "api_key" + OpenaiApiKeyParamKey string = "api_key" ) - type OpenAIEmbeddingFunction struct { FunctionBase fieldDim int64 - - client *models.OpenAIEmbeddingClient - modelName string + + client *models.OpenAIEmbeddingClient + modelName string embedDimParam int64 - user string + user string } func createOpenAIEmbeddingClient(apiKey string, url string) (*models.OpenAIEmbeddingClient, error) { @@ -125,18 +123,18 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap default: } } - + c, err := createOpenAIEmbeddingClient(apiKey, url) if err != nil { return nil, err } runner := OpenAIEmbeddingFunction{ - FunctionBase: *base, - client: c, - fieldDim: fieldDim, - modelName: modelName, - user: user, + FunctionBase: *base, + client: c, + fieldDim: fieldDim, + modelName: modelName, + user: user, embedDimParam: dim, } @@ -147,18 +145,17 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap return &runner, nil } -func (runner *OpenAIEmbeddingFunction)MaxBatch() int { - return 5 * maxBatch +func (runner *OpenAIEmbeddingFunction) MaxBatch() int { + return 5 * maxBatch } - -func (runner *OpenAIEmbeddingFunction)callEmbedding(texts []string) ([][]float32, error) { +func (runner *OpenAIEmbeddingFunction) callEmbedding(texts []string) ([][]float32, error) { numRows := len(texts) if numRows > runner.MaxBatch() { return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows) } - - data := make([][]float32, numRows) + + data := make([][]float32, 0, numRows) for i := 0; i < numRows; i += maxBatch { end := i + maxBatch if end > numRows { @@ -168,8 +165,8 @@ func (runner *OpenAIEmbeddingFunction)callEmbedding(texts []string) ([][]float32 if err != nil { return nil, err } - if end - i != len(resp.Data) { - return nil, fmt.Errorf("The texts number is [%d], but got embedding number [%d]", end - i, len(resp.Data)) + if end-i != len(resp.Data) { + return nil, fmt.Errorf("The texts number is [%d], but got embedding number [%d]", end-i, len(resp.Data)) } for _, item := range resp.Data { if len(item.Embedding) != int(runner.fieldDim) { @@ -200,7 +197,7 @@ func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldDat if err != nil { return nil, err } - data := make([]float32, 0, len(texts) * int(runner.fieldDim)) + data := make([]float32, 0, len(texts)*int(runner.fieldDim)) for _, emb := range embds { data = append(data, emb...) } @@ -223,10 +220,10 @@ func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldDat return []*schemapb.FieldData{&outputField}, nil } -func (runner *OpenAIEmbeddingFunction)ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error){ +func (runner *OpenAIEmbeddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) { texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally embds, err := runner.callEmbedding(texts) - if err == nil { + if err != nil { return nil, err } return funcutil.Float32VectorsToPlaceholderGroup(embds), nil diff --git a/internal/util/function/openai_embedding_function_test.go b/internal/util/function/openai_embedding_function_test.go index 46e09d8e68108..81f8ede4fe7e1 100644 --- a/internal/util/function/openai_embedding_function_test.go +++ b/internal/util/function/openai_embedding_function_test.go @@ -16,25 +16,23 @@ * # limitations under the License. */ - package function import ( + "encoding/json" "io" - "testing" "net/http" "net/http/httptest" - "encoding/json" + "testing" "github.com/stretchr/testify/suite" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/models" + "github.com/milvus-io/milvus/internal/models" ) - func TestOpenAIEmbeddingFunction(t *testing.T) { suite.Run(t, new(OpenAIEmbeddingFunctionSuite)) } @@ -58,11 +56,11 @@ func (s *OpenAIEmbeddingFunctionSuite) SetupTest() { } } -func createData(texts []string) []*schemapb.FieldData{ +func createData(texts []string) []*schemapb.FieldData { data := []*schemapb.FieldData{} f := schemapb.FieldData{ - Type: schemapb.DataType_VarChar, - FieldId: 101, + Type: schemapb.DataType_VarChar, + FieldId: 101, IsDynamic: false, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ @@ -84,7 +82,7 @@ func createEmbedding(texts []string, dim int) [][]float32 { f := float32(i) emb := make([]float32, 0) for j := 0; j < dim; j++ { - emb = append(emb, f + float32(j) * 0.1) + emb = append(emb, f+float32(j)*0.1) } embeddings = append(embeddings, emb) } @@ -118,20 +116,20 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() { embs := createEmbedding(req.Input, 4) for i := 0; i < len(req.Input); i++ { res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", + Object: "embedding", Embedding: embs[i], - Index: i, + Index: i, }) } res.Usage = models.Usage{ PromptTokens: 1, - TotalTokens: 100, + TotalTokens: 100, } w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - + })) defer ts.Close() @@ -158,19 +156,19 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingDimNotMatch() { res.Object = "list" res.Model = "text-embedding-3-small" res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", + Object: "embedding", Embedding: []float32{1.0, 1.0, 1.0, 1.0}, - Index: 0, + Index: 0, }) res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", + Object: "embedding", Embedding: []float32{1.0, 1.0}, - Index: 1, + Index: 1, }) res.Usage = models.Usage{ PromptTokens: 1, - TotalTokens: 100, + TotalTokens: 100, } w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) @@ -193,13 +191,13 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() { res.Object = "list" res.Model = "text-embedding-3-small" res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", + Object: "embedding", Embedding: []float32{1.0, 1.0, 1.0, 1.0}, - Index: 0, + Index: 0, }) res.Usage = models.Usage{ PromptTokens: 1, - TotalTokens: 100, + TotalTokens: 100, } w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) @@ -208,7 +206,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() { defer ts.Close() runner, err := createRunner(ts.URL, s.schema) - + s.NoError(err) // embedding dim not match @@ -291,7 +289,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: DimParamKey, Value: "4"}, {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, - }, + }, }) s.Error(err) } diff --git a/pkg/util/funcutil/placeholdergroup.go b/pkg/util/funcutil/placeholdergroup.go index abdec286e6fd8..60aa1aa7ecbe4 100644 --- a/pkg/util/funcutil/placeholdergroup.go +++ b/pkg/util/funcutil/placeholdergroup.go @@ -26,14 +26,14 @@ func SparseVectorDataToPlaceholderGroupBytes(contents [][]byte) []byte { } func Float32VectorsToPlaceholderGroup(embs [][]float32) *commonpb.PlaceholderGroup { - result := make([][]byte, 0) + result := make([][]byte, 0, len(embs)) for _, floatVector := range embs { result = append(result, floatVectorToByteVector(floatVector)) } placeholderGroup := &commonpb.PlaceholderGroup{ Placeholders: []*commonpb.PlaceholderValue{{ Tag: "$0", - Type: commonpb.PlaceholderType_SparseFloatVector, + Type: commonpb.PlaceholderType_FloatVector, Values: result, }}, } From a5776f3e10b91ed53e5eebcabbec1a262f418109 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Mon, 28 Oct 2024 20:12:01 +0800 Subject: [PATCH 07/18] Support bulkinsert Signed-off-by: junjie.jiang --- internal/datanode/importv2/scheduler_test.go | 99 +++++++++++++++++++ internal/datanode/importv2/util.go | 27 +++++ .../pipeline/flow_graph_embedding_node.go | 3 + internal/proxy/task_insert_test.go | 31 +----- internal/proxy/task_search_test.go | 29 +----- .../querynodev2/pipeline/embedding_node.go | 3 + internal/util/function/function.go | 2 + internal/util/function/function_executor.go | 33 +++++++ .../util/function/function_executor_test.go | 29 +----- .../util/function/mock_embedding_service.go | 71 +++++++++++++ .../function/openai_embedding_function.go | 41 +++++++- 11 files changed, 278 insertions(+), 90 deletions(-) create mode 100644 internal/util/function/mock_embedding_service.go diff --git a/internal/datanode/importv2/scheduler_test.go b/internal/datanode/importv2/scheduler_test.go index 7752c382187d1..9ba2cd3bd5061 100644 --- a/internal/datanode/importv2/scheduler_test.go +++ b/internal/datanode/importv2/scheduler_test.go @@ -37,6 +37,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/importutilv2" "github.com/milvus-io/milvus/internal/util/testutil" "github.com/milvus-io/milvus/pkg/common" @@ -435,6 +436,104 @@ func (s *SchedulerSuite) TestScheduler_ImportFile() { s.NoError(err) } +func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() { + s.syncMgr.EXPECT().SyncData(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, task syncmgr.Task, callbacks ...func(error) error) *conc.Future[struct{}] { + future := conc.Go(func() (struct{}, error) { + return struct{}{}, nil + }) + return future + }) + ts := function.CreateEmbeddingServer() + defer ts.Close() + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.MaxLengthKey, Value: "128"}, + }, + }, + { + FieldID: 101, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "4", + }, + }, + }, + { + FieldID: 102, + Name: "int64", + DataType: schemapb.DataType_Int64, + }, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "test", + Type: schemapb.FunctionType_OpenAIEmbedding, + InputFieldIds: []int64{100}, + OutputFieldIds: []int64{101}, + Params: []*commonpb.KeyValuePair{ + {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, + {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, + {Key: function.DimParamKey, Value: "4"}, + }, + }, + }, + } + + var once sync.Once + data, err := testutil.CreateInsertData(schema, s.numRows) + s.NoError(err) + s.reader = importutilv2.NewMockReader(s.T()) + s.reader.EXPECT().Read().RunAndReturn(func() (*storage.InsertData, error) { + var res *storage.InsertData + once.Do(func() { + res = data + }) + if res != nil { + return res, nil + } + return nil, io.EOF + }) + importReq := &datapb.ImportRequest{ + JobID: 10, + TaskID: 11, + CollectionID: 12, + PartitionIDs: []int64{13}, + Vchannels: []string{"v0"}, + Schema: schema, + Files: []*internalpb.ImportFile{ + { + Paths: []string{"dummy.json"}, + }, + }, + Ts: 1000, + IDRange: &datapb.IDRange{ + Begin: 0, + End: int64(s.numRows), + }, + RequestSegments: []*datapb.ImportRequestSegment{ + { + SegmentID: 14, + PartitionID: 13, + Vchannel: "v0", + }, + }, + } + importTask := NewImportTask(importReq, s.manager, s.syncMgr, s.cm) + s.manager.Add(importTask) + err = importTask.(*ImportTask).importFile(s.reader) + s.NoError(err) +} + func TestScheduler(t *testing.T) { suite.Run(t, new(SchedulerSuite)) } diff --git a/internal/datanode/importv2/util.go b/internal/datanode/importv2/util.go index eb6c592f85b12..9e683a6f65549 100644 --- a/internal/datanode/importv2/util.go +++ b/internal/datanode/importv2/util.go @@ -208,12 +208,39 @@ func AppendSystemFieldsData(task *ImportTask, data *storage.InsertData) error { } func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error { + if err := RunBm25Function(task, data); err != nil { + return err + } + if err := RunDenseEmbedding(task, data); err != nil { + return err + } + return nil +} + +func RunDenseEmbedding(task *ImportTask, data *storage.InsertData) error { + schema := task.GetSchema() + if function.HasFunctions(schema.Functions, []int64{}) { + exec, err := function.NewFunctionExecutor(schema) + if err != nil { + return err + } + if err := exec.ProcessBulkInsert(data); err != nil { + return err + } + } + return nil +} + +func RunBm25Function(task *ImportTask, data *storage.InsertData) error { fns := task.GetSchema().GetFunctions() for _, fn := range fns { runner, err := function.NewFunctionRunner(task.GetSchema(), fn) if err != nil { return err } + if runner == nil { + continue + } inputDatas := make([]any, 0, len(fn.InputFieldIds)) for _, inputFieldID := range fn.InputFieldIds { inputDatas = append(inputDatas, data.Data[inputFieldID].GetDataRows()) diff --git a/internal/flushcommon/pipeline/flow_graph_embedding_node.go b/internal/flushcommon/pipeline/flow_graph_embedding_node.go index bf809c49aeb7a..f264b8b39aa8e 100644 --- a/internal/flushcommon/pipeline/flow_graph_embedding_node.go +++ b/internal/flushcommon/pipeline/flow_graph_embedding_node.go @@ -67,6 +67,9 @@ func newEmbeddingNode(channelName string, schema *schemapb.CollectionSchema) (*e if err != nil { return nil, err } + if functionRunner == nil { + continue + } node.functionRunners[tf.GetId()] = functionRunner } return node, nil diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index a45c0046e6de2..e20270757e8f5 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -2,10 +2,6 @@ package proxy import ( "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -16,7 +12,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/mocks" - "github.com/milvus-io/milvus/internal/models" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -319,31 +314,7 @@ func TestMaxInsertSize(t *testing.T) { } func TestInsertTask_Function(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req models.EmbeddingRequest - body, _ := io.ReadAll(r.Body) - defer r.Body.Close() - json.Unmarshal(body, &req) - - var res models.EmbeddingResponse - res.Object = "list" - res.Model = "text-embedding-3-small" - for i := 0; i < len(req.Input); i++ { - res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", - Embedding: make([]float32, req.Dimensions), - Index: i, - }) - } - - res.Usage = models.Usage{ - PromptTokens: 1, - TotalTokens: 100, - } - w.WriteHeader(http.StatusOK) - data, _ := json.Marshal(res) - w.Write(data) - })) + ts := function.CreateEmbeddingServer() defer ts.Close() data := []*schemapb.FieldData{} f := schemapb.FieldData{ diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index f775fa53deb7e..4493726a486a2 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -17,7 +17,6 @@ package proxy import ( "context" - "encoding/json" "fmt" "io" "math" @@ -42,7 +41,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks" - "github.com/milvus-io/milvus/internal/models" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -483,33 +481,8 @@ func TestSearchTask_PreExecute(t *testing.T) { } func TestSearchTask_WithFunctions(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req models.EmbeddingRequest - body, _ := io.ReadAll(r.Body) - defer r.Body.Close() - json.Unmarshal(body, &req) - - var res models.EmbeddingResponse - res.Object = "list" - res.Model = "text-embedding-3-small" - for i := 0; i < len(req.Input); i++ { - res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", - Embedding: make([]float32, req.Dimensions), - Index: i, - }) - } - - res.Usage = models.Usage{ - PromptTokens: 1, - TotalTokens: 100, - } - w.WriteHeader(http.StatusOK) - data, _ := json.Marshal(res) - w.Write(data) - })) + ts := function.CreateEmbeddingServer() defer ts.Close() - collectionName := "TestInsertTask_function" schema := &schemapb.CollectionSchema{ Name: collectionName, diff --git a/internal/querynodev2/pipeline/embedding_node.go b/internal/querynodev2/pipeline/embedding_node.go index da75099f0b8af..76ef66e9c6ef1 100644 --- a/internal/querynodev2/pipeline/embedding_node.go +++ b/internal/querynodev2/pipeline/embedding_node.go @@ -70,6 +70,9 @@ func newEmbeddingNode(collectionID int64, channelName string, manager *DataManag if err != nil { return nil, err } + if functionRunner == nil { + continue + } node.functionRunners = append(node.functionRunners, functionRunner) } return node, nil diff --git a/internal/util/function/function.go b/internal/util/function/function.go index a9056af41298d..fcb451fe543af 100644 --- a/internal/util/function/function.go +++ b/internal/util/function/function.go @@ -35,6 +35,8 @@ func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.Functio switch schema.GetType() { case schemapb.FunctionType_BM25: return NewBM25FunctionRunner(coll, schema) + case schemapb.FunctionType_OpenAIEmbedding: + return nil, nil default: return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String()) } diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index 6ad806401cd5d..93d9ed68c7421 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -38,6 +39,7 @@ type Runner interface { MaxBatch() int ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) + ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) } type FunctionExecutor struct { @@ -210,3 +212,34 @@ func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) e return executor.prcessAdvanceSearch(req) } } + +func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) { + inputs := make([]storage.FieldData, 0, len(runner.GetSchema().InputFieldIds)) + for idx, id := range runner.GetSchema().InputFieldIds { + field, exist := data.Data[id] + if !exist { + return nil, fmt.Errorf("Can not find input field: [%s]", runner.GetSchema().GetInputFieldNames()[idx]) + } + inputs = append(inputs, field) + } + + outputs, err := runner.ProcessBulkInsert(inputs) + if err != nil { + return nil, err + } + return outputs, nil +} + +func (executor *FunctionExecutor) ProcessBulkInsert(data *storage.InsertData) error { + // Since concurrency has already been used in the outer layer, only a serial logic access model is used here. + for _, runner := range executor.runners { + output, err := executor.processSingleBulkInsert(runner, data) + if err != nil { + return nil + } + for k, v := range output { + data.Data[k] = v + } + } + return nil +} diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index b7e7e7cc932cc..9034e8eec0368 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -128,34 +128,7 @@ func (s *FunctionExecutorSuite) createEmbedding(texts []string, dim int) [][]flo } func (s *FunctionExecutorSuite) TestExecutor() { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req models.EmbeddingRequest - body, _ := io.ReadAll(r.Body) - defer r.Body.Close() - json.Unmarshal(body, &req) - - var res models.EmbeddingResponse - res.Object = "list" - res.Model = "text-embedding-3-small" - embs := s.createEmbedding(req.Input, req.Dimensions) - for i := 0; i < len(req.Input); i++ { - res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", - Embedding: embs[i], - Index: i, - }) - } - - res.Usage = models.Usage{ - PromptTokens: 1, - TotalTokens: 100, - } - w.WriteHeader(http.StatusOK) - data, _ := json.Marshal(res) - w.Write(data) - - })) - + ts := CreateEmbeddingServer() defer ts.Close() schema := s.creataSchema(ts.URL) exec, err := NewFunctionExecutor(schema) diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go new file mode 100644 index 0000000000000..9342ba59feaad --- /dev/null +++ b/internal/util/function/mock_embedding_service.go @@ -0,0 +1,71 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + + "github.com/milvus-io/milvus/internal/models" +) + +func mockEmbedding(texts []string, dim int) [][]float32 { + embeddings := make([][]float32, 0) + for i := 0; i < len(texts); i++ { + f := float32(i) + emb := make([]float32, 0) + for j := 0; j < dim; j++ { + emb = append(emb, f+float32(j)*0.1) + } + embeddings = append(embeddings, emb) + } + return embeddings +} + +func CreateEmbeddingServer() *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req models.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + embs := mockEmbedding(req.Input, req.Dimensions) + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: embs[i], + Index: i, + }) + } + + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + + })) + return ts +} diff --git a/internal/util/function/openai_embedding_function.go b/internal/util/function/openai_embedding_function.go index 15e7409628507..438fade0756f0 100644 --- a/internal/util/function/openai_embedding_function.go +++ b/internal/util/function/openai_embedding_function.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/models" + "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -149,9 +150,9 @@ func (runner *OpenAIEmbeddingFunction) MaxBatch() int { return 5 * maxBatch } -func (runner *OpenAIEmbeddingFunction) callEmbedding(texts []string) ([][]float32, error) { +func (runner *OpenAIEmbeddingFunction) callEmbedding(texts []string, batchLimit bool) ([][]float32, error) { numRows := len(texts) - if numRows > runner.MaxBatch() { + if batchLimit && numRows > runner.MaxBatch() { return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows) } @@ -193,7 +194,7 @@ func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldDat return nil, fmt.Errorf("Input texts is empty") } - embds, err := runner.callEmbedding(texts) + embds, err := runner.callEmbedding(texts, true) if err != nil { return nil, err } @@ -222,9 +223,41 @@ func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldDat func (runner *OpenAIEmbeddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) { texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally - embds, err := runner.callEmbedding(texts) + embds, err := runner.callEmbedding(texts, true) if err != nil { return nil, err } return funcutil.Float32VectorsToPlaceholderGroup(embds), nil } + +func (runner *OpenAIEmbeddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) + } + + if inputs[0].GetDataType() != schemapb.DataType_VarChar { + return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") + } + + texts, ok := inputs[0].GetDataRows().([]string) + if !ok { + return nil, fmt.Errorf("Input texts is empty") + } + + embds, err := runner.callEmbedding(texts, false) + if err != nil { + return nil, err + } + data := make([]float32, 0, len(texts)*int(runner.fieldDim)) + for _, emb := range embds { + data = append(data, emb...) + } + + field := &storage.FloatVectorFieldData{ + Data: data, + Dim: int(runner.fieldDim), + } + return map[storage.FieldID]storage.FieldData{ + runner.outputFields[0].FieldID: field, + }, nil +} From 5eea78fe1b2aae07e918f521c39c083490834e50 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Mon, 4 Nov 2024 15:06:04 +0800 Subject: [PATCH 08/18] Use TextEmbedding function Signed-off-by: junjie.jiang --- internal/datanode/importv2/scheduler_test.go | 3 +- .../proxy/httpserver/handler_v2.go | 4 +- internal/models/openai_embedding.go | 2 +- internal/proxy/task_insert_test.go | 12 +- internal/proxy/task_search_test.go | 6 +- internal/proxy/task_upsert_test.go | 120 ++++++++++++++++++ internal/proxy/util.go | 26 +++- internal/storage/utils.go | 9 +- internal/util/function/function.go | 2 +- internal/util/function/function_executor.go | 11 +- .../util/function/function_executor_test.go | 29 +++-- internal/util/function/function_util.go | 1 + .../util/function/text_embedding_function.go | 56 ++++++++ 13 files changed, 246 insertions(+), 35 deletions(-) create mode 100644 internal/util/function/text_embedding_function.go diff --git a/internal/datanode/importv2/scheduler_test.go b/internal/datanode/importv2/scheduler_test.go index 9ba2cd3bd5061..0c49a9685f1c8 100644 --- a/internal/datanode/importv2/scheduler_test.go +++ b/internal/datanode/importv2/scheduler_test.go @@ -476,10 +476,11 @@ func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() { Functions: []*schemapb.FunctionSchema{ { Name: "test", - Type: schemapb.FunctionType_OpenAIEmbedding, + Type: schemapb.FunctionType_TextEmbedding, InputFieldIds: []int64{100}, OutputFieldIds: []int64{101}, Params: []*commonpb.KeyValuePair{ + {Key: function.Provider, Value: function.OpenAIProvider}, {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index dbf76a8bfe17c..8525effd00c02 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -947,8 +947,8 @@ func generatePlaceholderGroup(ctx context.Context, body string, collSchema *sche if vectorField.GetIsFunctionOutput() { for _, function := range collSchema.Functions { - if function.Type == schemapb.FunctionType_BM25 { - // TODO: currently only BM25 function is supported, thus guarantees one input field to one output field + if function.Type == schemapb.FunctionType_BM25 || function.Type == schemapb.FunctionType_TextEmbedding { + // TODO: currently only BM25 & text embedding function is supported, thus guarantees one input field to one output field if function.OutputFieldNames[0] == vectorField.Name { dataType = schemapb.DataType_VarChar } diff --git a/internal/models/openai_embedding.go b/internal/models/openai_embedding.go index dc1c3660c7dda..a6cb41df112a5 100644 --- a/internal/models/openai_embedding.go +++ b/internal/models/openai_embedding.go @@ -185,7 +185,7 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim return nil, err } req.Header.Set("Content-Type", "application/json") - req.Header.Set("api-key", c.apiKey) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) var res EmbeddingResponse err = c.sendWithRetry(client, req, &res, 3) diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index e20270757e8f5..8c34b22e55477 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -347,15 +347,17 @@ func TestInsertTask_Function(t *testing.T) { {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, IsFunctionOutput: true}, }, Functions: []*schemapb.FunctionSchema{ { - Name: "test_function", - Type: schemapb.FunctionType_OpenAIEmbedding, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test_function", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ + {Key: function.Provider, Value: function.OpenAIProvider}, {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 4493726a486a2..f0a5de503c6dd 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -506,10 +506,11 @@ func TestSearchTask_WithFunctions(t *testing.T) { Functions: []*schemapb.FunctionSchema{ { Name: "func1", - Type: schemapb.FunctionType_OpenAIEmbedding, + Type: schemapb.FunctionType_TextEmbedding, InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ + {Key: function.Provider, Value: function.OpenAIProvider}, {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, @@ -518,10 +519,11 @@ func TestSearchTask_WithFunctions(t *testing.T) { }, { Name: "func2", - Type: schemapb.FunctionType_OpenAIEmbedding, + Type: schemapb.FunctionType_TextEmbedding, InputFieldIds: []int64{101}, OutputFieldIds: []int64{103}, Params: []*commonpb.KeyValuePair{ + {Key: function.Provider, Value: function.OpenAIProvider}, {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index 75fd39964b00e..eec832cafa4c3 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -27,8 +27,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/testutils" ) @@ -360,3 +365,118 @@ func TestUpsertTaskForReplicate(t *testing.T) { assert.Error(t, err) }) } + +func TestUpsertTask_Function(t *testing.T) { + ts := function.CreateEmbeddingServer() + defer ts.Close() + data := []*schemapb.FieldData{} + f1 := schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldId: 100, + FieldName: "id", + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{0, 1}, + }, + }, + }, + }, + } + data = append(data, &f1) + f2 := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + FieldName: "text", + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"sentence", "sentence"}, + }, + }, + }, + }, + } + data = append(data, &f2) + collectionName := "TestUpsertTask_function" + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "TestUpsertTask_function", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "max_length", Value: "200"}, + }}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, IsFunctionOutput: true}, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "test_function", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: function.Provider, Value: function.OpenAIProvider}, + {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, + {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, + {Key: function.DimParamKey, Value: "4"}, + }, + }, + }, + } + + info := newSchemaInfo(schema) + collectionID := UniqueID(0) + cache := NewMockCache(t) + globalMetaCache = cache + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + rc := mocks.NewMockRootCoordClient(t) + rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{ + Status: merr.Status(nil), + ID: collectionID, + Count: 10, + }, nil) + idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0) + idAllocator.Start() + defer idAllocator.Close() + assert.NoError(t, err) + task := upsertTask{ + ctx: context.Background(), + req: &milvuspb.UpsertRequest{ + CollectionName: collectionName, + }, + upsertMsg: &msgstream.UpsertMsg{ + InsertMsg: &msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Insert), + ), + CollectionName: collectionName, + DbName: "hooooooo", + Version: msgpb.InsertDataVersion_ColumnBased, + FieldsData: data, + NumRows: 2, + PartitionName: Params.CommonCfg.DefaultPartitionName.GetValue(), + }, + }, + }, + idAllocator: idAllocator, + schema: info, + result: &milvuspb.MutationResult{}, + } + err = task.insertPreExecute(ctx) + assert.NoError(t, err) +} diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 76b477400fea4..1c5daba64c681 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -718,6 +718,11 @@ func checkFunctionOutputField(function *schemapb.FunctionSchema, fields []*schem if !typeutil.IsSparseFloatVectorType(fields[0].GetDataType()) { return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String()) } + case schemapb.FunctionType_TextEmbedding: + if len(fields) != 1 || fields[0].DataType != schemapb.DataType_FloatVector { + return fmt.Errorf("TextEmbedding function output field must be a FloatVector field, got %d field with type %s", + len(fields), fields[0].DataType.String()) + } default: return fmt.Errorf("check output field for unknown function type") } @@ -744,7 +749,11 @@ func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schema if !h.EnableAnalyzer() { return fmt.Errorf("BM25 function input field must set enable_analyzer to true") } - + case schemapb.FunctionType_TextEmbedding: + if len(fields) != 1 || fields[0].DataType != schemapb.DataType_VarChar { + return fmt.Errorf("TextEmbedding function input field must be a VARCHAR field, got %d field with type %s", + len(fields), fields[0].DataType.String()) + } default: return fmt.Errorf("check input field with unknown function type") } @@ -786,6 +795,10 @@ func checkFunctionBasicParams(function *schemapb.FunctionSchema) error { if len(function.GetParams()) != 0 { return fmt.Errorf("BM25 function accepts no params") } + case schemapb.FunctionType_TextEmbedding: + if len(function.GetParams()) == 0 { + return fmt.Errorf("TextEmbedding function need provider and model_name params") + } default: return fmt.Errorf("check function params with unknown function type") } @@ -942,7 +955,7 @@ func fillFieldPropertiesBySchema(columns []*schemapb.FieldData, schema *schemapb expectColumnNum := 0 for _, field := range schema.GetFields() { fieldName2Schema[field.Name] = field - if !field.GetIsFunctionOutput() { + if !IsBM25FunctionOutputField(field) { expectColumnNum++ } } @@ -1519,12 +1532,12 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst if fieldSchema.GetDefaultValue() != nil && fieldSchema.IsPrimaryKey { return merr.WrapErrParameterInvalidMsg("primary key can't be with default value") } - if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() { + if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || IsBM25FunctionOutputField(fieldSchema) { // when inInsert, no need to pass when pk is autoid and SkipAutoIDCheck is false autoGenFieldNum++ } if _, ok := dataNameSet[fieldSchema.GetName()]; !ok { - if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || fieldSchema.GetIsFunctionOutput() { + if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && inInsert) || IsBM25FunctionOutputField(fieldSchema) { // autoGenField continue } @@ -1548,7 +1561,6 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst zap.Int64("primaryKeyNum", int64(primaryKeyNum))) return merr.WrapErrParameterInvalidMsg("more than 1 primary keys not supported, got %d", primaryKeyNum) } - expectedNum := len(schema.Fields) actualNum := len(insertMsg.FieldsData) + autoGenFieldNum @@ -2231,3 +2243,7 @@ func GetReplicateID(ctx context.Context, database, collectionName string) (strin replicateID, _ := common.GetReplicateID(dbInfo.properties) return replicateID, nil } + +func IsBM25FunctionOutputField(field *schemapb.FieldSchema) bool { + return field.GetIsFunctionOutput() && field.GetDataType() == schemapb.DataType_SparseFloatVector +} diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 5e444027dcc48..614c29057c896 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -371,7 +371,7 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap } for _, field := range collSchema.Fields { - if skipFunction && field.GetIsFunctionOutput() { + if skipFunction && IsBM25FunctionOutputField(field) { continue } @@ -503,7 +503,7 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche } length := 0 for _, field := range collSchema.Fields { - if field.GetIsFunctionOutput() { + if IsBM25FunctionOutputField(field) { continue } @@ -1340,3 +1340,8 @@ func (ni NullableInt) GetValue() int { func (ni NullableInt) IsNull() bool { return ni.Value == nil } + +// TODO: unify the function implementation, storage/utils.go & proxy/util.go +func IsBM25FunctionOutputField(field *schemapb.FieldSchema) bool { + return field.GetIsFunctionOutput() && field.GetDataType() == schemapb.DataType_SparseFloatVector +} diff --git a/internal/util/function/function.go b/internal/util/function/function.go index fcb451fe543af..7c3bae8ca4833 100644 --- a/internal/util/function/function.go +++ b/internal/util/function/function.go @@ -35,7 +35,7 @@ func NewFunctionRunner(coll *schemapb.CollectionSchema, schema *schemapb.Functio switch schema.GetType() { case schemapb.FunctionType_BM25: return NewBM25FunctionRunner(coll, schema) - case schemapb.FunctionType_OpenAIEmbedding: + case schemapb.FunctionType_TextEmbedding: return nil, nil default: return nil, fmt.Errorf("unknown functionRunner type %s", schema.GetType().String()) diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index 93d9ed68c7421..26f971c962ab5 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -50,8 +50,8 @@ func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSc switch schema.GetType() { case schemapb.FunctionType_BM25: // ignore bm25 function return nil, nil - case schemapb.FunctionType_OpenAIEmbedding: - f, err := NewOpenAIEmbeddingFunction(coll, schema) + case schemapb.FunctionType_TextEmbedding: + f, err := NewTextEmbeddingFunction(coll, schema) if err != nil { return nil, err } @@ -81,15 +81,14 @@ func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, } func (executor *FunctionExecutor) processSingleFunction(runner Runner, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { - inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds)) - for _, id := range runner.GetSchema().InputFieldIds { + inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().GetInputFieldNames())) + for _, name := range runner.GetSchema().GetInputFieldNames() { for _, field := range msg.FieldsData { - if field.FieldId == id { + if field.GetFieldName() == name { inputs = append(inputs, field) } } } - if len(inputs) != len(runner.GetSchema().InputFieldIds) { return nil, fmt.Errorf("Input field not found") } diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index 9034e8eec0368..66b546fb38c73 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -20,6 +20,7 @@ package function import ( "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -51,19 +52,23 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + IsFunctionOutput: true, + }, {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "8"}, - }}, + }, IsFunctionOutput: true}, }, Functions: []*schemapb.FunctionSchema{ { - Name: "test", - Type: schemapb.FunctionType_OpenAIEmbedding, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: url}, @@ -71,11 +76,13 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch }, }, { - Name: "test", - Type: schemapb.FunctionType_OpenAIEmbedding, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{103}, + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{103}, Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: url}, @@ -93,6 +100,7 @@ func (s *FunctionExecutorSuite) createMsg(texts []string) *msgstream.InsertMsg { f := schemapb.FieldData{ Type: schemapb.DataType_VarChar, FieldId: 101, + FieldName: "text", IsDynamic: false, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ @@ -168,6 +176,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() { defer ts.Close() schema := s.creataSchema(ts.URL) exec, err := NewFunctionExecutor(schema) + fmt.Println(err) s.NoError(err) msg := s.createMsg([]string{"sentence", "sentence"}) err = exec.ProcessInsert(msg) diff --git a/internal/util/function/function_util.go b/internal/util/function/function_util.go index cc32a2bfde397..bd0265336baa7 100644 --- a/internal/util/function/function_util.go +++ b/internal/util/function/function_util.go @@ -29,6 +29,7 @@ func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool for _, f_schema := range functions { switch f_schema.GetType() { case schemapb.FunctionType_BM25: + case schemapb.FunctionType_Unknown: default: if len(outputIDs) == 0 { return true diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go new file mode 100644 index 0000000000000..6813ff3c8ec47 --- /dev/null +++ b/internal/util/function/text_embedding_function.go @@ -0,0 +1,56 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +const ( + Provider string = "provider" +) + +const ( + OpenAIProvider string = "openai" +) + +func getProvider(schema *schemapb.FunctionSchema) (string, error) { + for _, param := range schema.Params { + switch strings.ToLower(param.Key) { + case Provider: + return strings.ToLower(param.Value), nil + default: + } + } + return "", fmt.Errorf("The provider parameter was not found in the function's parameters") +} + +func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*OpenAIEmbeddingFunction, error) { + provider, err := getProvider(schema) + if err != nil { + return nil, err + } + if provider == OpenAIProvider { + return NewOpenAIEmbeddingFunction(coll, schema) + } + return nil, fmt.Errorf("Provider: [%s] not exist, only supports [%s]", provider, OpenAIProvider) +} From df07a0614ba059470d45655f6fe969b880582e5e Mon Sep 17 00:00:00 2001 From: junjiejiangjjj Date: Wed, 6 Nov 2024 17:10:51 +0800 Subject: [PATCH 09/18] update Signed-off-by: junjiejiangjjj --- internal/util/function/function_executor.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index 26f971c962ab5..221826e0b7378 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -118,18 +118,24 @@ func (executor *FunctionExecutor) ProcessInsert(msg *msgstream.InsertMsg) error data, err := executor.processSingleFunction(runner, msg) if err != nil { errChan <- err - } else { - outputs <- data + return } - + outputs <- data }(runner) } wg.Wait() close(errChan) close(outputs) + + // Collect all errors + var errs []error for err := range errChan { - return err + errs = append(errs, err) } + if len(errs) > 0 { + return fmt.Errorf("multiple errors occurred: %v", errs) + } + for output := range outputs { msg.FieldsData = append(msg.FieldsData, output...) } From ecc4ad89af6af26b16ef79413010eeb91529c02e Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Mon, 9 Dec 2024 20:24:27 +0800 Subject: [PATCH 10/18] Add bedrock, azure, ali text embedding Signed-off-by: junjie.jiang --- internal/datanode/importv2/scheduler_test.go | 10 +- .../ali/ali_dashscope_text_embedding.go | 154 ++++++++++ .../ali/ali_dashscope_text_embedding_test.go | 118 ++++++++ .../models/{ => openai}/openai_embedding.go | 131 ++++++--- .../{ => openai}/openai_embedding_test.go | 24 +- internal/models/utils/embedding_util.go | 53 ++++ internal/proxy/task_insert_test.go | 10 +- internal/proxy/task_search_test.go | 18 +- internal/proxy/task_upsert_test.go | 10 +- .../util/function/ali_embedding_provider.go | 147 ++++++++++ .../alitext_embedding_provider_test.go | 166 +++++++++++ .../function/bedrock_embedding_provider.go | 210 ++++++++++++++ .../bedrock_text_embedding_provider_test.go | 108 +++++++ internal/util/function/common.go | 56 ++++ .../util/function/function_executor_test.go | 28 +- .../util/function/mock_embedding_service.go | 57 +++- .../function/openai_embedding_function.go | 263 ------------------ .../function/openai_embedding_provider.go | 187 +++++++++++++ .../openai_text_embedding_provider_test.go | 177 ++++++++++++ .../util/function/text_embedding_function.go | 169 ++++++++++- ...est.go => text_embedding_function_test.go} | 251 +++++++---------- 21 files changed, 1836 insertions(+), 511 deletions(-) create mode 100644 internal/models/ali/ali_dashscope_text_embedding.go create mode 100644 internal/models/ali/ali_dashscope_text_embedding_test.go rename internal/models/{ => openai}/openai_embedding.go (64%) rename internal/models/{ => openai}/openai_embedding_test.go (88%) create mode 100644 internal/models/utils/embedding_util.go create mode 100644 internal/util/function/ali_embedding_provider.go create mode 100644 internal/util/function/alitext_embedding_provider_test.go create mode 100644 internal/util/function/bedrock_embedding_provider.go create mode 100644 internal/util/function/bedrock_text_embedding_provider_test.go create mode 100644 internal/util/function/common.go delete mode 100644 internal/util/function/openai_embedding_function.go create mode 100644 internal/util/function/openai_embedding_provider.go create mode 100644 internal/util/function/openai_text_embedding_provider_test.go rename internal/util/function/{openai_embedding_function_test.go => text_embedding_function_test.go} (51%) diff --git a/internal/datanode/importv2/scheduler_test.go b/internal/datanode/importv2/scheduler_test.go index 0c49a9685f1c8..99f8e5d3593ee 100644 --- a/internal/datanode/importv2/scheduler_test.go +++ b/internal/datanode/importv2/scheduler_test.go @@ -443,7 +443,7 @@ func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() { }) return future }) - ts := function.CreateEmbeddingServer() + ts := function.CreateOpenAIEmbeddingServer() defer ts.Close() schema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ @@ -481,10 +481,10 @@ func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() { OutputFieldIds: []int64{101}, Params: []*commonpb.KeyValuePair{ {Key: function.Provider, Value: function.OpenAIProvider}, - {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, - {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, - {Key: function.DimParamKey, Value: "4"}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "4"}, }, }, }, diff --git a/internal/models/ali/ali_dashscope_text_embedding.go b/internal/models/ali/ali_dashscope_text_embedding.go new file mode 100644 index 0000000000000..c53194c4d2cad --- /dev/null +++ b/internal/models/ali/ali_dashscope_text_embedding.go @@ -0,0 +1,154 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ali + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "sort" + "time" + + "github.com/milvus-io/milvus/internal/models/utils" +) + +type Input struct { + Texts []string `json:"texts"` +} + +type Parameters struct { + TextType string `json:"text_type,omitempty"` + Dimension int `json:"dimension,omitempty"` + OutputType string `json:"output_type,omitempty"` +} + +type EmbeddingRequest struct { + // ID of the model to use. + Model string `json:"model"` + + // Input text to embed, encoded as a string. + Input Input `json:"input"` + + Parameters Parameters `json:"parameters,omitempty"` +} + +type Usage struct { + // The total number of tokens used by the request. + TotalTokens int `json:"total_tokens"` +} + +type SparseEmbedding struct { + Index int `json:"index"` + Value float32 `json:"value"` + Token string `json:"token"` +} + +type Embeddings struct { + TextIndex int `json:"text_index"` + Embedding []float32 `json:"embedding,omitempty"` + SparseEmbedding []SparseEmbedding `json:"sparse_embedding,omitempty"` +} + +type Output struct { + Embeddings []Embeddings `json:"embeddings"` +} + +type EmbeddingResponse struct { + Output Output `json:"output"` + Usage Usage `json:"usage"` + RequestID string `json:"request_id"` +} + +type ByIndex struct { + resp *EmbeddingResponse +} + +func (eb *ByIndex) Len() int { return len(eb.resp.Output.Embeddings) } +func (eb *ByIndex) Swap(i, j int) { + eb.resp.Output.Embeddings[i], eb.resp.Output.Embeddings[j] = eb.resp.Output.Embeddings[j], eb.resp.Output.Embeddings[i] +} +func (eb *ByIndex) Less(i, j int) bool { + return eb.resp.Output.Embeddings[i].TextIndex < eb.resp.Output.Embeddings[j].TextIndex +} + +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + RequestID string `json:"request_id"` +} + +type AliDashScopeEmbedding struct { + apiKey string + url string +} + +func NewAliDashScopeEmbeddingClient(apiKey string, url string) *AliDashScopeEmbedding { + return &AliDashScopeEmbedding{ + apiKey: apiKey, + url: url, + } +} + +func (c *AliDashScopeEmbedding) Check() error { + if c.apiKey == "" { + return fmt.Errorf("api key is empty") + } + + if c.url == "" { + return fmt.Errorf("url is empty") + } + return nil +} + +func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, text_type string, output_type string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + var r EmbeddingRequest + r.Model = modelName + r.Input = Input{texts} + r.Parameters.Dimension = dim + r.Parameters.TextType = text_type + r.Parameters.OutputType = output_type + data, err := json.Marshal(r) + if err != nil { + return nil, err + } + + if timeoutSec <= 0 { + timeoutSec = 30 + } + client := &http.Client{ + Timeout: timeoutSec * time.Second, + } + req, err := http.NewRequest("POST", c.url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + body, err := utils.RetrySend(client, req, 3) + if err != nil { + return nil, err + } + var res EmbeddingResponse + err = json.Unmarshal(body, &res) + if err != nil { + return nil, err + } + sort.Sort(&ByIndex{&res}) + return &res, err +} diff --git a/internal/models/ali/ali_dashscope_text_embedding_test.go b/internal/models/ali/ali_dashscope_text_embedding_test.go new file mode 100644 index 0000000000000..a371f10aa79bc --- /dev/null +++ b/internal/models/ali/ali_dashscope_text_embedding_test.go @@ -0,0 +1,118 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ali + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + // "sync/atomic" + "testing" + // "time" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingClientCheck(t *testing.T) { + { + c := NewAliDashScopeEmbeddingClient("", "mock_uri") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewAliDashScopeEmbeddingClient("mock_key", "") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewAliDashScopeEmbeddingClient("mock_key", "mock_uri") + err := c.Check() + assert.True(t, err == nil) + } +} + +func TestEmbeddingOK(t *testing.T) { + var res EmbeddingResponse + repStr := `{ +"output": { +"embeddings": [ +{ +"text_index": 1, +"embedding": [0.1] +}, +{ +"text_index": 0, +"embedding": [0.0] +}, +{ +"text_index": 2, +"embedding": [0.2] +} +] +}, +"usage": { +"total_tokens": 100 +}, +"request_id": "0000000000000" +}` + err := json.Unmarshal([]byte(repStr), &res) + assert.NoError(t, err) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewAliDashScopeEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + ret, err := c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", "dense", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Output.Embeddings[0].TextIndex, 0) + assert.Equal(t, ret.Output.Embeddings[1].TextIndex, 1) + assert.Equal(t, ret.Output.Embeddings[2].TextIndex, 2) + assert.Equal(t, ret.Output.Embeddings[0].Embedding, []float32{0.0}) + assert.Equal(t, ret.Output.Embeddings[1].Embedding, []float32{0.1}) + assert.Equal(t, ret.Output.Embeddings[2].Embedding, []float32{0.2}) + } +} + +func TestEmbeddingFailed(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewAliDashScopeEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", "dense", 0) + assert.True(t, err != nil) + } +} diff --git a/internal/models/openai_embedding.go b/internal/models/openai/openai_embedding.go similarity index 64% rename from internal/models/openai_embedding.go rename to internal/models/openai/openai_embedding.go index a6cb41df112a5..ef1be424005c8 100644 --- a/internal/models/openai_embedding.go +++ b/internal/models/openai/openai_embedding.go @@ -14,16 +14,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -package models +package openai import ( "bytes" "encoding/json" "fmt" - "io" "net/http" + "net/url" "sort" "time" + + "github.com/milvus-io/milvus/internal/models/utils" ) type EmbeddingRequest struct { @@ -97,86 +99,122 @@ type EmbedddingError struct { Error ErrorInfo `json:"error"` } -type OpenAIEmbeddingClient struct { +type OpenAIEmbeddingInterface interface { + Check() error + Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) +} + +type openAIBase struct { apiKey string url string } -func (c *OpenAIEmbeddingClient) Check() error { +func (c *openAIBase) Check() error { if c.apiKey == "" { - return fmt.Errorf("OpenAI api key is empty") + return fmt.Errorf("api key is empty") } if c.url == "" { - return fmt.Errorf("OpenAI embedding url is empty") + return fmt.Errorf("url is empty") } return nil } -func NewOpenAIEmbeddingClient(apiKey string, url string) OpenAIEmbeddingClient { - return OpenAIEmbeddingClient{ - apiKey: apiKey, - url: url, +func (c *openAIBase) genReq(modelName string, texts []string, dim int, user string) *EmbeddingRequest { + var r EmbeddingRequest + r.Model = modelName + r.Input = texts + r.EncodingFormat = "float" + if user != "" { + r.User = user + } + if dim != 0 { + r.Dimensions = dim } + return &r } -func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res *EmbeddingResponse) error { - // call openai - resp, err := client.Do(req) +type OpenAIEmbeddingClient struct { + openAIBase +} - if err != nil { - return err +func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient { + return &OpenAIEmbeddingClient{ + openAIBase{ + apiKey: apiKey, + url: url, + }, } - defer resp.Body.Close() +} - body, err := io.ReadAll(resp.Body) +func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + r := c.genReq(modelName, texts, dim, user) + data, err := json.Marshal(r) if err != nil { - return err + return nil, err } - if resp.StatusCode != 200 { - return fmt.Errorf(string(body)) + if timeoutSec <= 0 { + timeoutSec = 30 } - + client := &http.Client{ + Timeout: timeoutSec * time.Second, + } + req, err := http.NewRequest("POST", c.url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + body, err := utils.RetrySend(client, req, 3) + if err != nil { + return nil, err + } + var res EmbeddingResponse err = json.Unmarshal(body, &res) if err != nil { - return err + return nil, err } - return nil + sort.Sort(&ByIndex{&res}) + return &res, err } -func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request, res *EmbeddingResponse, maxRetries int) error { - var err error - for i := 0; i < maxRetries; i++ { - err = c.send(client, req, res) - if err == nil { - return nil - } - } - return err +type AzureOpenAIEmbeddingClient struct { + openAIBase + apiVersion string } -func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { - var r EmbeddingRequest - r.Model = modelName - r.Input = texts - r.EncodingFormat = "float" - if user != "" { - r.User = user - } - if dim != 0 { - r.Dimensions = dim +func NewAzureOpenAIEmbeddingClient(apiKey string, url string) *AzureOpenAIEmbeddingClient { + return &AzureOpenAIEmbeddingClient{ + openAIBase: openAIBase{ + apiKey: apiKey, + url: url, + }, + apiVersion: "2024-06-01", } +} +func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + r := c.genReq(modelName, texts, dim, user) data, err := json.Marshal(r) if err != nil { return nil, err } - // call openai if timeoutSec <= 0 { timeoutSec = 30 } + + base, err := url.Parse(c.url) + if err != nil { + return nil, err + } + path := fmt.Sprintf("/openai/deployments/%s/embeddings", modelName) + base.Path = path + params := url.Values{} + params.Add("api-version", c.apiVersion) + base.RawQuery = params.Encode() + client := &http.Client{ Timeout: timeoutSec * time.Second, } @@ -186,13 +224,16 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + body, err := utils.RetrySend(client, req, 3) + if err != nil { + return nil, err + } var res EmbeddingResponse - err = c.sendWithRetry(client, req, &res, 3) + err = json.Unmarshal(body, &res) if err != nil { return nil, err } sort.Sort(&ByIndex{&res}) return &res, err - } diff --git a/internal/models/openai_embedding_test.go b/internal/models/openai/openai_embedding_test.go similarity index 88% rename from internal/models/openai_embedding_test.go rename to internal/models/openai/openai_embedding_test.go index 0c4cfed6d3ece..b50fabe8e46d7 100644 --- a/internal/models/openai_embedding_test.go +++ b/internal/models/openai/openai_embedding_test.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package models +package openai import ( "encoding/json" @@ -30,21 +30,21 @@ import ( func TestEmbeddingClientCheck(t *testing.T) { { - c := OpenAIEmbeddingClient{"", "mock_uri"} + c := NewOpenAIEmbeddingClient("", "mock_uri") err := c.Check() assert.True(t, err != nil) fmt.Println(err) } { - c := OpenAIEmbeddingClient{"mock_key", ""} + c := NewOpenAIEmbeddingClient("mock_key", "") err := c.Check() assert.True(t, err != nil) fmt.Println(err) } { - c := OpenAIEmbeddingClient{"mock_key", "mock_uri"} + c := NewOpenAIEmbeddingClient("mock_key", "mock_uri") err := c.Check() assert.True(t, err == nil) } @@ -55,6 +55,11 @@ func TestEmbeddingOK(t *testing.T) { res.Object = "list" res.Model = "text-embedding-3-small" res.Data = []EmbeddingData{ + { + Object: "embedding", + Embedding: []float32{1.1, 2.2, 3.3, 4.4}, + Index: 1, + }, { Object: "embedding", Embedding: []float32{1.1, 2.2, 3.3, 4.4}, @@ -76,12 +81,13 @@ func TestEmbeddingOK(t *testing.T) { url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url} + c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err == nil) - assert.Equal(t, ret, &res) + assert.Equal(t, ret.Data[0].Index, 0) + assert.Equal(t, ret.Data[1].Index, 1) } } @@ -128,7 +134,7 @@ func TestEmbeddingRetry(t *testing.T) { url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url} + c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) @@ -154,7 +160,7 @@ func TestEmbeddingFailed(t *testing.T) { url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url} + c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) @@ -175,7 +181,7 @@ func TestTimeout(t *testing.T) { url := ts.URL { - c := OpenAIEmbeddingClient{"mock_key", url} + c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1) diff --git a/internal/models/utils/embedding_util.go b/internal/models/utils/embedding_util.go new file mode 100644 index 0000000000000..fafb6f5cbd3c5 --- /dev/null +++ b/internal/models/utils/embedding_util.go @@ -0,0 +1,53 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + "io" + "net/http" +) + +func send(client *http.Client, req *http.Request) ([]byte, error) { + // call openai + resp, err := client.Do(req) + + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf(string(body)) + } + return body, nil +} + +func RetrySend(client *http.Client, req *http.Request, maxRetries int) ([]byte, error) { + for i := 0; i < maxRetries; i++ { + res, err := send(client, req) + if err == nil { + return res, nil + } + } + return nil, nil +} diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index 8c34b22e55477..ccff3ae3bb88e 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -314,7 +314,7 @@ func TestMaxInsertSize(t *testing.T) { } func TestInsertTask_Function(t *testing.T) { - ts := function.CreateEmbeddingServer() + ts := function.CreateOpenAIEmbeddingServer() defer ts.Close() data := []*schemapb.FieldData{} f := schemapb.FieldData{ @@ -358,10 +358,10 @@ func TestInsertTask_Function(t *testing.T) { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: function.Provider, Value: function.OpenAIProvider}, - {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, - {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, - {Key: function.DimParamKey, Value: "4"}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "4"}, }, }, }, diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index f0a5de503c6dd..1fb6e9c2fdf60 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -481,7 +481,7 @@ func TestSearchTask_PreExecute(t *testing.T) { } func TestSearchTask_WithFunctions(t *testing.T) { - ts := function.CreateEmbeddingServer() + ts := function.CreateOpenAIEmbeddingServer() defer ts.Close() collectionName := "TestInsertTask_function" schema := &schemapb.CollectionSchema{ @@ -511,10 +511,10 @@ func TestSearchTask_WithFunctions(t *testing.T) { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: function.Provider, Value: function.OpenAIProvider}, - {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, - {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, - {Key: function.DimParamKey, Value: "4"}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "4"}, }, }, { @@ -524,10 +524,10 @@ func TestSearchTask_WithFunctions(t *testing.T) { OutputFieldIds: []int64{103}, Params: []*commonpb.KeyValuePair{ {Key: function.Provider, Value: function.OpenAIProvider}, - {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, - {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, - {Key: function.DimParamKey, Value: "4"}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "4"}, }, }, }, diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index eec832cafa4c3..10ac7d6a70487 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -367,7 +367,7 @@ func TestUpsertTaskForReplicate(t *testing.T) { } func TestUpsertTask_Function(t *testing.T) { - ts := function.CreateEmbeddingServer() + ts := function.CreateOpenAIEmbeddingServer() defer ts.Close() data := []*schemapb.FieldData{} f1 := schemapb.FieldData{ @@ -427,10 +427,10 @@ func TestUpsertTask_Function(t *testing.T) { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: function.Provider, Value: function.OpenAIProvider}, - {Key: function.ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: function.OpenaiApiKeyParamKey, Value: "mock"}, - {Key: function.OpenaiEmbeddingUrlParamKey, Value: ts.URL}, - {Key: function.DimParamKey, Value: "4"}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + {Key: "url", Value: ts.URL}, + {Key: "dim", Value: "4"}, }, }, }, diff --git a/internal/util/function/ali_embedding_provider.go b/internal/util/function/ali_embedding_provider.go new file mode 100644 index 0000000000000..d106426f5c020 --- /dev/null +++ b/internal/util/function/ali_embedding_provider.go @@ -0,0 +1,147 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/ali" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type AliEmbeddingProvider struct { + fieldDim int64 + + client *ali.AliDashScopeEmbedding + modelName string + embedDimParam int64 + + maxBatch int + timeoutSec int +} + +func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) { + if apiKey == "" { + apiKey = os.Getenv("DASHSCOPE_API_KEY") + } + if apiKey == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the DASHSCOPE_API_KEY environment variable in the Milvus service.") + } + + if url == "" { + url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + } + if url == "" { + return nil, fmt.Errorf("Must provide `url` arguments or configure the DASHSCOPE_ENDPOINT environment variable in the Milvus service") + } + + c := ali.NewAliDashScopeEmbeddingClient(apiKey, url) + return c, nil +} + +func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*AliEmbeddingProvider, error) { + fieldDim, err := typeutil.GetDim(fieldSchema) + if err != nil { + return nil, err + } + var apiKey, url, modelName string + var dim int64 + + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case modelNameParamKey: + modelName = param.Value + case dimParamKey: + dim, err = strconv.ParseInt(param.Value, 10, 64) + if err != nil { + return nil, fmt.Errorf("dim [%s] is not int", param.Value) + } + + if dim != 0 && dim != fieldDim { + return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + } + case apiKeyParamKey: + apiKey = param.Value + case embeddingUrlParamKey: + url = param.Value + default: + } + } + + if modelName != TextEmbeddingV1 && modelName != TextEmbeddingV2 && modelName != TextEmbeddingV3 { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", + modelName, TextEmbeddingV1, TextEmbeddingV2, TextEmbeddingV3) + } + c, err := createAliEmbeddingClient(apiKey, url) + if err != nil { + return nil, err + } + provider := AliEmbeddingProvider{ + client: c, + fieldDim: fieldDim, + modelName: modelName, + embedDimParam: dim, + maxBatch: 25, + timeoutSec: 30, + } + return &provider, nil +} + +func (provider *AliEmbeddingProvider) MaxBatch() int { + return 5 * provider.maxBatch +} + +func (provider *AliEmbeddingProvider) FieldDim() int64 { + return provider.fieldDim +} + +func (provider *AliEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) { + numRows := len(texts) + if batchLimit && numRows > provider.MaxBatch() { + return nil, fmt.Errorf("Ali text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) + } + + data := make([][]float32, 0, numRows) + for i := 0; i < numRows; i += provider.maxBatch { + end := i + provider.maxBatch + if end > numRows { + end = numRows + } + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), "query", "dense", time.Duration(provider.timeoutSec)) + if err != nil { + return nil, err + } + if end-i != len(resp.Output.Embeddings) { + return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Output.Embeddings)) + } + for _, item := range resp.Output.Embeddings { + if len(item.Embedding) != int(provider.fieldDim) { + return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]", + provider.fieldDim, len(item.Embedding)) + } + data = append(data, item.Embedding) + } + } + return data, nil +} diff --git a/internal/util/function/alitext_embedding_provider_test.go b/internal/util/function/alitext_embedding_provider_test.go new file mode 100644 index 0000000000000..100e42c31e918 --- /dev/null +++ b/internal/util/function/alitext_embedding_provider_test.go @@ -0,0 +1,166 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + + "github.com/milvus-io/milvus/internal/models/ali" +) + +func TestAliTextEmbeddingProvider(t *testing.T) { + suite.Run(t, new(AliTextEmbeddingProviderSuite)) +} + +type AliTextEmbeddingProviderSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + providers []string +} + +func (s *AliTextEmbeddingProviderSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } + s.providers = []string{AliDashScopeProvider} +} + +func createAliProvider(url string, schema *schemapb.FieldSchema, providerName string) (TextEmbeddingProvider, error) { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: modelNameParamKey, Value: TextEmbeddingV3}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingUrlParamKey, Value: url}, + {Key: dimParamKey, Value: "4"}, + }, + } + switch providerName { + case AliDashScopeProvider: + return NewAliDashScopeEmbeddingProvider(schema, functionSchema) + default: + return nil, fmt.Errorf("Unknow provider") + } +} + +func (s *AliTextEmbeddingProviderSuite) TestEmbedding() { + ts := CreateAliEmbeddingServer() + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName) + s.NoError(err) + { + data := []string{"sentence"} + ret, err2 := provder.CallEmbedding(data, false) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(4, len(ret[0])) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0]) + } + { + data := []string{"sentence 1", "sentence 2", "sentence 3"} + ret, _ := provder.CallEmbedding(data, false) + s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) + } + + } +} + +func (s *AliTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res ali.EmbeddingResponse + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + TextIndex: 0, + }) + + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: []float32{1.0, 1.0}, + TextIndex: 1, + }) + res.Usage = ali.Usage{ + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName) + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence"} + _, err2 := provder.CallEmbedding(data, false) + s.Error(err2) + + } +} + +func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res ali.EmbeddingResponse + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + TextIndex: 0, + }) + res.Usage = ali.Usage{ + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createAliProvider(ts.URL, s.schema.Fields[2], provderName) + + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence2"} + _, err2 := provder.CallEmbedding(data, false) + s.Error(err2) + + } +} diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go new file mode 100644 index 0000000000000..5a8367fe23e49 --- /dev/null +++ b/internal/util/function/bedrock_embedding_provider.go @@ -0,0 +1,210 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/typeutil" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" +) + +type BedrockClient interface { + InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) +} + +type BedrockEmbeddingProvider struct { + fieldDim int64 + + client BedrockClient + modelName string + embedDimParam int64 + normalize bool + + maxBatch int + timeoutSec int +} + +func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) { + if awsAccessKeyId == "" { + awsAccessKeyId = os.Getenv("BEDROCK_ACCESS_KEY_ID") + } + if awsAccessKeyId == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the BEDROCK_ACCESS_KEY_ID environment variable in the Milvus service.") + } + + if awsSecretAccessKey == "" { + awsSecretAccessKey = os.Getenv("BEDROCK_SECRET_ACCESS_KEY") + } + if awsSecretAccessKey == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the BEDROCK_SECRET_ACCESS_KEY environment variable in the Milvus service.") + } + if region == "" { + return nil, fmt.Errorf("Missing region. Please pass `region` param.") + } + + cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( + awsAccessKeyId, awsSecretAccessKey, "")), + ) + if err != nil { + return nil, err + } + + return bedrockruntime.NewFromConfig(cfg), nil +} + +func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c BedrockClient) (*BedrockEmbeddingProvider, error) { + fieldDim, err := typeutil.GetDim(fieldSchema) + if err != nil { + return nil, err + } + var awsAccessKeyId, awsSecretAccessKey, region, modelName string + var dim int64 + var normalize bool + + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case modelNameParamKey: + modelName = param.Value + case dimParamKey: + dim, err = strconv.ParseInt(param.Value, 10, 64) + if err != nil { + return nil, fmt.Errorf("dim [%s] is not int", param.Value) + } + + if dim != 0 && dim != fieldDim { + return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + } + case awsAccessKeyIdParamKey: + awsAccessKeyId = param.Value + case awsSecretAccessKeyParamKey: + awsSecretAccessKey = param.Value + case regionParamKey: + region = param.Value + case normalizeParamKey: + switch strings.ToLower(param.Value) { + case "true": + normalize = true + case "false": + normalize = false + default: + return nil, fmt.Errorf("Illegal [%s:%s] param, ", normalizeParamKey, param.Value) + } + default: + } + } + + if modelName != BedRockTitanTextEmbeddingsV2 { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s]", + modelName, BedRockTitanTextEmbeddingsV2) + } + var client BedrockClient + if c == nil { + client, err = createBedRockEmbeddingClient(awsAccessKeyId, awsSecretAccessKey, region) + if err != nil { + return nil, err + } + } else { + client = c + } + + return &BedrockEmbeddingProvider{ + client: client, + fieldDim: fieldDim, + modelName: modelName, + embedDimParam: dim, + normalize: normalize, + maxBatch: 1, + timeoutSec: 30, + }, nil +} + +func (provider *BedrockEmbeddingProvider) MaxBatch() int { + return 5 * provider.maxBatch +} + +func (provider *BedrockEmbeddingProvider) FieldDim() int64 { + return 5 * provider.fieldDim +} + +func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) { + numRows := len(texts) + if batchLimit && numRows > provider.MaxBatch() { + return nil, fmt.Errorf("Bedrock text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) + } + + data := make([][]float32, 0, numRows) + for i := 0; i < numRows; i += 1 { + payload := BedRockRequest{ + InputText: texts[i], + } + if provider.embedDimParam != 0 { + payload.Dimensions = provider.embedDimParam + } + + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + output, err := provider.client.InvokeModel(context.Background(), &bedrockruntime.InvokeModelInput{ + Body: payloadBytes, + ModelId: aws.String(provider.modelName), + ContentType: aws.String("application/json"), + }) + + if err != nil { + return nil, err + } + + var resp BedRockResponse + err = json.Unmarshal(output.Body, &resp) + if err != nil { + return nil, err + } + if len(resp.Embedding) != int(provider.fieldDim) { + return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]", + provider.fieldDim, len(resp.Embedding)) + } + data = append(data, resp.Embedding) + } + return data, nil +} + +type BedRockRequest struct { + InputText string `json:"inputText"` + Dimensions int64 `json:"dimensions,omitempty"` + Normalize bool `json:"normalize,omitempty"` +} + +type BedRockResponse struct { + Embedding []float32 `json:"embedding"` + InputTextTokenCount int `json:"inputTextTokenCount"` +} diff --git a/internal/util/function/bedrock_text_embedding_provider_test.go b/internal/util/function/bedrock_text_embedding_provider_test.go new file mode 100644 index 0000000000000..8ba9b7763167f --- /dev/null +++ b/internal/util/function/bedrock_text_embedding_provider_test.go @@ -0,0 +1,108 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func TestBedrockTextEmbeddingProvider(t *testing.T) { + suite.Run(t, new(BedrockTextEmbeddingProviderSuite)) +} + +type BedrockTextEmbeddingProviderSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + providers []string +} + +func (s *BedrockTextEmbeddingProviderSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } + s.providers = []string{BedrockProvider} +} + +func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, dim int) (TextEmbeddingProvider, error) { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: dimParamKey, Value: "4"}, + }, + } + switch providerName { + case BedrockProvider: + return NewBedrockEmbeddingProvider(schema, functionSchema, &MockBedrockClient{dim: dim}) + default: + return nil, fmt.Errorf("Unknow provider") + } +} + +func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() { + for _, provderName := range s.providers { + provder, err := createBedrockProvider(s.schema.Fields[2], provderName, 4) + s.NoError(err) + { + data := []string{"sentence"} + ret, err2 := provder.CallEmbedding(data, false) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(4, len(ret[0])) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0]) + } + { + data := []string{"sentence 1", "sentence 2", "sentence 3"} + ret, _ := provder.CallEmbedding(data, false) + s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}}, ret) + } + + } +} + +func (s *BedrockTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { + for _, provderName := range s.providers { + provder, err := createBedrockProvider(s.schema.Fields[2], provderName, 2) + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence"} + _, err2 := provder.CallEmbedding(data, false) + s.Error(err2) + + } +} diff --git a/internal/util/function/common.go b/internal/util/function/common.go new file mode 100644 index 0000000000000..5063d895283ff --- /dev/null +++ b/internal/util/function/common.go @@ -0,0 +1,56 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +// common params +const ( + modelNameParamKey string = "model_name" + dimParamKey string = "dim" + embeddingUrlParamKey string = "url" + apiKeyParamKey string = "api_key" +) + +// ali text embedding +const ( + TextEmbeddingV1 string = "text-embedding-v1" + TextEmbeddingV2 string = "text-embedding-v2" + TextEmbeddingV3 string = "text-embedding-v1" +) + +// openai/azure text embedding + +const ( + TextEmbeddingAda002 string = "text-embedding-ada-002" + TextEmbedding3Small string = "text-embedding-3-small" + TextEmbedding3Large string = "text-embedding-3-large" +) + +const ( + userParamKey string = "user" +) + +// bedrock emebdding + +const ( + BedRockTitanTextEmbeddingsV2 string = "amazon.titan-embed-text-v2:0" + awsAccessKeyIdParamKey string = "aws_access_key_id" + awsSecretAccessKeyParamKey string = "aws_secret_access_key" + regionParamKey string = "regin" + normalizeParamKey string = "normalize" +) diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index 66b546fb38c73..5a4cd3aae0fcc 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -31,7 +31,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/models" + "github.com/milvus-io/milvus/internal/models/openai" "github.com/milvus-io/milvus/pkg/mq/msgstream" ) @@ -69,10 +69,10 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: OpenAIProvider}, - {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: OpenaiApiKeyParamKey, Value: "mock"}, - {Key: OpenaiEmbeddingUrlParamKey, Value: url}, - {Key: DimParamKey, Value: "4"}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingUrlParamKey, Value: url}, + {Key: dimParamKey, Value: "4"}, }, }, { @@ -83,10 +83,10 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch OutputFieldIds: []int64{103}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: OpenAIProvider}, - {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: OpenaiApiKeyParamKey, Value: "mock"}, - {Key: OpenaiEmbeddingUrlParamKey, Value: url}, - {Key: DimParamKey, Value: "8"}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingUrlParamKey, Value: url}, + {Key: dimParamKey, Value: "8"}, }, }, }, @@ -136,7 +136,7 @@ func (s *FunctionExecutorSuite) createEmbedding(texts []string, dim int) [][]flo } func (s *FunctionExecutorSuite) TestExecutor() { - ts := CreateEmbeddingServer() + ts := CreateOpenAIEmbeddingServer() defer ts.Close() schema := s.creataSchema(ts.URL) exec, err := NewFunctionExecutor(schema) @@ -148,23 +148,23 @@ func (s *FunctionExecutorSuite) TestExecutor() { func (s *FunctionExecutorSuite) TestErrorEmbedding() { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req models.EmbeddingRequest + var req openai.EmbeddingRequest body, _ := io.ReadAll(r.Body) defer r.Body.Close() json.Unmarshal(body, &req) - var res models.EmbeddingResponse + var res openai.EmbeddingResponse res.Object = "list" res.Model = "text-embedding-3-small" for i := 0; i < len(req.Input); i++ { - res.Data = append(res.Data, models.EmbeddingData{ + res.Data = append(res.Data, openai.EmbeddingData{ Object: "embedding", Embedding: []float32{}, Index: i, }) } - res.Usage = models.Usage{ + res.Usage = openai.Usage{ PromptTokens: 1, TotalTokens: 100, } diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go index 9342ba59feaad..f48315c7eff41 100644 --- a/internal/util/function/mock_embedding_service.go +++ b/internal/util/function/mock_embedding_service.go @@ -19,12 +19,15 @@ package function import ( + "context" "encoding/json" "io" "net/http" "net/http/httptest" - "github.com/milvus-io/milvus/internal/models" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/milvus-io/milvus/internal/models/ali" + "github.com/milvus-io/milvus/internal/models/openai" ) func mockEmbedding(texts []string, dim int) [][]float32 { @@ -40,25 +43,25 @@ func mockEmbedding(texts []string, dim int) [][]float32 { return embeddings } -func CreateEmbeddingServer() *httptest.Server { +func CreateOpenAIEmbeddingServer() *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req models.EmbeddingRequest + var req openai.EmbeddingRequest body, _ := io.ReadAll(r.Body) defer r.Body.Close() json.Unmarshal(body, &req) embs := mockEmbedding(req.Input, req.Dimensions) - var res models.EmbeddingResponse + var res openai.EmbeddingResponse res.Object = "list" res.Model = "text-embedding-3-small" for i := 0; i < len(req.Input); i++ { - res.Data = append(res.Data, models.EmbeddingData{ + res.Data = append(res.Data, openai.EmbeddingData{ Object: "embedding", Embedding: embs[i], Index: i, }) } - res.Usage = models.Usage{ + res.Usage = openai.Usage{ PromptTokens: 1, TotalTokens: 100, } @@ -69,3 +72,45 @@ func CreateEmbeddingServer() *httptest.Server { })) return ts } + +func CreateAliEmbeddingServer() *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req ali.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + embs := mockEmbedding(req.Input.Texts, req.Parameters.Dimension) + var res ali.EmbeddingResponse + for i := 0; i < len(req.Input.Texts); i++ { + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: embs[i], + TextIndex: i, + }) + } + + res.Usage = ali.Usage{ + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + + })) + return ts +} + +type MockBedrockClient struct { + dim int +} + +func (c *MockBedrockClient) InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) { + var req BedRockRequest + json.Unmarshal(params.Body, &req) + embs := mockEmbedding([]string{req.InputText}, c.dim) + + var resp BedRockResponse + resp.Embedding = embs[0] + resp.InputTextTokenCount = 2 + body, _ := json.Marshal(resp) + return &bedrockruntime.InvokeModelOutput{Body: body}, nil +} diff --git a/internal/util/function/openai_embedding_function.go b/internal/util/function/openai_embedding_function.go deleted file mode 100644 index 438fade0756f0..0000000000000 --- a/internal/util/function/openai_embedding_function.go +++ /dev/null @@ -1,263 +0,0 @@ -/* - * # Licensed to the LF AI & Data foundation under one - * # or more contributor license agreements. See the NOTICE file - * # distributed with this work for additional information - * # regarding copyright ownership. The ASF licenses this file - * # to you under the Apache License, Version 2.0 (the - * # "License"); you may not use this file except in compliance - * # with the License. You may obtain a copy of the License at - * # - * # http://www.apache.org/licenses/LICENSE-2.0 - * # - * # Unless required by applicable law or agreed to in writing, software - * # distributed under the License is distributed on an "AS IS" BASIS, - * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * # See the License for the specific language governing permissions and - * # limitations under the License. - */ - -package function - -import ( - "fmt" - "os" - "strconv" - "strings" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/models" - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -const ( - TextEmbeddingAda002 string = "text-embedding-ada-002" - TextEmbedding3Small string = "text-embedding-3-small" - TextEmbedding3Large string = "text-embedding-3-large" -) - -const ( - maxBatch = 128 - timeoutSec = 30 -) - -const ( - ModelNameParamKey string = "model_name" - DimParamKey string = "dim" - UserParamKey string = "user" - OpenaiEmbeddingUrlParamKey string = "embedding_url" - OpenaiApiKeyParamKey string = "api_key" -) - -type OpenAIEmbeddingFunction struct { - FunctionBase - fieldDim int64 - - client *models.OpenAIEmbeddingClient - modelName string - embedDimParam int64 - user string -} - -func createOpenAIEmbeddingClient(apiKey string, url string) (*models.OpenAIEmbeddingClient, error) { - if apiKey == "" { - apiKey = os.Getenv("OPENAI_API_KEY") - } - if apiKey == "" { - return nil, fmt.Errorf("The apiKey configuration was not found in the environment variables") - } - - if url == "" { - url = os.Getenv("OPENAI_EMBEDDING_URL") - } - if url == "" { - url = "https://api.openai.com/v1/embeddings" - } - c := models.NewOpenAIEmbeddingClient(apiKey, url) - return &c, nil -} - -func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*OpenAIEmbeddingFunction, error) { - if len(schema.GetOutputFieldIds()) != 1 { - return nil, fmt.Errorf("OpenAIEmbedding function should only have one output field, but now is %d", len(schema.GetOutputFieldIds())) - } - - base, err := NewBase(coll, schema) - if err != nil { - return nil, err - } - - if base.outputFields[0].DataType != schemapb.DataType_FloatVector { - return nil, fmt.Errorf("Output field not match, openai embedding needs [%s], got [%s]", - schemapb.DataType_name[int32(schemapb.DataType_FloatVector)], - schemapb.DataType_name[int32(base.outputFields[0].DataType)]) - } - - fieldDim, err := typeutil.GetDim(base.outputFields[0]) - if err != nil { - return nil, err - } - var apiKey, url, modelName, user string - var dim int64 - - for _, param := range schema.Params { - switch strings.ToLower(param.Key) { - case ModelNameParamKey: - modelName = param.Value - case DimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) - if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", schema.Name, fieldDim, dim) - } - case UserParamKey: - user = param.Value - case OpenaiApiKeyParamKey: - apiKey = param.Value - case OpenaiEmbeddingUrlParamKey: - url = param.Value - default: - } - } - - c, err := createOpenAIEmbeddingClient(apiKey, url) - if err != nil { - return nil, err - } - - runner := OpenAIEmbeddingFunction{ - FunctionBase: *base, - client: c, - fieldDim: fieldDim, - modelName: modelName, - user: user, - embedDimParam: dim, - } - - if runner.modelName != TextEmbeddingAda002 && runner.modelName != TextEmbedding3Small && runner.modelName != TextEmbedding3Large { - return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", - runner.modelName, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large) - } - return &runner, nil -} - -func (runner *OpenAIEmbeddingFunction) MaxBatch() int { - return 5 * maxBatch -} - -func (runner *OpenAIEmbeddingFunction) callEmbedding(texts []string, batchLimit bool) ([][]float32, error) { - numRows := len(texts) - if batchLimit && numRows > runner.MaxBatch() { - return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows) - } - - data := make([][]float32, 0, numRows) - for i := 0; i < numRows; i += maxBatch { - end := i + maxBatch - if end > numRows { - end = numRows - } - resp, err := runner.client.Embedding(runner.modelName, texts[i:end], int(runner.embedDimParam), runner.user, timeoutSec) - if err != nil { - return nil, err - } - if end-i != len(resp.Data) { - return nil, fmt.Errorf("The texts number is [%d], but got embedding number [%d]", end-i, len(resp.Data)) - } - for _, item := range resp.Data { - if len(item.Embedding) != int(runner.fieldDim) { - return nil, fmt.Errorf("The required embedding dim for field [%s] is [%d], but the embedding obtained from the model is [%d]", - runner.outputFields[0].Name, runner.fieldDim, len(item.Embedding)) - } - data = append(data, item.Embedding) - } - } - return data, nil -} - -func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { - if len(inputs) != 1 { - return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) - } - - if inputs[0].Type != schemapb.DataType_VarChar { - return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") - } - - texts := inputs[0].GetScalars().GetStringData().GetData() - if texts == nil { - return nil, fmt.Errorf("Input texts is empty") - } - - embds, err := runner.callEmbedding(texts, true) - if err != nil { - return nil, err - } - data := make([]float32, 0, len(texts)*int(runner.fieldDim)) - for _, emb := range embds { - data = append(data, emb...) - } - - var outputField schemapb.FieldData - outputField.FieldId = runner.outputFields[0].FieldID - outputField.FieldName = runner.outputFields[0].Name - outputField.Type = runner.outputFields[0].DataType - outputField.IsDynamic = runner.outputFields[0].IsDynamic - outputField.Field = &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Data: &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: data, - }, - }, - Dim: runner.fieldDim, - }, - } - return []*schemapb.FieldData{&outputField}, nil -} - -func (runner *OpenAIEmbeddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) { - texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally - embds, err := runner.callEmbedding(texts, true) - if err != nil { - return nil, err - } - return funcutil.Float32VectorsToPlaceholderGroup(embds), nil -} - -func (runner *OpenAIEmbeddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) { - if len(inputs) != 1 { - return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) - } - - if inputs[0].GetDataType() != schemapb.DataType_VarChar { - return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") - } - - texts, ok := inputs[0].GetDataRows().([]string) - if !ok { - return nil, fmt.Errorf("Input texts is empty") - } - - embds, err := runner.callEmbedding(texts, false) - if err != nil { - return nil, err - } - data := make([]float32, 0, len(texts)*int(runner.fieldDim)) - for _, emb := range embds { - data = append(data, emb...) - } - - field := &storage.FloatVectorFieldData{ - Data: data, - Dim: int(runner.fieldDim), - } - return map[storage.FieldID]storage.FieldData{ - runner.outputFields[0].FieldID: field, - }, nil -} diff --git a/internal/util/function/openai_embedding_provider.go b/internal/util/function/openai_embedding_provider.go new file mode 100644 index 0000000000000..f9fa78f64c8b6 --- /dev/null +++ b/internal/util/function/openai_embedding_provider.go @@ -0,0 +1,187 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/openai" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type OpenAIEmbeddingProvider struct { + fieldDim int64 + + client openai.OpenAIEmbeddingInterface + modelName string + embedDimParam int64 + user string + + maxBatch int + timeoutSec int +} + +func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) { + if apiKey == "" { + apiKey = os.Getenv("OPENAI_API_KEY") + } + if apiKey == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the OPENAI_API_KEY environment variable in the Milvus service.") + } + + if url == "" { + url = "https://api.openai.com/v1/embeddings" + } + if url == "" { + return nil, fmt.Errorf("Must provide `url` arguments or configure the OPENAI_ENDPOINT environment variable in the Milvus service") + } + + c := openai.NewOpenAIEmbeddingClient(apiKey, url) + return c, nil +} + +func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureOpenAIEmbeddingClient, error) { + if apiKey == "" { + apiKey = os.Getenv("AZURE_OPENAI_API_KEY") + } + if apiKey == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the AZURE_OPENAI_API_KEY environment variable in the Milvus service") + } + + if url == "" { + url = os.Getenv("AZURE_OPENAI_ENDPOINT") + } + if url == "" { + return nil, fmt.Errorf("Must provide `url` arguments or configure the AZURE_OPENAI_ENDPOINT environment variable in the Milvus service") + } + c := openai.NewAzureOpenAIEmbeddingClient(apiKey, url) + return c, nil +} + +func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, isAzure bool) (*OpenAIEmbeddingProvider, error) { + fieldDim, err := typeutil.GetDim(fieldSchema) + if err != nil { + return nil, err + } + var apiKey, url, modelName, user string + var dim int64 + + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case modelNameParamKey: + modelName = param.Value + case dimParamKey: + dim, err = strconv.ParseInt(param.Value, 10, 64) + if err != nil { + return nil, fmt.Errorf("dim [%s] is not int", param.Value) + } + + if dim != 0 && dim != fieldDim { + return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", fieldSchema.Name, fieldDim, dim) + } + case userParamKey: + user = param.Value + case apiKeyParamKey: + apiKey = param.Value + case embeddingUrlParamKey: + url = param.Value + default: + } + } + + var c openai.OpenAIEmbeddingInterface + if !isAzure { + if modelName != TextEmbeddingAda002 && modelName != TextEmbedding3Small && modelName != TextEmbedding3Large { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", + modelName, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large) + } + + c, err = createOpenAIEmbeddingClient(apiKey, url) + if err != nil { + return nil, err + } + } else { + c, err = createAzureOpenAIEmbeddingClient(apiKey, url) + if err != nil { + return nil, err + } + } + + provider := OpenAIEmbeddingProvider{ + client: c, + fieldDim: fieldDim, + modelName: modelName, + user: user, + embedDimParam: dim, + maxBatch: 128, + timeoutSec: 30, + } + return &provider, nil +} + +func NewOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) { + return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, false) +} + +func NewAzureOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*OpenAIEmbeddingProvider, error) { + return newOpenAIEmbeddingProvider(fieldSchema, functionSchema, true) +} + +func (provider *OpenAIEmbeddingProvider) MaxBatch() int { + return 5 * provider.maxBatch +} + +func (provider *OpenAIEmbeddingProvider) FieldDim() int64 { + return provider.fieldDim +} + +func (provider *OpenAIEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) { + numRows := len(texts) + if batchLimit && numRows > provider.MaxBatch() { + return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) + } + + data := make([][]float32, 0, numRows) + for i := 0; i < numRows; i += provider.maxBatch { + end := i + provider.maxBatch + if end > numRows { + end = numRows + } + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), provider.user, time.Duration(provider.timeoutSec)) + if err != nil { + return nil, err + } + if end-i != len(resp.Data) { + return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Data)) + } + for _, item := range resp.Data { + if len(item.Embedding) != int(provider.fieldDim) { + return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]", + provider.fieldDim, len(item.Embedding)) + } + data = append(data, item.Embedding) + } + } + return data, nil +} diff --git a/internal/util/function/openai_text_embedding_provider_test.go b/internal/util/function/openai_text_embedding_provider_test.go new file mode 100644 index 0000000000000..7681161f9c3e5 --- /dev/null +++ b/internal/util/function/openai_text_embedding_provider_test.go @@ -0,0 +1,177 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + + "github.com/milvus-io/milvus/internal/models/openai" +) + +func TestOpenAITextEmbeddingProvider(t *testing.T) { + suite.Run(t, new(OpenAITextEmbeddingProviderSuite)) +} + +type OpenAITextEmbeddingProviderSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + providers []string +} + +func (s *OpenAITextEmbeddingProviderSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } + s.providers = []string{OpenAIProvider, AzureOpenAIProvider} +} + +func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName string) (TextEmbeddingProvider, error) { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: dimParamKey, Value: "4"}, + {Key: embeddingUrlParamKey, Value: url}, + }, + } + switch providerName { + case OpenAIProvider: + return NewOpenAIEmbeddingProvider(schema, functionSchema) + case AzureOpenAIProvider: + return NewAzureOpenAIEmbeddingProvider(schema, functionSchema) + default: + return nil, fmt.Errorf("Unknow provider") + } +} + +func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() { + ts := CreateOpenAIEmbeddingServer() + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName) + s.NoError(err) + { + data := []string{"sentence"} + ret, err2 := provder.CallEmbedding(data, false) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(4, len(ret[0])) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0]) + } + { + data := []string{"sentence 1", "sentence 2", "sentence 3"} + ret, _ := provder.CallEmbedding(data, false) + s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) + } + + } +} + +func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res openai.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + res.Data = append(res.Data, openai.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + Index: 0, + }) + + res.Data = append(res.Data, openai.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0}, + Index: 1, + }) + res.Usage = openai.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName) + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence"} + _, err2 := provder.CallEmbedding(data, false) + s.Error(err2) + + } +} + +func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res openai.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + res.Data = append(res.Data, openai.EmbeddingData{ + Object: "embedding", + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + Index: 0, + }) + res.Usage = openai.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createOpenAIProvider(ts.URL, s.schema.Fields[2], provderName) + + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence2"} + _, err2 := provder.CallEmbedding(data, false) + s.Error(err2) + + } +} diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index 6813ff3c8ec47..a0ea82dc96282 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -22,7 +22,11 @@ import ( "fmt" "strings" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/funcutil" + // "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ( @@ -30,9 +34,18 @@ const ( ) const ( - OpenAIProvider string = "openai" + OpenAIProvider string = "openai" + AzureOpenAIProvider string = "azure_openai" + AliDashScopeProvider string = "dashscope" + BedrockProvider string = "bedrock" ) +type TextEmbeddingProvider interface { + MaxBatch() int + CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) + FieldDim() int64 +} + func getProvider(schema *schemapb.FunctionSchema) (string, error) { for _, param := range schema.Params { switch strings.ToLower(param.Key) { @@ -44,13 +57,157 @@ func getProvider(schema *schemapb.FunctionSchema) (string, error) { return "", fmt.Errorf("The provider parameter was not found in the function's parameters") } -func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*OpenAIEmbeddingFunction, error) { - provider, err := getProvider(schema) +type TextEmebddingFunction struct { + FunctionBase + + embProvider TextEmbeddingProvider +} + +func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *schemapb.FunctionSchema) (*TextEmebddingFunction, error) { + if len(functionSchema.GetOutputFieldIds()) != 1 { + return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldIds())) + } + + base, err := NewBase(coll, functionSchema) if err != nil { return nil, err } - if provider == OpenAIProvider { - return NewOpenAIEmbeddingFunction(coll, schema) + + if base.outputFields[0].DataType != schemapb.DataType_FloatVector { + return nil, fmt.Errorf("Output field not match, openai embedding needs [%s], got [%s]", + schemapb.DataType_name[int32(schemapb.DataType_FloatVector)], + schemapb.DataType_name[int32(base.outputFields[0].DataType)]) + } + + provider, err := getProvider(functionSchema) + if err != nil { + return nil, err + } + switch provider { + case OpenAIProvider: + embP, err := NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil + case AzureOpenAIProvider: + embP, err := NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil + case BedrockProvider: + embP, err := NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil + case AliDashScopeProvider: + embP, err := NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil + default: + return nil, fmt.Errorf("Provider: [%s] not exist, only supports [%s, %s, %s]", provider, OpenAIProvider, AzureOpenAIProvider, AliDashScopeProvider) + } + +} + +func (runner *TextEmebddingFunction) MaxBatch() int { + return runner.embProvider.MaxBatch() +} + +func (runner *TextEmebddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("Text embedding function only receives one input, bug got [%d]", len(inputs)) + } + + if inputs[0].Type != schemapb.DataType_VarChar { + return nil, fmt.Errorf("Text embedding only supports varchar field, the input is not varchar") + } + + texts := inputs[0].GetScalars().GetStringData().GetData() + if texts == nil { + return nil, fmt.Errorf("Input texts is empty") + } + + embds, err := runner.embProvider.CallEmbedding(texts, true) + if err != nil { + return nil, err + } + data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim())) + for _, emb := range embds { + data = append(data, emb...) + } + + var outputField schemapb.FieldData + outputField.FieldId = runner.outputFields[0].FieldID + outputField.FieldName = runner.outputFields[0].Name + outputField.Type = runner.outputFields[0].DataType + outputField.IsDynamic = runner.outputFields[0].IsDynamic + outputField.Field = &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: data, + }, + }, + Dim: runner.embProvider.FieldDim(), + }, + } + return []*schemapb.FieldData{&outputField}, nil +} + +func (runner *TextEmebddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) { + texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally + embds, err := runner.embProvider.CallEmbedding(texts, true) + if err != nil { + return nil, err + } + return funcutil.Float32VectorsToPlaceholderGroup(embds), nil +} + +func (runner *TextEmebddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) { + if len(inputs) != 1 { + return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) + } + + if inputs[0].GetDataType() != schemapb.DataType_VarChar { + return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") + } + + texts, ok := inputs[0].GetDataRows().([]string) + if !ok { + return nil, fmt.Errorf("Input texts is empty") + } + + embds, err := runner.embProvider.CallEmbedding(texts, false) + if err != nil { + return nil, err + } + data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim())) + for _, emb := range embds { + data = append(data, emb...) + } + + field := &storage.FloatVectorFieldData{ + Data: data, + Dim: int(runner.embProvider.FieldDim()), } - return nil, fmt.Errorf("Provider: [%s] not exist, only supports [%s]", provider, OpenAIProvider) + return map[storage.FieldID]storage.FieldData{ + runner.outputFields[0].FieldID: field, + }, nil } diff --git a/internal/util/function/openai_embedding_function_test.go b/internal/util/function/text_embedding_function_test.go similarity index 51% rename from internal/util/function/openai_embedding_function_test.go rename to internal/util/function/text_embedding_function_test.go index 81f8ede4fe7e1..9fb6d946a8b4d 100644 --- a/internal/util/function/openai_embedding_function_test.go +++ b/internal/util/function/text_embedding_function_test.go @@ -19,30 +19,24 @@ package function import ( - "encoding/json" - "io" - "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - - "github.com/milvus-io/milvus/internal/models" ) -func TestOpenAIEmbeddingFunction(t *testing.T) { - suite.Run(t, new(OpenAIEmbeddingFunctionSuite)) +func TestTextEmbeddingFunction(t *testing.T) { + suite.Run(t, new(TextEmbeddingFunctionSuite)) } -type OpenAIEmbeddingFunctionSuite struct { +type TextEmbeddingFunctionSuite struct { suite.Suite schema *schemapb.CollectionSchema } -func (s *OpenAIEmbeddingFunctionSuite) SetupTest() { +func (s *TextEmbeddingFunctionSuite) SetupTest() { s.schema = &schemapb.CollectionSchema{ Name: "test", Fields: []*schemapb.FieldSchema{ @@ -76,65 +70,93 @@ func createData(texts []string) []*schemapb.FieldData { return data } -func createEmbedding(texts []string, dim int) [][]float32 { - embeddings := make([][]float32, 0) - for i := 0; i < len(texts); i++ { - f := float32(i) - emb := make([]float32, 0) - for j := 0; j < dim; j++ { - emb = append(emb, f+float32(j)*0.1) +func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { + ts := CreateOpenAIEmbeddingServer() + defer ts.Close() + { + + runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingUrlParamKey, Value: ts.URL}, + }, + }) + s.NoError(err) + + { + data := createData([]string{"sentence"}) + ret, err2 := runner.ProcessInsert(data) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(int64(4), ret[0].GetVectors().Dim) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data) + } + { + data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) + ret, _ := runner.ProcessInsert(data) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) + } + } + + { + + runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: Provider, Value: AzureOpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingUrlParamKey, Value: ts.URL}, + }, + }) + s.NoError(err) + + { + data := createData([]string{"sentence"}) + ret, err2 := runner.ProcessInsert(data) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(int64(4), ret[0].GetVectors().Dim) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0].GetVectors().GetFloatVector().Data) + } + { + data := createData([]string{"sentence 1", "sentence 2", "sentence 3"}) + ret, _ := runner.ProcessInsert(data) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) } - embeddings = append(embeddings, emb) } - return embeddings } -func createRunner(url string, schema *schemapb.CollectionSchema) (*OpenAIEmbeddingFunction, error) { - return NewOpenAIEmbeddingFunction(schema, &schemapb.FunctionSchema{ +func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { + ts := CreateAliEmbeddingServer() + defer ts.Close() + + runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_Unknown, InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: OpenaiApiKeyParamKey, Value: "mock"}, - {Key: OpenaiEmbeddingUrlParamKey, Value: url}, + {Key: Provider, Value: AliDashScopeProvider}, + {Key: modelNameParamKey, Value: TextEmbeddingV3}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingUrlParamKey, Value: ts.URL}, }, }) -} - -func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req models.EmbeddingRequest - body, _ := io.ReadAll(r.Body) - defer r.Body.Close() - json.Unmarshal(body, &req) - - var res models.EmbeddingResponse - res.Object = "list" - res.Model = "text-embedding-3-small" - embs := createEmbedding(req.Input, 4) - for i := 0; i < len(req.Input); i++ { - res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", - Embedding: embs[i], - Index: i, - }) - } - - res.Usage = models.Usage{ - PromptTokens: 1, - TotalTokens: 100, - } - w.WriteHeader(http.StatusOK) - data, _ := json.Marshal(res) - w.Write(data) - - })) - - defer ts.Close() - runner, err := createRunner(ts.URL, s.schema) s.NoError(err) + { data := createData([]string{"sentence"}) ret, err2 := runner.ProcessInsert(data) @@ -148,74 +170,10 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() { ret, _ := runner.ProcessInsert(data) s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) } -} - -func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingDimNotMatch() { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var res models.EmbeddingResponse - res.Object = "list" - res.Model = "text-embedding-3-small" - res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", - Embedding: []float32{1.0, 1.0, 1.0, 1.0}, - Index: 0, - }) - - res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", - Embedding: []float32{1.0, 1.0}, - Index: 1, - }) - res.Usage = models.Usage{ - PromptTokens: 1, - TotalTokens: 100, - } - w.WriteHeader(http.StatusOK) - data, _ := json.Marshal(res) - w.Write(data) - })) - - defer ts.Close() - runner, err := createRunner(ts.URL, s.schema) - s.NoError(err) - - // embedding dim not match - data := createData([]string{"sentence", "sentence"}) - _, err2 := runner.ProcessInsert(data) - s.Error(err2) -} - -func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var res models.EmbeddingResponse - res.Object = "list" - res.Model = "text-embedding-3-small" - res.Data = append(res.Data, models.EmbeddingData{ - Object: "embedding", - Embedding: []float32{1.0, 1.0, 1.0, 1.0}, - Index: 0, - }) - res.Usage = models.Usage{ - PromptTokens: 1, - TotalTokens: 100, - } - w.WriteHeader(http.StatusOK) - data, _ := json.Marshal(res) - w.Write(data) - })) - - defer ts.Close() - runner, err := createRunner(ts.URL, s.schema) - - s.NoError(err) - // embedding dim not match - data := createData([]string{"sentence", "sentence2"}) - _, err2 := runner.ProcessInsert(data) - s.Error(err2) } -func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { +func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { // outputfield datatype mismatch { schema := &schemapb.CollectionSchema{ @@ -230,16 +188,17 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { }, } - _, err := NewOpenAIEmbeddingFunction(schema, &schemapb.FunctionSchema{ + _, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_Unknown, InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: DimParamKey, Value: "4"}, - {Key: OpenaiApiKeyParamKey, Value: "mock"}, - {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingUrlParamKey, Value: "mock"}, }, }) s.Error(err) @@ -262,16 +221,17 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { }}, }, } - _, err := NewOpenAIEmbeddingFunction(schema, &schemapb.FunctionSchema{ + _, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_Unknown, InputFieldIds: []int64{101}, OutputFieldIds: []int64{102, 103}, Params: []*commonpb.KeyValuePair{ - {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: DimParamKey, Value: "4"}, - {Key: OpenaiApiKeyParamKey, Value: "mock"}, - {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingUrlParamKey, Value: "mock"}, }, }) s.Error(err) @@ -279,16 +239,17 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { // outputfield miss { - _, err := NewOpenAIEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_Unknown, InputFieldIds: []int64{101}, OutputFieldIds: []int64{103}, Params: []*commonpb.KeyValuePair{ - {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, - {Key: DimParamKey, Value: "4"}, - {Key: OpenaiApiKeyParamKey, Value: "mock"}, - {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingUrlParamKey, Value: "mock"}, }, }) s.Error(err) @@ -296,16 +257,17 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { // error model name { - _, err := NewOpenAIEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_Unknown, InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: ModelNameParamKey, Value: "text-embedding-ada-004"}, - {Key: DimParamKey, Value: "4"}, - {Key: OpenaiApiKeyParamKey, Value: "mock"}, - {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-004"}, + {Key: dimParamKey, Value: "4"}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingUrlParamKey, Value: "mock"}, }, }) s.Error(err) @@ -313,13 +275,14 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { // no openai api key { - _, err := NewOpenAIEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ + _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_Unknown, InputFieldIds: []int64{101}, OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ - {Key: ModelNameParamKey, Value: "text-embedding-ada-003"}, + {Key: Provider, Value: OpenAIProvider}, + {Key: modelNameParamKey, Value: "text-embedding-ada-003"}, }, }) s.Error(err) From c459d41bde6a917efb6d461af078819ae9cbd38f Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Mon, 9 Dec 2024 21:00:12 +0800 Subject: [PATCH 11/18] Update go mod Signed-off-by: junjie.jiang --- go.mod | 15 +++++++++++++++ go.sum | 30 ++++++++++++++++++++++++++++++ internal/proxy/task_insert_test.go | 12 ++++++++++++ internal/proxy/task_search_test.go | 2 +- 4 files changed, 58 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 5c8ff76d9e3cc..fdfc3e2bd43e9 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,10 @@ require ( require ( cloud.google.com/go/storage v1.43.0 github.com/antlr4-go/antlr/v4 v4.13.1 + github.com/aws/aws-sdk-go-v2 v1.32.6 + github.com/aws/aws-sdk-go-v2/config v1.28.6 + github.com/aws/aws-sdk-go-v2/credentials v1.17.47 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0 github.com/bits-and-blooms/bitset v1.10.0 github.com/bytedance/sonic v1.12.2 github.com/cenkalti/backoff/v4 v4.2.1 @@ -101,6 +105,17 @@ require ( github.com/apache/pulsar-client-go v0.6.1-0.20210728062540-29414db801a7 // indirect github.com/apache/thrift v0.18.1 // indirect github.com/ardielle/ardielle-go v1.5.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 // indirect + github.com/aws/smithy-go v1.22.1 // indirect github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bytedance/sonic/loader v0.2.0 // indirect diff --git a/go.sum b/go.sum index 228644055077f..52dd7ac9b51bc 100644 --- a/go.sum +++ b/go.sum @@ -120,6 +120,36 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5 github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aws/aws-sdk-go v1.32.6/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= +github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= +github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo= +github.com/aws/aws-sdk-go-v2/config v1.28.6/go.mod h1:GDzxJ5wyyFSCoLkS+UhGB0dArhb9mI+Co4dHtoTxbko= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47 h1:48bA+3/fCdi2yAwVt+3COvmatZ6jUDNkDTIsqDiMUdw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.47/go.mod h1:+KdckOejLW3Ks3b0E3b5rHsr2f9yuORBum0WPnE5o5w= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 h1:s/fF4+yDQDoElYhfIVvSNyeCydfbuTKzhxSXDXCPasU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25/go.mod h1:IgPfDv5jqFIzQSNbUEMoitNooSMXjRSDkhXv8jiROvU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25 h1:ZntTCl5EsYnhN/IygQEUugpdwbhdkom9uHcbCftiGgA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.25/go.mod h1:DBdPrgeocww+CSl1C8cEV8PN1mHMBhuCDLpXezyvWkE= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0 h1:mfV5tcLXeRLbiyI4EHoHWH1sIU7JvbfXVvymUCIgZEo= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0/go.mod h1:YSSgYnasDKm5OjU3bOPkaz+2PFO6WjEQGIA6KQNsR3Q= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6 h1:50+XsN70RS7dwJ2CkVNXzj7U2L1HKP8nqTd3XWEXBN4= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.6/go.mod h1:WqgLmwY7so32kG01zD8CPTJWVWM+TzJoOVHwTg4aPug= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7 h1:rLnYAfXQ3YAccocshIH5mzNNwZBkBo+bP6EhIxak6Hw= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.7/go.mod h1:ZHtuQJ6t9A/+YDuxOLnbryAmITtr8UysSny3qcyvJTc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6 h1:JnhTZR3PiYDNKlXy50/pNeix9aGMo6lLpXwJ1mw8MD4= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.6/go.mod h1:URronUEGfXZN1VpdktPSD1EkAL9mfrV+2F4sjH38qOY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2 h1:s4074ZO1Hk8qv65GqNXqDjmkf4HSQqJukaLuuW0TpDA= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.2/go.mod h1:mVggCnIWoM09jP71Wh+ea7+5gAp53q+49wDFs1SW5z8= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b h1:5JgaFtHFRnOPReItxvhMDXbvuBkjSWE+9glJyF466yw= diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index ccff3ae3bb88e..2946af3c4bfe2 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -404,6 +404,18 @@ func TestInsertTask_Function(t *testing.T) { mock.AnythingOfType("string"), mock.AnythingOfType("string"), ).Return(info, nil) + + cache.On("GetPartitionInfo", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(&partitionInfo{ + name: "p1", + partitionID: 10, + createdTimestamp: 10001, + createdUtcTimestamp: 10002, + }, nil) globalMetaCache = cache err = task.PreExecute(ctx) assert.NoError(t, err) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 1fb6e9c2fdf60..377c0d27d70be 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -588,7 +588,7 @@ func TestSearchTask_WithFunctions(t *testing.T) { cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collectionID, nil).Maybe() cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(info, nil).Maybe() cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe() - cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionBasicInfo{}, nil).Maybe() + cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe() cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe() cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe() globalMetaCache = cache From 9ff9419a9f696bbd53ce3b0436b66efad1f9fbb8 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Wed, 11 Dec 2024 11:50:59 +0800 Subject: [PATCH 12/18] Polish error infos Signed-off-by: junjie.jiang --- internal/util/function/function_base.go | 2 +- internal/util/function/text_embedding_function.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index 393e03d131c52..ac501a96dcc95 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -29,7 +29,7 @@ type FunctionBase struct { outputFields []*schemapb.FieldSchema } -func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*FunctionBase, error) { +func NewFunctionBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*FunctionBase, error) { var base FunctionBase base.schema = schema for _, field_id := range schema.GetOutputFieldIds() { diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index a0ea82dc96282..274ddde0b8ba9 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -46,15 +46,15 @@ type TextEmbeddingProvider interface { FieldDim() int64 } -func getProvider(schema *schemapb.FunctionSchema) (string, error) { - for _, param := range schema.Params { +func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) { + for _, param := range functionSchema.Params { switch strings.ToLower(param.Key) { case Provider: return strings.ToLower(param.Value), nil default: } } - return "", fmt.Errorf("The provider parameter was not found in the function's parameters") + return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider) } type TextEmebddingFunction struct { @@ -68,7 +68,7 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldIds())) } - base, err := NewBase(coll, functionSchema) + base, err := NewFunctionBase(coll, functionSchema) if err != nil { return nil, err } @@ -121,7 +121,7 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s embProvider: embP, }, nil default: - return nil, fmt.Errorf("Provider: [%s] not exist, only supports [%s, %s, %s]", provider, OpenAIProvider, AzureOpenAIProvider, AliDashScopeProvider) + return nil, fmt.Errorf("Unsupported embedding service provider: [%s] , list of supported [%s, %s, %s, %s]", provider, OpenAIProvider, AzureOpenAIProvider, AliDashScopeProvider, BedrockProvider) } } From 0bcc13a5c814944f737b4959d087cab4ccf7e6bf Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Wed, 11 Dec 2024 15:18:36 +0800 Subject: [PATCH 13/18] Reuse http client Signed-off-by: junjie.jiang --- .../ali/ali_dashscope_text_embedding.go | 11 ++++++----- .../ali/ali_dashscope_text_embedding_test.go | 2 -- internal/models/openai/openai_embedding.go | 19 +++++++++---------- .../models/openai/openai_embedding_test.go | 1 + internal/models/utils/embedding_util.go | 10 ++++------ 5 files changed, 20 insertions(+), 23 deletions(-) diff --git a/internal/models/ali/ali_dashscope_text_embedding.go b/internal/models/ali/ali_dashscope_text_embedding.go index c53194c4d2cad..329451577f07f 100644 --- a/internal/models/ali/ali_dashscope_text_embedding.go +++ b/internal/models/ali/ali_dashscope_text_embedding.go @@ -18,6 +18,7 @@ package ali import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -130,17 +131,17 @@ func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim if timeoutSec <= 0 { timeoutSec = 30 } - client := &http.Client{ - Timeout: timeoutSec * time.Second, - } - req, err := http.NewRequest("POST", c.url, bytes.NewBuffer(data)) + + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) - body, err := utils.RetrySend(client, req, 3) + body, err := utils.RetrySend(req, 3) if err != nil { return nil, err } diff --git a/internal/models/ali/ali_dashscope_text_embedding_test.go b/internal/models/ali/ali_dashscope_text_embedding_test.go index a371f10aa79bc..6fb7cd04e5ac4 100644 --- a/internal/models/ali/ali_dashscope_text_embedding_test.go +++ b/internal/models/ali/ali_dashscope_text_embedding_test.go @@ -21,9 +21,7 @@ import ( "fmt" "net/http" "net/http/httptest" - // "sync/atomic" "testing" - // "time" "github.com/stretchr/testify/assert" ) diff --git a/internal/models/openai/openai_embedding.go b/internal/models/openai/openai_embedding.go index ef1be424005c8..ee1b7cc03330b 100644 --- a/internal/models/openai/openai_embedding.go +++ b/internal/models/openai/openai_embedding.go @@ -18,6 +18,7 @@ package openai import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -157,16 +158,15 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim if timeoutSec <= 0 { timeoutSec = 30 } - client := &http.Client{ - Timeout: timeoutSec * time.Second, - } - req, err := http.NewRequest("POST", c.url, bytes.NewBuffer(data)) + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) - body, err := utils.RetrySend(client, req, 3) + body, err := utils.RetrySend(req, 3) if err != nil { return nil, err } @@ -215,16 +215,15 @@ func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, params.Add("api-version", c.apiVersion) base.RawQuery = params.Encode() - client := &http.Client{ - Timeout: timeoutSec * time.Second, - } - req, err := http.NewRequest("POST", c.url, bytes.NewBuffer(data)) + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) - body, err := utils.RetrySend(client, req, 3) + body, err := utils.RetrySend(req, 3) if err != nil { return nil, err } diff --git a/internal/models/openai/openai_embedding_test.go b/internal/models/openai/openai_embedding_test.go index b50fabe8e46d7..c935e7c2cfbff 100644 --- a/internal/models/openai/openai_embedding_test.go +++ b/internal/models/openai/openai_embedding_test.go @@ -84,6 +84,7 @@ func TestEmbeddingOK(t *testing.T) { c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err == nil) assert.Equal(t, ret.Data[0].Index, 0) diff --git a/internal/models/utils/embedding_util.go b/internal/models/utils/embedding_util.go index fafb6f5cbd3c5..e67dcf0a4bd9c 100644 --- a/internal/models/utils/embedding_util.go +++ b/internal/models/utils/embedding_util.go @@ -22,10 +22,8 @@ import ( "net/http" ) -func send(client *http.Client, req *http.Request) ([]byte, error) { - // call openai - resp, err := client.Do(req) - +func send(req *http.Request) ([]byte, error) { + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } @@ -42,9 +40,9 @@ func send(client *http.Client, req *http.Request) ([]byte, error) { return body, nil } -func RetrySend(client *http.Client, req *http.Request, maxRetries int) ([]byte, error) { +func RetrySend(req *http.Request, maxRetries int) ([]byte, error) { for i := 0; i < maxRetries; i++ { - res, err := send(client, req) + res, err := send(req) if err == nil { return res, nil } From 2692fea46a92668db85c72f251339e2cc793f95f Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Tue, 17 Dec 2024 17:32:44 +0800 Subject: [PATCH 14/18] Add vertexai Signed-off-by: junjie.jiang --- internal/models/openai/openai_embedding.go | 6 +- .../models/openai/openai_embedding_test.go | 62 ++++- internal/models/utils/embedding_util.go | 6 +- .../vertexai/vertexai_text_embedding.go | 163 +++++++++++++ .../vertexai/vertexai_text_embedding_test.go | 90 +++++++ internal/proxy/task_search.go | 1 - internal/proxy/task_test.go | 44 ++++ internal/proxy/util.go | 5 + .../util/function/ali_embedding_provider.go | 27 ++- .../alitext_embedding_provider_test.go | 8 +- .../function/bedrock_embedding_provider.go | 10 +- .../bedrock_text_embedding_provider_test.go | 6 +- internal/util/function/common.go | 32 ++- internal/util/function/function_base.go | 10 +- internal/util/function/function_executor.go | 9 + .../util/function/mock_embedding_service.go | 35 +++ .../function/openai_embedding_provider.go | 17 +- .../openai_text_embedding_provider_test.go | 8 +- .../util/function/text_embedding_function.go | 19 +- .../function/vertexai_embedding_provider.go | 221 ++++++++++++++++++ .../vertexai_embedding_provider_test.go | 170 ++++++++++++++ 21 files changed, 892 insertions(+), 57 deletions(-) create mode 100644 internal/models/vertexai/vertexai_text_embedding.go create mode 100644 internal/models/vertexai/vertexai_text_embedding_test.go create mode 100644 internal/util/function/vertexai_embedding_provider.go create mode 100644 internal/util/function/vertexai_embedding_provider_test.go diff --git a/internal/models/openai/openai_embedding.go b/internal/models/openai/openai_embedding.go index ee1b7cc03330b..bb6f88be0cd18 100644 --- a/internal/models/openai/openai_embedding.go +++ b/internal/models/openai/openai_embedding.go @@ -215,14 +215,14 @@ func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, params.Add("api-version", c.apiVersion) base.RawQuery = params.Encode() - ctx, cancel := context.WithTimeout(context.Background(), timeoutSec) + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, base.String(), bytes.NewBuffer(data)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + req.Header.Set("api-key", c.apiKey) body, err := utils.RetrySend(req, 3) if err != nil { return nil, err diff --git a/internal/models/openai/openai_embedding_test.go b/internal/models/openai/openai_embedding_test.go index c935e7c2cfbff..87f44b4ea6308 100644 --- a/internal/models/openai/openai_embedding_test.go +++ b/internal/models/openai/openai_embedding_test.go @@ -72,7 +72,20 @@ func TestEmbeddingOK(t *testing.T) { } ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) + if r.URL.Path == "/" { + if r.Header["Authorization"][0] != "" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusBadRequest) + } + } else { + if r.Header["Api-Key"][0] != "" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusBadRequest) + } + } + data, _ := json.Marshal(res) w.Write(data) })) @@ -84,7 +97,15 @@ func TestEmbeddingOK(t *testing.T) { c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) - _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Data[0].Index, 0) + assert.Equal(t, ret.Data[1].Index, 1) + } + { + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) assert.True(t, err == nil) assert.Equal(t, ret.Data[0].Index, 0) @@ -148,6 +169,20 @@ func TestEmbeddingRetry(t *testing.T) { assert.Equal(t, ret.Data[2], res.Data[0]) assert.Equal(t, atomic.LoadInt32(&count), int32(2)) } + { + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + ret, err := c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Usage, res.Usage) + assert.Equal(t, ret.Object, res.Object) + assert.Equal(t, ret.Model, res.Model) + assert.Equal(t, ret.Data[0], res.Data[1]) + assert.Equal(t, ret.Data[1], res.Data[2]) + assert.Equal(t, ret.Data[2], res.Data[0]) + assert.Equal(t, atomic.LoadInt32(&count), int32(2)) + } } func TestEmbeddingFailed(t *testing.T) { @@ -161,6 +196,7 @@ func TestEmbeddingFailed(t *testing.T) { url := ts.URL { + atomic.StoreInt32(&count, 0) c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) @@ -168,6 +204,15 @@ func TestEmbeddingFailed(t *testing.T) { assert.True(t, err != nil) assert.Equal(t, atomic.LoadInt32(&count), int32(3)) } + { + atomic.StoreInt32(&count, 0) + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 0) + assert.True(t, err != nil) + assert.Equal(t, atomic.LoadInt32(&count), int32(3)) + } } func TestTimeout(t *testing.T) { @@ -182,6 +227,7 @@ func TestTimeout(t *testing.T) { url := ts.URL { + atomic.StoreInt32(&st, 0) c := NewOpenAIEmbeddingClient("mock_key", url) err := c.Check() assert.True(t, err == nil) @@ -191,4 +237,16 @@ func TestTimeout(t *testing.T) { time.Sleep(3 * time.Second) assert.Equal(t, atomic.LoadInt32(&st), int32(1)) } + + { + atomic.StoreInt32(&st, 0) + c := NewAzureOpenAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-3-small", []string{"sentence"}, 0, "", 1) + assert.True(t, err != nil) + assert.Equal(t, atomic.LoadInt32(&st), int32(0)) + time.Sleep(3 * time.Second) + assert.Equal(t, atomic.LoadInt32(&st), int32(1)) + } } diff --git a/internal/models/utils/embedding_util.go b/internal/models/utils/embedding_util.go index e67dcf0a4bd9c..1d6e7d916cab2 100644 --- a/internal/models/utils/embedding_util.go +++ b/internal/models/utils/embedding_util.go @@ -41,11 +41,13 @@ func send(req *http.Request) ([]byte, error) { } func RetrySend(req *http.Request, maxRetries int) ([]byte, error) { + var err error + var res []byte for i := 0; i < maxRetries; i++ { - res, err := send(req) + res, err = send(req) if err == nil { return res, nil } } - return nil, nil + return nil, err } diff --git a/internal/models/vertexai/vertexai_text_embedding.go b/internal/models/vertexai/vertexai_text_embedding.go new file mode 100644 index 0000000000000..3842824616214 --- /dev/null +++ b/internal/models/vertexai/vertexai_text_embedding.go @@ -0,0 +1,163 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vertexai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/milvus-io/milvus/internal/models/utils" + + "golang.org/x/oauth2/google" +) + +type Instance struct { + TaskType string `json:"task_type,omitempty"` + Content string `json:"content"` +} + +type Parameters struct { + OutputDimensionality int64 `json:"outputDimensionality,omitempty"` +} + +type EmbeddingRequest struct { + Instances []Instance `json:"instances"` + Parameters Parameters `json:"parameters,omitempty"` +} + +type Statistics struct { + Truncated bool `json:"truncated"` + TokenCount int `json:"token_count"` +} + +type Embeddings struct { + Statistics Statistics `json:"statistics"` + Values []float32 `json:"values"` +} + +type Prediction struct { + Embeddings Embeddings `json:"embeddings"` +} + +type Metadata struct { + BillableCharacterCount int `json:"billableCharacterCount"` +} + +type EmbeddingResponse struct { + Predictions []Prediction `json:"predictions"` + Metadata Metadata `json:"metadata"` +} + +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + RequestID string `json:"request_id"` +} + +type VertexAIEmbedding struct { + url string + jsonKey []byte + scopes string + token string +} + +func NewVertexAIEmbedding(url string, jsonKey []byte, scopes string, token string) *VertexAIEmbedding { + return &VertexAIEmbedding{ + url: url, + jsonKey: jsonKey, + scopes: scopes, + token: token, + } +} + +func (c *VertexAIEmbedding) Check() error { + if c.url == "" { + return fmt.Errorf("VertexAI embedding url is empty") + } + if len(c.jsonKey) == 0 { + return fmt.Errorf("jsonKey is empty") + } + if c.scopes == "" { + return fmt.Errorf("Scopes param is empty") + } + return nil +} + +func (c *VertexAIEmbedding) getAccessToken() (string, error) { + ctx := context.Background() + creds, err := google.CredentialsFromJSON(ctx, c.jsonKey, c.scopes) + if err != nil { + return "", fmt.Errorf("Failed to find credentials: %v", err) + } + token, err := creds.TokenSource.Token() + if err != nil { + return "", fmt.Errorf("Failed to get token: %v", err) + } + return token.AccessToken, nil +} + +func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int64, taskType string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + var r EmbeddingRequest + for _, text := range texts { + r.Instances = append(r.Instances, Instance{TaskType: taskType, Content: text}) + } + if dim != 0 { + r.Parameters.OutputDimensionality = dim + } + + data, err := json.Marshal(r) + if err != nil { + return nil, err + } + + if timeoutSec <= 0 { + timeoutSec = 30 + } + + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + var token string + if c.token != "" { + token = c.token + } else { + token, err = c.getAccessToken() + if err != nil { + return nil, err + } + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + body, err := utils.RetrySend(req, 3) + if err != nil { + return nil, err + } + var res EmbeddingResponse + err = json.Unmarshal(body, &res) + if err != nil { + return nil, err + } + return &res, err +} diff --git a/internal/models/vertexai/vertexai_text_embedding_test.go b/internal/models/vertexai/vertexai_text_embedding_test.go new file mode 100644 index 0000000000000..f138d659a3ea4 --- /dev/null +++ b/internal/models/vertexai/vertexai_text_embedding_test.go @@ -0,0 +1,90 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vertexai + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingClientCheck(t *testing.T) { + mockJsonKey := []byte{1, 2, 3} + { + c := NewVertexAIEmbedding("mock_url", []byte{}, "mock_scopes", "") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewVertexAIEmbedding("", mockJsonKey, "", "") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewVertexAIEmbedding("mock_url", mockJsonKey, "mock_scopes", "") + err := c.Check() + assert.True(t, err == nil) + } +} + +func TestEmbeddingOK(t *testing.T) { + var res EmbeddingResponse + repStr := `{"predictions": [{"embeddings": {"statistics": {"truncated": false, "token_count": 4}, "values": [-0.028420744463801384, 0.037183016538619995]}}, {"embeddings": {"statistics": {"truncated": false, "token_count": 8}, "values": [-0.04367655888199806, 0.03777721896767616, 0.0158217903226614]}}], "metadata": {"billableCharacterCount": 27}}` + err := json.Unmarshal([]byte(repStr), &res) + assert.NoError(t, err) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token") + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-005", []string{"sentence"}, 0, "query", 0) + assert.True(t, err == nil) + } +} + +func TestEmbeddingFailed(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock_scopes", "mock_token") + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("text-embedding-v2", []string{"sentence"}, 0, "query", 0) + assert.True(t, err != nil) + } +} diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 6c7f271f74e90..0b5988961c84f 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -439,7 +439,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect() } - var err error t.reScorers, err = NewReScorers(ctx, len(t.request.GetSubReqs()), t.request.GetSearchParams()) if err != nil { log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err)) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 019063bbbc9f1..7982f36731dbe 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "encoding/binary" + "fmt" "math/rand" "strconv" "testing" @@ -1022,6 +1023,49 @@ func TestCreateCollectionTask(t *testing.T) { err = task2.PreExecute(ctx) assert.Error(t, err) }) + + t.Run("collection with embedding function ", func(t *testing.T) { + fmt.Println(schema) + schema.Functions = []*schemapb.FunctionSchema{ + { + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldNames: []string{varCharField}, + OutputFieldNames: []string{floatVecField}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: "provider", Value: "openai"}, + {Key: "model_name", Value: "text-embedding-ada-002"}, + {Key: "api_key", Value: "mock"}, + }, + }, + } + + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + task2 := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + err = task2.OnEnqueue() + assert.NoError(t, err) + + err = task2.PreExecute(ctx) + assert.NoError(t, err) + }) } func TestHasCollectionTask(t *testing.T) { diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 1c5daba64c681..2efb9e2ebd9bd 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -39,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/internal/util/indexparamcheck" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" @@ -705,6 +706,10 @@ func validateFunction(coll *schemapb.CollectionSchema) error { return err } } + + if err := function.CheckFunctions(coll); err != nil { + return err + } return nil } diff --git a/internal/util/function/ali_embedding_provider.go b/internal/util/function/ali_embedding_provider.go index d106426f5c020..920041afadbc1 100644 --- a/internal/util/function/ali_embedding_provider.go +++ b/internal/util/function/ali_embedding_provider.go @@ -36,6 +36,7 @@ type AliEmbeddingProvider struct { client *ali.AliDashScopeEmbedding modelName string embedDimParam int64 + outputType string maxBatch int timeoutSec int @@ -43,19 +44,15 @@ type AliEmbeddingProvider struct { func createAliEmbeddingClient(apiKey string, url string) (*ali.AliDashScopeEmbedding, error) { if apiKey == "" { - apiKey = os.Getenv("DASHSCOPE_API_KEY") + apiKey = os.Getenv(dashscopeApiKey) } if apiKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the DASHSCOPE_API_KEY environment variable in the Milvus service.") + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", dashscopeApiKey) } if url == "" { url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" } - if url == "" { - return nil, fmt.Errorf("Must provide `url` arguments or configure the DASHSCOPE_ENDPOINT environment variable in the Milvus service") - } - c := ali.NewAliDashScopeEmbeddingClient(apiKey, url) return c, nil } @@ -93,6 +90,7 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]", modelName, TextEmbeddingV1, TextEmbeddingV2, TextEmbeddingV3) } + c, err := createAliEmbeddingClient(apiKey, url) if err != nil { return nil, err @@ -102,8 +100,10 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio fieldDim: fieldDim, modelName: modelName, embedDimParam: dim, - maxBatch: 25, - timeoutSec: 30, + // TextEmbedding only supports dense embedding + outputType: "dense", + maxBatch: 25, + timeoutSec: 30, } return &provider, nil } @@ -116,19 +116,24 @@ func (provider *AliEmbeddingProvider) FieldDim() int64 { return provider.fieldDim } -func (provider *AliEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) { +func (provider *AliEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, mode string) ([][]float32, error) { numRows := len(texts) if batchLimit && numRows > provider.MaxBatch() { return nil, fmt.Errorf("Ali text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) } - + var textType string + if mode == SearchMode { + textType = "query" + } else { + textType = "document" + } data := make([][]float32, 0, numRows) for i := 0; i < numRows; i += provider.maxBatch { end := i + provider.maxBatch if end > numRows { end = numRows } - resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), "query", "dense", time.Duration(provider.timeoutSec)) + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, provider.outputType, time.Duration(provider.timeoutSec)) if err != nil { return nil, err } diff --git a/internal/util/function/alitext_embedding_provider_test.go b/internal/util/function/alitext_embedding_provider_test.go index 100e42c31e918..f4a36bd2635a4 100644 --- a/internal/util/function/alitext_embedding_provider_test.go +++ b/internal/util/function/alitext_embedding_provider_test.go @@ -88,7 +88,7 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbedding() { s.NoError(err) { data := []string{"sentence"} - ret, err2 := provder.CallEmbedding(data, false) + ret, err2 := provder.CallEmbedding(data, false, InsertMode) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(4, len(ret[0])) @@ -96,7 +96,7 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbedding() { } { data := []string{"sentence 1", "sentence 2", "sentence 3"} - ret, _ := provder.CallEmbedding(data, false) + ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } @@ -130,7 +130,7 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { // embedding dim not match data := []string{"sentence", "sentence"} - _, err2 := provder.CallEmbedding(data, false) + _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) } @@ -159,7 +159,7 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { // embedding dim not match data := []string{"sentence", "sentence2"} - _, err2 := provder.CallEmbedding(data, false) + _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) } diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go index 5a8367fe23e49..a9a6d56e95be9 100644 --- a/internal/util/function/bedrock_embedding_provider.go +++ b/internal/util/function/bedrock_embedding_provider.go @@ -53,17 +53,17 @@ type BedrockEmbeddingProvider struct { func createBedRockEmbeddingClient(awsAccessKeyId string, awsSecretAccessKey string, region string) (*bedrockruntime.Client, error) { if awsAccessKeyId == "" { - awsAccessKeyId = os.Getenv("BEDROCK_ACCESS_KEY_ID") + awsAccessKeyId = os.Getenv(bedrockAccessKeyId) } if awsAccessKeyId == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the BEDROCK_ACCESS_KEY_ID environment variable in the Milvus service.") + return nil, fmt.Errorf("Missing credentials. Please pass `aws_access_key_id`, or configure the %s environment variable in the Milvus service.", bedrockAccessKeyId) } if awsSecretAccessKey == "" { - awsSecretAccessKey = os.Getenv("BEDROCK_SECRET_ACCESS_KEY") + awsSecretAccessKey = os.Getenv(bedrockSecretAccessKey) } if awsSecretAccessKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the BEDROCK_SECRET_ACCESS_KEY environment variable in the Milvus service.") + return nil, fmt.Errorf("Missing credentials. Please pass `aws_secret_access_key`, or configure the %s environment variable in the Milvus service.", bedrockSecretAccessKey) } if region == "" { return nil, fmt.Errorf("Missing region. Please pass `region` param.") @@ -154,7 +154,7 @@ func (provider *BedrockEmbeddingProvider) FieldDim() int64 { return 5 * provider.fieldDim } -func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) { +func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, _ string) ([][]float32, error) { numRows := len(texts) if batchLimit && numRows > provider.MaxBatch() { return nil, fmt.Errorf("Bedrock text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) diff --git a/internal/util/function/bedrock_text_embedding_provider_test.go b/internal/util/function/bedrock_text_embedding_provider_test.go index 8ba9b7763167f..9d74f7e2604cc 100644 --- a/internal/util/function/bedrock_text_embedding_provider_test.go +++ b/internal/util/function/bedrock_text_embedding_provider_test.go @@ -79,7 +79,7 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() { s.NoError(err) { data := []string{"sentence"} - ret, err2 := provder.CallEmbedding(data, false) + ret, err2 := provder.CallEmbedding(data, false, InsertMode) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(4, len(ret[0])) @@ -87,7 +87,7 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() { } { data := []string{"sentence 1", "sentence 2", "sentence 3"} - ret, _ := provder.CallEmbedding(data, false) + ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}}, ret) } @@ -101,7 +101,7 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { // embedding dim not match data := []string{"sentence", "sentence"} - _, err2 := provder.CallEmbedding(data, false) + _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) } diff --git a/internal/util/function/common.go b/internal/util/function/common.go index 5063d895283ff..56da30e5ed42f 100644 --- a/internal/util/function/common.go +++ b/internal/util/function/common.go @@ -18,6 +18,11 @@ package function +const ( + InsertMode string = "Insert" + SearchMode string = "Search" +) + // common params const ( modelNameParamKey string = "model_name" @@ -30,7 +35,9 @@ const ( const ( TextEmbeddingV1 string = "text-embedding-v1" TextEmbeddingV2 string = "text-embedding-v2" - TextEmbeddingV3 string = "text-embedding-v1" + TextEmbeddingV3 string = "text-embedding-v3" + + dashscopeApiKey string = "MILVUS_DASHSCOPE_API_KEY" ) // openai/azure text embedding @@ -39,9 +46,12 @@ const ( TextEmbeddingAda002 string = "text-embedding-ada-002" TextEmbedding3Small string = "text-embedding-3-small" TextEmbedding3Large string = "text-embedding-3-large" -) -const ( + openaiApiKey string = "MILVUSAI_OPENAI_API_KEY" + + azureOpenaiApiKey string = "MILVUSAI_AZURE_OPENAI_API_KEY" + azureOpenaiEndpoint string = "MILVUSAI_AZURE_OPENAI_ENDPOINT" + userParamKey string = "user" ) @@ -53,4 +63,20 @@ const ( awsSecretAccessKeyParamKey string = "aws_secret_access_key" regionParamKey string = "regin" normalizeParamKey string = "normalize" + + bedrockAccessKeyId string = "MILVUSAI_BEDROCK_ACCESS_KEY_ID" + bedrockSecretAccessKey string = "MILVUSAI_BEDROCK_SECRET_ACCESS_KEY" +) + +// vertexAI + +const ( + locationParamKey string = "location" + projectIDParamKey string = "projectid" + taskTypeParamKey string = "task" + + textEmbedding005 string = "text-embedding-005" + textMultilingualEmbedding002 string = "text-multilingual-embedding-002" + + vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS" ) diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index ac501a96dcc95..209b4788180dc 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -29,10 +29,10 @@ type FunctionBase struct { outputFields []*schemapb.FieldSchema } -func NewFunctionBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*FunctionBase, error) { +func NewFunctionBase(coll *schemapb.CollectionSchema, f_schema *schemapb.FunctionSchema) (*FunctionBase, error) { var base FunctionBase - base.schema = schema - for _, field_id := range schema.GetOutputFieldIds() { + base.schema = f_schema + for _, field_id := range f_schema.GetOutputFieldIds() { for _, field := range coll.GetFields() { if field.GetFieldID() == field_id { base.outputFields = append(base.outputFields, field) @@ -41,9 +41,9 @@ func NewFunctionBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionS } } - if len(base.outputFields) != len(schema.GetOutputFieldIds()) { + if len(base.outputFields) != len(f_schema.GetOutputFieldIds()) { return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema", - coll.Name, schema.Name) + coll.Name, f_schema.Name) } return &base, nil } diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index 221826e0b7378..6f2469cca9173 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -61,6 +61,15 @@ func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSc } } +func CheckFunctions(schema *schemapb.CollectionSchema) error { + for _, f_schema := range schema.Functions { + if _, err := createFunction(schema, f_schema); err != nil { + return err + } + } + return nil +} + func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { // If the function's outputs exists in outputIDs, then create the function // when outputIDs is empty, create all functions diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go index f48315c7eff41..4cb181a7a0c4f 100644 --- a/internal/util/function/mock_embedding_service.go +++ b/internal/util/function/mock_embedding_service.go @@ -28,6 +28,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/milvus-io/milvus/internal/models/ali" "github.com/milvus-io/milvus/internal/models/openai" + "github.com/milvus-io/milvus/internal/models/vertexai" ) func mockEmbedding(texts []string, dim int) [][]float32 { @@ -94,6 +95,40 @@ func CreateAliEmbeddingServer() *httptest.Server { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) + })) + return ts +} + +func CreateVertexAIEmbeddingServer() *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req vertexai.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + var texts []string + for _, item := range req.Instances { + texts = append(texts, item.Content) + } + embs := mockEmbedding(texts, int(req.Parameters.OutputDimensionality)) + var res vertexai.EmbeddingResponse + for i := 0; i < len(req.Instances); i++ { + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: embs[i], + }, + }) + } + + res.Metadata = vertexai.Metadata{ + BillableCharacterCount: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) })) return ts diff --git a/internal/util/function/openai_embedding_provider.go b/internal/util/function/openai_embedding_provider.go index f9fa78f64c8b6..32cfb945509f7 100644 --- a/internal/util/function/openai_embedding_provider.go +++ b/internal/util/function/openai_embedding_provider.go @@ -44,18 +44,15 @@ type OpenAIEmbeddingProvider struct { func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbeddingClient, error) { if apiKey == "" { - apiKey = os.Getenv("OPENAI_API_KEY") + apiKey = os.Getenv(openaiApiKey) } if apiKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the OPENAI_API_KEY environment variable in the Milvus service.") + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", openaiApiKey) } if url == "" { url = "https://api.openai.com/v1/embeddings" } - if url == "" { - return nil, fmt.Errorf("Must provide `url` arguments or configure the OPENAI_ENDPOINT environment variable in the Milvus service") - } c := openai.NewOpenAIEmbeddingClient(apiKey, url) return c, nil @@ -63,17 +60,17 @@ func createOpenAIEmbeddingClient(apiKey string, url string) (*openai.OpenAIEmbed func createAzureOpenAIEmbeddingClient(apiKey string, url string) (*openai.AzureOpenAIEmbeddingClient, error) { if apiKey == "" { - apiKey = os.Getenv("AZURE_OPENAI_API_KEY") + apiKey = os.Getenv(azureOpenaiApiKey) } if apiKey == "" { - return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the AZURE_OPENAI_API_KEY environment variable in the Milvus service") + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service", azureOpenaiApiKey) } if url == "" { - url = os.Getenv("AZURE_OPENAI_ENDPOINT") + url = os.Getenv(azureOpenaiEndpoint) } if url == "" { - return nil, fmt.Errorf("Must provide `url` arguments or configure the AZURE_OPENAI_ENDPOINT environment variable in the Milvus service") + return nil, fmt.Errorf("Must provide `url` arguments or configure the %s environment variable in the Milvus service", azureOpenaiEndpoint) } c := openai.NewAzureOpenAIEmbeddingClient(apiKey, url) return c, nil @@ -156,7 +153,7 @@ func (provider *OpenAIEmbeddingProvider) FieldDim() int64 { return provider.fieldDim } -func (provider *OpenAIEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) { +func (provider *OpenAIEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, _ string) ([][]float32, error) { numRows := len(texts) if batchLimit && numRows > provider.MaxBatch() { return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) diff --git a/internal/util/function/openai_text_embedding_provider_test.go b/internal/util/function/openai_text_embedding_provider_test.go index 7681161f9c3e5..7c3667822956f 100644 --- a/internal/util/function/openai_text_embedding_provider_test.go +++ b/internal/util/function/openai_text_embedding_provider_test.go @@ -90,7 +90,7 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() { s.NoError(err) { data := []string{"sentence"} - ret, err2 := provder.CallEmbedding(data, false) + ret, err2 := provder.CallEmbedding(data, false, InsertMode) s.NoError(err2) s.Equal(1, len(ret)) s.Equal(4, len(ret[0])) @@ -98,7 +98,7 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() { } { data := []string{"sentence 1", "sentence 2", "sentence 3"} - ret, _ := provder.CallEmbedding(data, false) + ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } @@ -137,7 +137,7 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { // embedding dim not match data := []string{"sentence", "sentence"} - _, err2 := provder.CallEmbedding(data, false) + _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) } @@ -170,7 +170,7 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { // embedding dim not match data := []string{"sentence", "sentence2"} - _, err2 := provder.CallEmbedding(data, false) + _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) } diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index 274ddde0b8ba9..8ddf4893dfc44 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -38,11 +38,13 @@ const ( AzureOpenAIProvider string = "azure_openai" AliDashScopeProvider string = "dashscope" BedrockProvider string = "bedrock" + VertexAIProvider string = "vertexai" ) +// Text embedding for retrieval task type TextEmbeddingProvider interface { MaxBatch() int - CallEmbedding(texts []string, batchLimit bool) ([][]float32, error) + CallEmbedding(texts []string, batchLimit bool, mode string) ([][]float32, error) FieldDim() int64 } @@ -120,6 +122,15 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s FunctionBase: *base, embProvider: embP, }, nil + case VertexAIProvider: + embP, err := NewVertextAIEmbeddingProvider(base.outputFields[0], functionSchema, nil) + if err != nil { + return nil, err + } + return &TextEmebddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil default: return nil, fmt.Errorf("Unsupported embedding service provider: [%s] , list of supported [%s, %s, %s, %s]", provider, OpenAIProvider, AzureOpenAIProvider, AliDashScopeProvider, BedrockProvider) } @@ -144,7 +155,7 @@ func (runner *TextEmebddingFunction) ProcessInsert(inputs []*schemapb.FieldData) return nil, fmt.Errorf("Input texts is empty") } - embds, err := runner.embProvider.CallEmbedding(texts, true) + embds, err := runner.embProvider.CallEmbedding(texts, true, InsertMode) if err != nil { return nil, err } @@ -173,7 +184,7 @@ func (runner *TextEmebddingFunction) ProcessInsert(inputs []*schemapb.FieldData) func (runner *TextEmebddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) { texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally - embds, err := runner.embProvider.CallEmbedding(texts, true) + embds, err := runner.embProvider.CallEmbedding(texts, true, SearchMode) if err != nil { return nil, err } @@ -194,7 +205,7 @@ func (runner *TextEmebddingFunction) ProcessBulkInsert(inputs []storage.FieldDat return nil, fmt.Errorf("Input texts is empty") } - embds, err := runner.embProvider.CallEmbedding(texts, false) + embds, err := runner.embProvider.CallEmbedding(texts, false, InsertMode) if err != nil { return nil, err } diff --git a/internal/util/function/vertexai_embedding_provider.go b/internal/util/function/vertexai_embedding_provider.go new file mode 100644 index 0000000000000..1d9c997571dcf --- /dev/null +++ b/internal/util/function/vertexai_embedding_provider.go @@ -0,0 +1,221 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/vertexai" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type vertexAIJsonKey struct { + jsonKey []byte + once sync.Once + initErr error +} + +var vtxKey vertexAIJsonKey + +func getVertexAIJsonKey() ([]byte, error) { + vtxKey.once.Do(func() { + jsonKeyPath := os.Getenv(vertexServiceAccountJSONEnv) + jsonKey, err := os.ReadFile(jsonKeyPath) + if err != nil { + vtxKey.initErr = fmt.Errorf("Read service account json file failed, %v", err) + return + } + vtxKey.jsonKey = jsonKey + }) + return vtxKey.jsonKey, vtxKey.initErr +} + +const ( + vertexAIDocRetrival string = "DOC_RETRIEVAL" + vertexAICodeRetrival string = "CODE_RETRIEVAL" + vertexAISTS string = "STS" +) + +func checkTask(modelName string, task string) error { + if task != vertexAIDocRetrival && task != vertexAICodeRetrival && task != vertexAISTS { + return fmt.Errorf("Unsupport task %s, the supported list: [%s, %s, %s]", task, vertexAIDocRetrival, vertexAICodeRetrival, vertexAISTS) + } + if modelName == textMultilingualEmbedding002 && task == vertexAICodeRetrival { + return fmt.Errorf("Model %s doesn't support %s task", textMultilingualEmbedding002, vertexAICodeRetrival) + } + return nil +} + +type VertextAIEmbeddingProvider struct { + fieldDim int64 + + client *vertexai.VertexAIEmbedding + modelName string + embedDimParam int64 + task string + + maxBatch int + timeoutSec int +} + +func createVertextAIEmbeddingClient(url string) (*vertexai.VertexAIEmbedding, error) { + jsonKey, err := getVertexAIJsonKey() + if err != nil { + return nil, err + } + c := vertexai.NewVertexAIEmbedding(url, jsonKey, "https://www.googleapis.com/auth/cloud-platform", "") + return c, nil +} + +func NewVertextAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema, c *vertexai.VertexAIEmbedding) (*VertextAIEmbeddingProvider, error) { + fieldDim, err := typeutil.GetDim(fieldSchema) + if err != nil { + return nil, err + } + var location, projectID, task, modelName string + var dim int64 + + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case modelNameParamKey: + modelName = param.Value + case dimParamKey: + dim, err = strconv.ParseInt(param.Value, 10, 64) + if err != nil { + return nil, fmt.Errorf("dim [%s] is not int", param.Value) + } + + if dim != 0 && dim != fieldDim { + return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + } + case locationParamKey: + location = param.Value + case projectIDParamKey: + projectID = param.Value + case taskTypeParamKey: + task = param.Value + default: + } + } + + if task == "" { + task = vertexAIDocRetrival + } + if err := checkTask(modelName, task); err != nil { + return nil, err + } + + if location == "" { + location = "us-central1" + } + + if modelName != textEmbedding005 && modelName != textMultilingualEmbedding002 { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s]", + modelName, textEmbedding005, textMultilingualEmbedding002) + } + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", location, projectID, location, modelName) + var client *vertexai.VertexAIEmbedding + if c == nil { + client, err = createVertextAIEmbeddingClient(url) + if err != nil { + return nil, err + } + } else { + client = c + } + + provider := VertextAIEmbeddingProvider{ + fieldDim: fieldDim, + client: client, + modelName: modelName, + embedDimParam: dim, + task: task, + maxBatch: 128, + timeoutSec: 30, + } + return &provider, nil +} + +func (provider *VertextAIEmbeddingProvider) MaxBatch() int { + return 5 * provider.maxBatch +} + +func (provider *VertextAIEmbeddingProvider) FieldDim() int64 { + return provider.fieldDim +} + +func (provider *VertextAIEmbeddingProvider) getTaskType(mode string) string { + if mode == SearchMode { + switch provider.task { + case vertexAIDocRetrival: + return "RETRIEVAL_QUERY" + case vertexAICodeRetrival: + return "CODE_RETRIEVAL_QUERY" + case vertexAISTS: + return "SEMANTIC_SIMILARITY" + } + } else { + switch provider.task { + case vertexAIDocRetrival: + return "RETRIEVAL_DOCUMENT" + case vertexAICodeRetrival: + return "RETRIEVAL_DOCUMENT" + case vertexAISTS: + return "SEMANTIC_SIMILARITY" + } + } + return "" +} + +func (provider *VertextAIEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, mode string) ([][]float32, error) { + numRows := len(texts) + if batchLimit && numRows > provider.MaxBatch() { + return nil, fmt.Errorf("VertextAI text embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) + } + + taskType := provider.getTaskType(mode) + data := make([][]float32, 0, numRows) + for i := 0; i < numRows; i += provider.maxBatch { + end := i + provider.maxBatch + if end > numRows { + end = numRows + } + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], provider.embedDimParam, taskType, time.Duration(provider.timeoutSec)) + if err != nil { + return nil, err + } + if end-i != len(resp.Predictions) { + return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Predictions)) + } + for _, item := range resp.Predictions { + if len(item.Embeddings.Values) != int(provider.fieldDim) { + return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]", + provider.fieldDim, len(item.Embeddings.Values)) + } + data = append(data, item.Embeddings.Values) + } + } + return data, nil +} diff --git a/internal/util/function/vertexai_embedding_provider_test.go b/internal/util/function/vertexai_embedding_provider_test.go new file mode 100644 index 0000000000000..321a531a1ac0e --- /dev/null +++ b/internal/util/function/vertexai_embedding_provider_test.go @@ -0,0 +1,170 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/models/vertexai" +) + +func TestVertextAITextEmbeddingProvider(t *testing.T) { + suite.Run(t, new(VertextAITextEmbeddingProviderSuite)) +} + +type VertextAITextEmbeddingProviderSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + providers []string +} + +func (s *VertextAITextEmbeddingProviderSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + }, + } +} + +func createVertextAIProvider(url string, schema *schemapb.FieldSchema) (TextEmbeddingProvider, error) { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: modelNameParamKey, Value: textEmbedding005}, + {Key: locationParamKey, Value: "mock_local"}, + {Key: projectIDParamKey, Value: "mock_id"}, + {Key: taskTypeParamKey, Value: vertexAICodeRetrival}, + {Key: embeddingUrlParamKey, Value: url}, + {Key: dimParamKey, Value: "4"}, + }, + } + mockClient := vertexai.NewVertexAIEmbedding(url, []byte{1, 2, 3}, "mock scope", "mock token") + return NewVertextAIEmbeddingProvider(schema, functionSchema, mockClient) +} + +func (s *VertextAITextEmbeddingProviderSuite) TestEmbedding() { + ts := CreateVertexAIEmbeddingServer() + + defer ts.Close() + provder, err := createVertextAIProvider(ts.URL, s.schema.Fields[2]) + s.NoError(err) + { + data := []string{"sentence"} + ret, err2 := provder.CallEmbedding(data, false, InsertMode) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(4, len(ret[0])) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0]) + } + { + data := []string{"sentence 1", "sentence 2", "sentence 3"} + ret, _ := provder.CallEmbedding(data, false, SearchMode) + s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) + } + +} + +func (s *VertextAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res vertexai.EmbeddingResponse + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: []float32{1.0, 1.0, 1.0, 1.0}, + }, + }) + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: []float32{1.0, 1.0}, + }, + }) + + res.Metadata = vertexai.Metadata{ + BillableCharacterCount: 100, + } + + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + provder, err := createVertextAIProvider(ts.URL, s.schema.Fields[2]) + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) +} + +func (s *VertextAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res vertexai.EmbeddingResponse + res.Predictions = append(res.Predictions, vertexai.Prediction{ + Embeddings: vertexai.Embeddings{ + Statistics: vertexai.Statistics{ + Truncated: false, + TokenCount: 10, + }, + Values: []float32{1.0, 1.0, 1.0, 1.0}, + }, + }) + res.Metadata = vertexai.Metadata{ + BillableCharacterCount: 100, + } + + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + provder, err := createVertextAIProvider(ts.URL, s.schema.Fields[2]) + + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence2"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) +} From 5c0f58b030223e7705ff7d83fd3c8d25919ba6d8 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Wed, 18 Dec 2024 16:32:42 +0800 Subject: [PATCH 15/18] Add function check Signed-off-by: junjie.jiang --- internal/datanode/importv2/scheduler_test.go | 10 ++- internal/proxy/task_insert_test.go | 22 +++-- internal/proxy/task_search_test.go | 24 +++--- internal/proxy/task_test.go | 2 - internal/proxy/task_upsert_test.go | 11 +-- .../alitext_embedding_provider_test.go | 10 ++- .../function/bedrock_embedding_provider.go | 9 +- .../bedrock_text_embedding_provider_test.go | 10 ++- internal/util/function/function_base.go | 6 +- .../util/function/function_executor_test.go | 22 ++--- .../openai_text_embedding_provider_test.go | 10 ++- .../util/function/text_embedding_function.go | 6 +- .../function/text_embedding_function_test.go | 82 +++++++++++-------- .../vertexai_embedding_provider_test.go | 10 ++- 14 files changed, 138 insertions(+), 96 deletions(-) diff --git a/internal/datanode/importv2/scheduler_test.go b/internal/datanode/importv2/scheduler_test.go index 99f8e5d3593ee..a7c56a4b2561b 100644 --- a/internal/datanode/importv2/scheduler_test.go +++ b/internal/datanode/importv2/scheduler_test.go @@ -475,10 +475,12 @@ func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() { }, Functions: []*schemapb.FunctionSchema{ { - Name: "test", - Type: schemapb.FunctionType_TextEmbedding, - InputFieldIds: []int64{100}, - OutputFieldIds: []int64{101}, + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{100}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{101}, + OutputFieldNames: []string{"vec"}, Params: []*commonpb.KeyValuePair{ {Key: function.Provider, Value: function.OpenAIProvider}, {Key: "model_name", Value: "text-embedding-ada-002"}, diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index 2946af3c4bfe2..2586ddf37afca 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -351,11 +351,12 @@ func TestInsertTask_Function(t *testing.T) { }, Functions: []*schemapb.FunctionSchema{ { - Name: "test_function", - Type: schemapb.FunctionType_TextEmbedding, - InputFieldIds: []int64{101}, - InputFieldNames: []string{"text"}, - OutputFieldIds: []int64{102}, + Name: "test_function", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, + OutputFieldNames: []string{"vector"}, Params: []*commonpb.KeyValuePair{ {Key: function.Provider, Value: function.OpenAIProvider}, {Key: "model_name", Value: "text-embedding-ada-002"}, @@ -416,6 +417,17 @@ func TestInsertTask_Function(t *testing.T) { createdTimestamp: 10001, createdUtcTimestamp: 10002, }, nil) + cache.On("GetCollectionInfo", + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(&collectionInfo{schema: info}, nil) + cache.On("GetDatabaseInfo", + mock.Anything, + mock.Anything, + ).Return(&databaseInfo{properties: []*commonpb.KeyValuePair{}}, nil) + globalMetaCache = cache err = task.PreExecute(ctx) assert.NoError(t, err) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 377c0d27d70be..21b69d5257ab7 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -483,10 +483,10 @@ func TestSearchTask_PreExecute(t *testing.T) { func TestSearchTask_WithFunctions(t *testing.T) { ts := function.CreateOpenAIEmbeddingServer() defer ts.Close() - collectionName := "TestInsertTask_function" + collectionName := "TestSearchTask_function" schema := &schemapb.CollectionSchema{ Name: collectionName, - Description: "TestInsertTask_function", + Description: "TestSearchTask_function", AutoID: true, Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, @@ -505,10 +505,12 @@ func TestSearchTask_WithFunctions(t *testing.T) { }, Functions: []*schemapb.FunctionSchema{ { - Name: "func1", - Type: schemapb.FunctionType_TextEmbedding, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "func1", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, + OutputFieldNames: []string{"vector1"}, Params: []*commonpb.KeyValuePair{ {Key: function.Provider, Value: function.OpenAIProvider}, {Key: "model_name", Value: "text-embedding-ada-002"}, @@ -518,10 +520,12 @@ func TestSearchTask_WithFunctions(t *testing.T) { }, }, { - Name: "func2", - Type: schemapb.FunctionType_TextEmbedding, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{103}, + Name: "func2", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{103}, + OutputFieldNames: []string{"vector2"}, Params: []*commonpb.KeyValuePair{ {Key: function.Provider, Value: function.OpenAIProvider}, {Key: "model_name", Value: "text-embedding-ada-002"}, diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 7982f36731dbe..d1fd055556dea 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -1032,8 +1032,6 @@ func TestCreateCollectionTask(t *testing.T) { Type: schemapb.FunctionType_TextEmbedding, InputFieldNames: []string{varCharField}, OutputFieldNames: []string{floatVecField}, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: "provider", Value: "openai"}, {Key: "model_name", Value: "text-embedding-ada-002"}, diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index 10ac7d6a70487..348ee64313d31 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -420,11 +420,12 @@ func TestUpsertTask_Function(t *testing.T) { }, Functions: []*schemapb.FunctionSchema{ { - Name: "test_function", - Type: schemapb.FunctionType_TextEmbedding, - InputFieldIds: []int64{101}, - InputFieldNames: []string{"text"}, - OutputFieldIds: []int64{102}, + Name: "test_function", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, + OutputFieldNames: []string{"vector"}, Params: []*commonpb.KeyValuePair{ {Key: function.Provider, Value: function.OpenAIProvider}, {Key: "model_name", Value: "text-embedding-ada-002"}, diff --git a/internal/util/function/alitext_embedding_provider_test.go b/internal/util/function/alitext_embedding_provider_test.go index f4a36bd2635a4..73d8613f20a4d 100644 --- a/internal/util/function/alitext_embedding_provider_test.go +++ b/internal/util/function/alitext_embedding_provider_test.go @@ -60,10 +60,12 @@ func (s *AliTextEmbeddingProviderSuite) SetupTest() { func createAliProvider(url string, schema *schemapb.FieldSchema, providerName string) (TextEmbeddingProvider, error) { functionSchema := &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: modelNameParamKey, Value: TextEmbeddingV3}, {Key: apiKeyParamKey, Value: "mock"}, diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go index a9a6d56e95be9..5e64ca4f6f2b7 100644 --- a/internal/util/function/bedrock_embedding_provider.go +++ b/internal/util/function/bedrock_embedding_provider.go @@ -87,7 +87,7 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche } var awsAccessKeyId, awsSecretAccessKey, region, modelName string var dim int64 - var normalize bool + normalize := false for _, param := range functionSchema.Params { switch strings.ToLower(param.Key) { @@ -112,8 +112,6 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche switch strings.ToLower(param.Value) { case "true": normalize = true - case "false": - normalize = false default: return nil, fmt.Errorf("Illegal [%s:%s] param, ", normalizeParamKey, param.Value) } @@ -147,11 +145,11 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche } func (provider *BedrockEmbeddingProvider) MaxBatch() int { - return 5 * provider.maxBatch + return 12 * provider.maxBatch } func (provider *BedrockEmbeddingProvider) FieldDim() int64 { - return 5 * provider.fieldDim + return provider.fieldDim } func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, _ string) ([][]float32, error) { @@ -164,6 +162,7 @@ func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLim for i := 0; i < numRows; i += 1 { payload := BedRockRequest{ InputText: texts[i], + Normalize: provider.normalize, } if provider.embedDimParam != 0 { payload.Dimensions = provider.embedDimParam diff --git a/internal/util/function/bedrock_text_embedding_provider_test.go b/internal/util/function/bedrock_text_embedding_provider_test.go index 9d74f7e2604cc..eb26aa03ef3f7 100644 --- a/internal/util/function/bedrock_text_embedding_provider_test.go +++ b/internal/util/function/bedrock_text_embedding_provider_test.go @@ -55,10 +55,12 @@ func (s *BedrockTextEmbeddingProviderSuite) SetupTest() { func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, dim int) (TextEmbeddingProvider, error) { functionSchema := &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2}, {Key: apiKeyParamKey, Value: "mock"}, diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index 209b4788180dc..fa3253bf3c7bd 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -32,16 +32,16 @@ type FunctionBase struct { func NewFunctionBase(coll *schemapb.CollectionSchema, f_schema *schemapb.FunctionSchema) (*FunctionBase, error) { var base FunctionBase base.schema = f_schema - for _, field_id := range f_schema.GetOutputFieldIds() { + for _, fieldName := range f_schema.GetOutputFieldNames() { for _, field := range coll.GetFields() { - if field.GetFieldID() == field_id { + if field.GetName() == fieldName { base.outputFields = append(base.outputFields, field) break } } } - if len(base.outputFields) != len(f_schema.GetOutputFieldIds()) { + if len(base.outputFields) != len(f_schema.GetOutputFieldNames()) { return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema", coll.Name, f_schema.Name) } diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index 5a4cd3aae0fcc..a791760351fb7 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -62,11 +62,12 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch }, Functions: []*schemapb.FunctionSchema{ { - Name: "test", - Type: schemapb.FunctionType_TextEmbedding, - InputFieldIds: []int64{101}, - InputFieldNames: []string{"text"}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{102}, + OutputFieldNames: []string{"vector"}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, @@ -76,11 +77,12 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch }, }, { - Name: "test", - Type: schemapb.FunctionType_TextEmbedding, - InputFieldIds: []int64{101}, - InputFieldNames: []string{"text"}, - OutputFieldIds: []int64{103}, + Name: "test", + Type: schemapb.FunctionType_TextEmbedding, + InputFieldIds: []int64{101}, + InputFieldNames: []string{"text"}, + OutputFieldIds: []int64{103}, + OutputFieldNames: []string{"vector2"}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, diff --git a/internal/util/function/openai_text_embedding_provider_test.go b/internal/util/function/openai_text_embedding_provider_test.go index 7c3667822956f..395ecf06cdc9d 100644 --- a/internal/util/function/openai_text_embedding_provider_test.go +++ b/internal/util/function/openai_text_embedding_provider_test.go @@ -60,10 +60,12 @@ func (s *OpenAITextEmbeddingProviderSuite) SetupTest() { func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName string) (TextEmbeddingProvider, error) { functionSchema := &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: apiKeyParamKey, Value: "mock"}, diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index 8ddf4893dfc44..d359514d6cfbc 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -66,8 +66,8 @@ type TextEmebddingFunction struct { } func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *schemapb.FunctionSchema) (*TextEmebddingFunction, error) { - if len(functionSchema.GetOutputFieldIds()) != 1 { - return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldIds())) + if len(functionSchema.GetOutputFieldNames()) != 1 { + return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldNames())) } base, err := NewFunctionBase(coll, functionSchema) @@ -76,7 +76,7 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s } if base.outputFields[0].DataType != schemapb.DataType_FloatVector { - return nil, fmt.Errorf("Output field not match, openai embedding needs [%s], got [%s]", + return nil, fmt.Errorf("Text embedding function's output field not match, needs [%s], got [%s]", schemapb.DataType_name[int32(schemapb.DataType_FloatVector)], schemapb.DataType_name[int32(base.outputFields[0].DataType)]) } diff --git a/internal/util/function/text_embedding_function_test.go b/internal/util/function/text_embedding_function_test.go index 9fb6d946a8b4d..ce0bfc86dbf51 100644 --- a/internal/util/function/text_embedding_function_test.go +++ b/internal/util/function/text_embedding_function_test.go @@ -76,10 +76,12 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { { runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, @@ -108,10 +110,12 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { { runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: AzureOpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, @@ -143,10 +147,12 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { defer ts.Close() runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: AliDashScopeProvider}, {Key: modelNameParamKey, Value: TextEmbeddingV3}, @@ -189,10 +195,12 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { } _, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, @@ -215,17 +223,19 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, }}, - {FieldID: 103, Name: "vector", DataType: schemapb.DataType_FloatVector, + {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, }}, }, } _, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102, 103}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector", "vector2"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102, 103}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, @@ -240,10 +250,12 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { // outputfield miss { _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{103}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector2"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{103}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, @@ -258,10 +270,12 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { // error model name { _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-004"}, @@ -276,10 +290,12 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { // no openai api key { _, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-003"}, diff --git a/internal/util/function/vertexai_embedding_provider_test.go b/internal/util/function/vertexai_embedding_provider_test.go index 321a531a1ac0e..2c18b133cc974 100644 --- a/internal/util/function/vertexai_embedding_provider_test.go +++ b/internal/util/function/vertexai_embedding_provider_test.go @@ -57,10 +57,12 @@ func (s *VertextAITextEmbeddingProviderSuite) SetupTest() { func createVertextAIProvider(url string, schema *schemapb.FieldSchema) (TextEmbeddingProvider, error) { functionSchema := &schemapb.FunctionSchema{ - Name: "test", - Type: schemapb.FunctionType_Unknown, - InputFieldIds: []int64{101}, - OutputFieldIds: []int64{102}, + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: modelNameParamKey, Value: textEmbedding005}, {Key: locationParamKey, Value: "mock_local"}, From 8c627716e7497ffc9d4739eb676ac1db8dd971d2 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Fri, 20 Dec 2024 11:21:52 +0800 Subject: [PATCH 16/18] Adjust dashscope model batch Signed-off-by: junjie.jiang --- internal/util/function/ali_embedding_provider.go | 8 +++++++- internal/util/function/bedrock_embedding_provider.go | 6 +++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/internal/util/function/ali_embedding_provider.go b/internal/util/function/ali_embedding_provider.go index 920041afadbc1..2def771d86971 100644 --- a/internal/util/function/ali_embedding_provider.go +++ b/internal/util/function/ali_embedding_provider.go @@ -95,6 +95,12 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio if err != nil { return nil, err } + + maxBatch := 25 + if modelName == TextEmbeddingV3 { + maxBatch = 6 + } + provider := AliEmbeddingProvider{ client: c, fieldDim: fieldDim, @@ -102,7 +108,7 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio embedDimParam: dim, // TextEmbedding only supports dense embedding outputType: "dense", - maxBatch: 25, + maxBatch: maxBatch, timeoutSec: 30, } return &provider, nil diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go index 5e64ca4f6f2b7..f9d4d184e4c8e 100644 --- a/internal/util/function/bedrock_embedding_provider.go +++ b/internal/util/function/bedrock_embedding_provider.go @@ -87,7 +87,7 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche } var awsAccessKeyId, awsSecretAccessKey, region, modelName string var dim int64 - normalize := false + normalize := true for _, param := range functionSchema.Params { switch strings.ToLower(param.Key) { @@ -110,8 +110,8 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche region = param.Value case normalizeParamKey: switch strings.ToLower(param.Value) { - case "true": - normalize = true + case "false": + normalize = false default: return nil, fmt.Errorf("Illegal [%s:%s] param, ", normalizeParamKey, param.Value) } From b482b3ccbed1f84de64b9135af7fdbd6d480f04e Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Thu, 26 Dec 2024 16:44:00 +0800 Subject: [PATCH 17/18] Fix confict Signed-off-by: junjie.jiang --- internal/proxy/task_search_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 21b69d5257ab7..4af9d69a82a67 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -18,10 +18,7 @@ package proxy import ( "context" "fmt" - "io" "math" - "net/http" - "net/http/httptest" "strconv" "strings" "testing" From 0e9e9ca321d652cc338b734f572eaa6335e3fb08 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Fri, 27 Dec 2024 16:15:13 +0800 Subject: [PATCH 18/18] Polish code Signed-off-by: junjie.jiang --- .../ali/ali_dashscope_text_embedding.go | 9 ++- internal/models/openai/openai_embedding.go | 81 ++++++++----------- internal/models/utils/embedding_util.go | 5 +- .../vertexai/vertexai_text_embedding.go | 6 +- .../vertexai/vertexai_text_embedding_test.go | 6 +- internal/proxy/task_insert_test.go | 13 ++- internal/proxy/task_search_test.go | 18 +++-- internal/proxy/task_upsert_test.go | 13 ++- .../util/function/ali_embedding_provider.go | 11 +-- .../alitext_embedding_provider_test.go | 13 ++- .../function/bedrock_embedding_provider.go | 16 ++-- .../bedrock_text_embedding_provider_test.go | 5 +- internal/util/function/common.go | 19 ++++- internal/util/function/function_base.go | 10 +-- internal/util/function/function_executor.go | 16 ++-- .../util/function/function_executor_test.go | 17 ++-- internal/util/function/function_util.go | 16 ++-- .../util/function/mock_embedding_service.go | 5 +- .../function/openai_embedding_provider.go | 11 +-- .../openai_text_embedding_provider_test.go | 16 ++-- .../util/function/text_embedding_function.go | 8 +- .../function/text_embedding_function_test.go | 41 +++++----- .../function/vertexai_embedding_provider.go | 11 +-- .../vertexai_embedding_provider_test.go | 9 ++- 24 files changed, 187 insertions(+), 188 deletions(-) diff --git a/internal/models/ali/ali_dashscope_text_embedding.go b/internal/models/ali/ali_dashscope_text_embedding.go index 329451577f07f..ee412c6e992f6 100644 --- a/internal/models/ali/ali_dashscope_text_embedding.go +++ b/internal/models/ali/ali_dashscope_text_embedding.go @@ -83,6 +83,7 @@ func (eb *ByIndex) Len() int { return len(eb.resp.Output.Embeddings) } func (eb *ByIndex) Swap(i, j int) { eb.resp.Output.Embeddings[i], eb.resp.Output.Embeddings[j] = eb.resp.Output.Embeddings[j], eb.resp.Output.Embeddings[i] } + func (eb *ByIndex) Less(i, j int) bool { return eb.resp.Output.Embeddings[i].TextIndex < eb.resp.Output.Embeddings[j].TextIndex } @@ -116,20 +117,20 @@ func (c *AliDashScopeEmbedding) Check() error { return nil } -func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, text_type string, output_type string, timeoutSec time.Duration) (*EmbeddingResponse, error) { +func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec time.Duration) (*EmbeddingResponse, error) { var r EmbeddingRequest r.Model = modelName r.Input = Input{texts} r.Parameters.Dimension = dim - r.Parameters.TextType = text_type - r.Parameters.OutputType = output_type + r.Parameters.TextType = textType + r.Parameters.OutputType = outputType data, err := json.Marshal(r) if err != nil { return nil, err } if timeoutSec <= 0 { - timeoutSec = 30 + timeoutSec = utils.DefaultTimeout } ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) diff --git a/internal/models/openai/openai_embedding.go b/internal/models/openai/openai_embedding.go index bb6f88be0cd18..433a95a2e32a7 100644 --- a/internal/models/openai/openai_embedding.go +++ b/internal/models/openai/openai_embedding.go @@ -135,20 +135,7 @@ func (c *openAIBase) genReq(modelName string, texts []string, dim int, user stri return &r } -type OpenAIEmbeddingClient struct { - openAIBase -} - -func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient { - return &OpenAIEmbeddingClient{ - openAIBase{ - apiKey: apiKey, - url: url, - }, - } -} - -func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { +func (c *openAIBase) embedding(url string, headers map[string]string, modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { r := c.genReq(modelName, texts, dim, user) data, err := json.Marshal(r) if err != nil { @@ -156,20 +143,23 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim } if timeoutSec <= 0 { - timeoutSec = 30 + timeoutSec = utils.DefaultTimeout } + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + for key, value := range headers { + req.Header.Set(key, value) + } body, err := utils.RetrySend(req, 3) if err != nil { return nil, err } + var res EmbeddingResponse err = json.Unmarshal(body, &res) if err != nil { @@ -179,6 +169,27 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim return &res, err } +type OpenAIEmbeddingClient struct { + openAIBase +} + +func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient { + return &OpenAIEmbeddingClient{ + openAIBase{ + apiKey: apiKey, + url: url, + }, + } +} + +func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), + } + return c.embedding(c.url, headers, modelName, texts, dim, user, timeoutSec) +} + type AzureOpenAIEmbeddingClient struct { openAIBase apiVersion string @@ -195,16 +206,6 @@ func NewAzureOpenAIEmbeddingClient(apiKey string, url string) *AzureOpenAIEmbedd } func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { - r := c.genReq(modelName, texts, dim, user) - data, err := json.Marshal(r) - if err != nil { - return nil, err - } - - if timeoutSec <= 0 { - timeoutSec = 30 - } - base, err := url.Parse(c.url) if err != nil { return nil, err @@ -214,25 +215,11 @@ func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, params := url.Values{} params.Add("api-version", c.apiVersion) base.RawQuery = params.Encode() + url := base.String() - ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) - defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, base.String(), bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("api-key", c.apiKey) - body, err := utils.RetrySend(req, 3) - if err != nil { - return nil, err - } - - var res EmbeddingResponse - err = json.Unmarshal(body, &res) - if err != nil { - return nil, err + headers := map[string]string{ + "Content-Type": "application/json", + "api-key": c.apiKey, } - sort.Sort(&ByIndex{&res}) - return &res, err + return c.embedding(url, headers, modelName, texts, dim, user, timeoutSec) } diff --git a/internal/models/utils/embedding_util.go b/internal/models/utils/embedding_util.go index 1d6e7d916cab2..1383d5740e814 100644 --- a/internal/models/utils/embedding_util.go +++ b/internal/models/utils/embedding_util.go @@ -20,8 +20,11 @@ import ( "fmt" "io" "net/http" + "time" ) +const DefaultTimeout time.Duration = 30 + func send(req *http.Request) ([]byte, error) { resp, err := http.DefaultClient.Do(req) if err != nil { @@ -34,7 +37,7 @@ func send(req *http.Request) ([]byte, error) { return nil, err } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf(string(body)) } return body, nil diff --git a/internal/models/vertexai/vertexai_text_embedding.go b/internal/models/vertexai/vertexai_text_embedding.go index 3842824616214..1a63c59961f81 100644 --- a/internal/models/vertexai/vertexai_text_embedding.go +++ b/internal/models/vertexai/vertexai_text_embedding.go @@ -24,9 +24,9 @@ import ( "net/http" "time" - "github.com/milvus-io/milvus/internal/models/utils" - "golang.org/x/oauth2/google" + + "github.com/milvus-io/milvus/internal/models/utils" ) type Instance struct { @@ -129,7 +129,7 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6 } if timeoutSec <= 0 { - timeoutSec = 30 + timeoutSec = utils.DefaultTimeout } ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) diff --git a/internal/models/vertexai/vertexai_text_embedding_test.go b/internal/models/vertexai/vertexai_text_embedding_test.go index f138d659a3ea4..83b26ac4d4634 100644 --- a/internal/models/vertexai/vertexai_text_embedding_test.go +++ b/internal/models/vertexai/vertexai_text_embedding_test.go @@ -27,7 +27,7 @@ import ( ) func TestEmbeddingClientCheck(t *testing.T) { - mockJsonKey := []byte{1, 2, 3} + mockJSONKey := []byte{1, 2, 3} { c := NewVertexAIEmbedding("mock_url", []byte{}, "mock_scopes", "") err := c.Check() @@ -36,14 +36,14 @@ func TestEmbeddingClientCheck(t *testing.T) { } { - c := NewVertexAIEmbedding("", mockJsonKey, "", "") + c := NewVertexAIEmbedding("", mockJSONKey, "", "") err := c.Check() assert.True(t, err != nil) fmt.Println(err) } { - c := NewVertexAIEmbedding("mock_url", mockJsonKey, "mock_scopes", "") + c := NewVertexAIEmbedding("mock_url", mockJSONKey, "mock_scopes", "") err := c.Check() assert.True(t, err == nil) } diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index 2586ddf37afca..006d383be3146 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -340,14 +340,19 @@ func TestInsertTask_Function(t *testing.T) { AutoID: true, Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, - {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: "max_length", Value: "200"}, - }}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }, IsFunctionOutput: true}, + }, + IsFunctionOutput: true, + }, }, Functions: []*schemapb.FunctionSchema{ { diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 4af9d69a82a67..b255027cf7316 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -487,18 +487,24 @@ func TestSearchTask_WithFunctions(t *testing.T) { AutoID: true, Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, - {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: "max_length", Value: "200"}, - }}, - {FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, - {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, Functions: []*schemapb.FunctionSchema{ { diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index 348ee64313d31..da0b3595cc45e 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -409,14 +409,19 @@ func TestUpsertTask_Function(t *testing.T) { AutoID: true, Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, - {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: "max_length", Value: "200"}, - }}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }, IsFunctionOutput: true}, + }, + IsFunctionOutput: true, + }, }, Functions: []*schemapb.FunctionSchema{ { diff --git a/internal/util/function/ali_embedding_provider.go b/internal/util/function/ali_embedding_provider.go index 2def771d86971..966c530522e16 100644 --- a/internal/util/function/ali_embedding_provider.go +++ b/internal/util/function/ali_embedding_provider.go @@ -21,7 +21,6 @@ package function import ( "fmt" "os" - "strconv" "strings" "time" @@ -70,17 +69,13 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + return nil, err } case apiKeyParamKey: apiKey = param.Value - case embeddingUrlParamKey: + case embeddingURLParamKey: url = param.Value default: } diff --git a/internal/util/function/alitext_embedding_provider_test.go b/internal/util/function/alitext_embedding_provider_test.go index 73d8613f20a4d..a852b1b74e6ab 100644 --- a/internal/util/function/alitext_embedding_provider_test.go +++ b/internal/util/function/alitext_embedding_provider_test.go @@ -25,12 +25,11 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/suite" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/models/ali" + + "github.com/stretchr/testify/suite" ) func TestAliTextEmbeddingProvider(t *testing.T) { @@ -52,7 +51,8 @@ func (s *AliTextEmbeddingProviderSuite) SetupTest() { {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } s.providers = []string{AliDashScopeProvider} @@ -69,7 +69,7 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st Params: []*commonpb.KeyValuePair{ {Key: modelNameParamKey, Value: TextEmbeddingV3}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "4"}, }, } @@ -101,7 +101,6 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } - } } @@ -134,7 +133,6 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { data := []string{"sentence", "sentence"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } @@ -163,6 +161,5 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { data := []string{"sentence", "sentence2"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go index f9d4d184e4c8e..eb54712ce5499 100644 --- a/internal/util/function/bedrock_embedding_provider.go +++ b/internal/util/function/bedrock_embedding_provider.go @@ -23,16 +23,15 @@ import ( "encoding/json" "fmt" "os" - "strconv" "strings" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type BedrockClient interface { @@ -94,13 +93,9 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + return nil, err } case awsAccessKeyIdParamKey: awsAccessKeyId = param.Value @@ -178,7 +173,6 @@ func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLim ModelId: aws.String(provider.modelName), ContentType: aws.String("application/json"), }) - if err != nil { return nil, err } diff --git a/internal/util/function/bedrock_text_embedding_provider_test.go b/internal/util/function/bedrock_text_embedding_provider_test.go index eb26aa03ef3f7..e8f08df77e8d1 100644 --- a/internal/util/function/bedrock_text_embedding_provider_test.go +++ b/internal/util/function/bedrock_text_embedding_provider_test.go @@ -47,7 +47,8 @@ func (s *BedrockTextEmbeddingProviderSuite) SetupTest() { {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } s.providers = []string{BedrockProvider} @@ -92,7 +93,6 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}}, ret) } - } } @@ -105,6 +105,5 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { data := []string{"sentence", "sentence"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } diff --git a/internal/util/function/common.go b/internal/util/function/common.go index 56da30e5ed42f..a6c4fe4840b0e 100644 --- a/internal/util/function/common.go +++ b/internal/util/function/common.go @@ -18,6 +18,11 @@ package function +import ( + "fmt" + "strconv" +) + const ( InsertMode string = "Insert" SearchMode string = "Search" @@ -27,7 +32,7 @@ const ( const ( modelNameParamKey string = "model_name" dimParamKey string = "dim" - embeddingUrlParamKey string = "url" + embeddingURLParamKey string = "url" apiKeyParamKey string = "api_key" ) @@ -80,3 +85,15 @@ const ( vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS" ) + +func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) { + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("dim [%s] is not int", dimStr) + } + + if dim != 0 && dim != fieldDim { + return 0, fmt.Errorf("Field %s's dim is [%d], but embedding's dim is [%d]", fieldName, fieldDim, dim) + } + return dim, nil +} diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index fa3253bf3c7bd..aabcfdf5c0ea2 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -29,10 +29,10 @@ type FunctionBase struct { outputFields []*schemapb.FieldSchema } -func NewFunctionBase(coll *schemapb.CollectionSchema, f_schema *schemapb.FunctionSchema) (*FunctionBase, error) { +func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.FunctionSchema) (*FunctionBase, error) { var base FunctionBase - base.schema = f_schema - for _, fieldName := range f_schema.GetOutputFieldNames() { + base.schema = fSchema + for _, fieldName := range fSchema.GetOutputFieldNames() { for _, field := range coll.GetFields() { if field.GetName() == fieldName { base.outputFields = append(base.outputFields, field) @@ -41,9 +41,9 @@ func NewFunctionBase(coll *schemapb.CollectionSchema, f_schema *schemapb.Functio } } - if len(base.outputFields) != len(f_schema.GetOutputFieldNames()) { + if len(base.outputFields) != len(fSchema.GetOutputFieldNames()) { return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema", - coll.Name, f_schema.Name) + coll.Name, fSchema.Name) } return &base, nil } diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index 6f2469cca9173..011380ae5a226 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -62,8 +62,8 @@ func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSc } func CheckFunctions(schema *schemapb.CollectionSchema) error { - for _, f_schema := range schema.Functions { - if _, err := createFunction(schema, f_schema); err != nil { + for _, fSchema := range schema.Functions { + if _, err := createFunction(schema, fSchema); err != nil { return err } } @@ -77,12 +77,12 @@ func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, executor := &FunctionExecutor{ runners: make(map[int64]Runner), } - for _, f_schema := range schema.Functions { - if runner, err := createFunction(schema, f_schema); err != nil { + for _, fSchema := range schema.Functions { + if runner, err := createFunction(schema, fSchema); err != nil { return nil, err } else { if runner != nil { - executor.runners[f_schema.GetOutputFieldIds()[0]] = runner + executor.runners[fSchema.GetOutputFieldIds()[0]] = runner } } } @@ -200,7 +200,6 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ } else { outputs <- map[int64][]byte{idx: newHolder} } - }(runner, int64(idx)) } } @@ -222,9 +221,8 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) error { if !req.IsAdvanced { return executor.prcessSearch(req) - } else { - return executor.prcessAdvanceSearch(req) - } + } + return executor.prcessAdvanceSearch(req) } func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) { diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index a791760351fb7..a38360d343660 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -49,16 +49,20 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, }, IsFunctionOutput: true, }, - {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "8"}, - }, IsFunctionOutput: true}, + }, + IsFunctionOutput: true, + }, }, Functions: []*schemapb.FunctionSchema{ { @@ -72,7 +76,7 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "4"}, }, }, @@ -87,17 +91,15 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "8"}, }, }, }, } - } func (s *FunctionExecutorSuite) createMsg(texts []string) *msgstream.InsertMsg { - data := []*schemapb.FieldData{} f := schemapb.FieldData{ Type: schemapb.DataType_VarChar, @@ -173,7 +175,6 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - })) defer ts.Close() schema := s.creataSchema(ts.URL) diff --git a/internal/util/function/function_util.go b/internal/util/function/function_util.go index bd0265336baa7..240e13615b4f7 100644 --- a/internal/util/function/function_util.go +++ b/internal/util/function/function_util.go @@ -26,8 +26,8 @@ import ( func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool { // Determine whether the column corresponding to outputIDs contains functions, except bm25 function, // if outputIDs is empty, check all cols - for _, f_schema := range functions { - switch f_schema.GetType() { + for _, fSchema := range functions { + switch fSchema.GetType() { case schemapb.FunctionType_BM25: case schemapb.FunctionType_Unknown: default: @@ -35,7 +35,7 @@ func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool return true } else { for _, id := range outputIDs { - if f_schema.GetOutputFieldIds()[0] == id { + if fSchema.GetOutputFieldIds()[0] == id { return true } } @@ -47,14 +47,14 @@ func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool func GetOutputIDFunctionsMap(functions []*schemapb.FunctionSchema) (map[int64]*schemapb.FunctionSchema, error) { outputIdMap := map[int64]*schemapb.FunctionSchema{} - for _, f_schema := range functions { - switch f_schema.GetType() { + for _, fSchema := range functions { + switch fSchema.GetType() { case schemapb.FunctionType_BM25: default: - if len(f_schema.OutputFieldIds) != 1 { - return nil, merr.WrapErrParameterInvalidMsg("Function [%s]'s outputs err, only supports one outputs", f_schema.Name) + if len(fSchema.OutputFieldIds) != 1 { + return nil, merr.WrapErrParameterInvalidMsg("Function [%s]'s outputs err, only supports one outputs", fSchema.Name) } - outputIdMap[f_schema.OutputFieldIds[0]] = f_schema + outputIdMap[fSchema.OutputFieldIds[0]] = fSchema } } return outputIdMap, nil diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go index 4cb181a7a0c4f..c071a2056df72 100644 --- a/internal/util/function/mock_embedding_service.go +++ b/internal/util/function/mock_embedding_service.go @@ -25,10 +25,11 @@ import ( "net/http" "net/http/httptest" - "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/milvus-io/milvus/internal/models/ali" "github.com/milvus-io/milvus/internal/models/openai" "github.com/milvus-io/milvus/internal/models/vertexai" + + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" ) func mockEmbedding(texts []string, dim int) [][]float32 { @@ -69,7 +70,6 @@ func CreateOpenAIEmbeddingServer() *httptest.Server { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - })) return ts } @@ -129,7 +129,6 @@ func CreateVertexAIEmbeddingServer() *httptest.Server { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - })) return ts } diff --git a/internal/util/function/openai_embedding_provider.go b/internal/util/function/openai_embedding_provider.go index 32cfb945509f7..8b5f53bc7fd2c 100644 --- a/internal/util/function/openai_embedding_provider.go +++ b/internal/util/function/openai_embedding_provider.go @@ -21,7 +21,6 @@ package function import ( "fmt" "os" - "strconv" "strings" "time" @@ -89,19 +88,15 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", fieldSchema.Name, fieldDim, dim) + return nil, err } case userParamKey: user = param.Value case apiKeyParamKey: apiKey = param.Value - case embeddingUrlParamKey: + case embeddingURLParamKey: url = param.Value default: } diff --git a/internal/util/function/openai_text_embedding_provider_test.go b/internal/util/function/openai_text_embedding_provider_test.go index 395ecf06cdc9d..09b120e0603d5 100644 --- a/internal/util/function/openai_text_embedding_provider_test.go +++ b/internal/util/function/openai_text_embedding_provider_test.go @@ -25,12 +25,11 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/suite" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/models/openai" + + "github.com/stretchr/testify/suite" ) func TestOpenAITextEmbeddingProvider(t *testing.T) { @@ -49,10 +48,12 @@ func (s *OpenAITextEmbeddingProviderSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } s.providers = []string{OpenAIProvider, AzureOpenAIProvider} @@ -70,7 +71,7 @@ func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: apiKeyParamKey, Value: "mock"}, {Key: dimParamKey, Value: "4"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, }, } switch providerName { @@ -103,7 +104,6 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } - } } @@ -141,7 +141,6 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { data := []string{"sentence", "sentence"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } @@ -174,6 +173,5 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { data := []string{"sentence", "sentence2"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index d359514d6cfbc..030679df812fb 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -26,7 +26,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/util/funcutil" - // "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ( @@ -134,7 +133,6 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s default: return nil, fmt.Errorf("Unsupported embedding service provider: [%s] , list of supported [%s, %s, %s, %s]", provider, OpenAIProvider, AzureOpenAIProvider, AliDashScopeProvider, BedrockProvider) } - } func (runner *TextEmebddingFunction) MaxBatch() int { @@ -147,7 +145,7 @@ func (runner *TextEmebddingFunction) ProcessInsert(inputs []*schemapb.FieldData) } if inputs[0].Type != schemapb.DataType_VarChar { - return nil, fmt.Errorf("Text embedding only supports varchar field, the input is not varchar") + return nil, fmt.Errorf("Text embedding only supports varchar field as input field, but got %s", schemapb.DataType_name[int32(inputs[0].Type)]) } texts := inputs[0].GetScalars().GetStringData().GetData() @@ -193,11 +191,11 @@ func (runner *TextEmebddingFunction) ProcessSearch(placeholderGroup *commonpb.Pl func (runner *TextEmebddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) { if len(inputs) != 1 { - return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) + return nil, fmt.Errorf("TextEmbedding function only receives one input, bug got [%d]", len(inputs)) } if inputs[0].GetDataType() != schemapb.DataType_VarChar { - return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") + return nil, fmt.Errorf(" only supports varchar field, the input is not varchar") } texts, ok := inputs[0].GetDataRows().([]string) diff --git a/internal/util/function/text_embedding_function_test.go b/internal/util/function/text_embedding_function_test.go index ce0bfc86dbf51..353684e55b838 100644 --- a/internal/util/function/text_embedding_function_test.go +++ b/internal/util/function/text_embedding_function_test.go @@ -42,10 +42,12 @@ func (s *TextEmbeddingFunctionSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } } @@ -74,7 +76,6 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { ts := CreateOpenAIEmbeddingServer() defer ts.Close() { - runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_Unknown, @@ -87,7 +88,7 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: ts.URL}, + {Key: embeddingURLParamKey, Value: ts.URL}, }, }) s.NoError(err) @@ -106,7 +107,6 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) } } - { runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ @@ -121,7 +121,7 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: ts.URL}, + {Key: embeddingURLParamKey, Value: ts.URL}, }, }) s.NoError(err) @@ -158,7 +158,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { {Key: modelNameParamKey, Value: TextEmbeddingV3}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: ts.URL}, + {Key: embeddingURLParamKey, Value: ts.URL}, }, }) s.NoError(err) @@ -176,7 +176,6 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { ret, _ := runner.ProcessInsert(data) s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) } - } func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { @@ -187,10 +186,12 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } @@ -206,7 +207,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) @@ -219,14 +220,18 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, - {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } _, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{ @@ -241,7 +246,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) @@ -261,7 +266,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) @@ -281,7 +286,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-004"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) diff --git a/internal/util/function/vertexai_embedding_provider.go b/internal/util/function/vertexai_embedding_provider.go index 1d9c997571dcf..8eb4ad52ff1f9 100644 --- a/internal/util/function/vertexai_embedding_provider.go +++ b/internal/util/function/vertexai_embedding_provider.go @@ -21,7 +21,6 @@ package function import ( "fmt" "os" - "strconv" "strings" "sync" "time" @@ -102,14 +101,10 @@ func NewVertextAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSc case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) - } + return nil, err + } case locationParamKey: location = param.Value case projectIDParamKey: diff --git a/internal/util/function/vertexai_embedding_provider_test.go b/internal/util/function/vertexai_embedding_provider_test.go index 2c18b133cc974..10a9093d69634 100644 --- a/internal/util/function/vertexai_embedding_provider_test.go +++ b/internal/util/function/vertexai_embedding_provider_test.go @@ -47,10 +47,12 @@ func (s *VertextAITextEmbeddingProviderSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } } @@ -68,7 +70,7 @@ func createVertextAIProvider(url string, schema *schemapb.FieldSchema) (TextEmbe {Key: locationParamKey, Value: "mock_local"}, {Key: projectIDParamKey, Value: "mock_id"}, {Key: taskTypeParamKey, Value: vertexAICodeRetrival}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "4"}, }, } @@ -95,7 +97,6 @@ func (s *VertextAITextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } - } func (s *VertextAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {