From 0e9e9ca321d652cc338b734f572eaa6335e3fb08 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Fri, 27 Dec 2024 16:15:13 +0800 Subject: [PATCH] Polish code Signed-off-by: junjie.jiang --- .../ali/ali_dashscope_text_embedding.go | 9 ++- internal/models/openai/openai_embedding.go | 81 ++++++++----------- internal/models/utils/embedding_util.go | 5 +- .../vertexai/vertexai_text_embedding.go | 6 +- .../vertexai/vertexai_text_embedding_test.go | 6 +- internal/proxy/task_insert_test.go | 13 ++- internal/proxy/task_search_test.go | 18 +++-- internal/proxy/task_upsert_test.go | 13 ++- .../util/function/ali_embedding_provider.go | 11 +-- .../alitext_embedding_provider_test.go | 13 ++- .../function/bedrock_embedding_provider.go | 16 ++-- .../bedrock_text_embedding_provider_test.go | 5 +- internal/util/function/common.go | 19 ++++- internal/util/function/function_base.go | 10 +-- internal/util/function/function_executor.go | 16 ++-- .../util/function/function_executor_test.go | 17 ++-- internal/util/function/function_util.go | 16 ++-- .../util/function/mock_embedding_service.go | 5 +- .../function/openai_embedding_provider.go | 11 +-- .../openai_text_embedding_provider_test.go | 16 ++-- .../util/function/text_embedding_function.go | 8 +- .../function/text_embedding_function_test.go | 41 +++++----- .../function/vertexai_embedding_provider.go | 11 +-- .../vertexai_embedding_provider_test.go | 9 ++- 24 files changed, 187 insertions(+), 188 deletions(-) diff --git a/internal/models/ali/ali_dashscope_text_embedding.go b/internal/models/ali/ali_dashscope_text_embedding.go index 329451577f07f..ee412c6e992f6 100644 --- a/internal/models/ali/ali_dashscope_text_embedding.go +++ b/internal/models/ali/ali_dashscope_text_embedding.go @@ -83,6 +83,7 @@ func (eb *ByIndex) Len() int { return len(eb.resp.Output.Embeddings) } func (eb *ByIndex) Swap(i, j int) { eb.resp.Output.Embeddings[i], eb.resp.Output.Embeddings[j] = eb.resp.Output.Embeddings[j], eb.resp.Output.Embeddings[i] } + func (eb *ByIndex) Less(i, j int) bool { return eb.resp.Output.Embeddings[i].TextIndex < eb.resp.Output.Embeddings[j].TextIndex } @@ -116,20 +117,20 @@ func (c *AliDashScopeEmbedding) Check() error { return nil } -func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, text_type string, output_type string, timeoutSec time.Duration) (*EmbeddingResponse, error) { +func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec time.Duration) (*EmbeddingResponse, error) { var r EmbeddingRequest r.Model = modelName r.Input = Input{texts} r.Parameters.Dimension = dim - r.Parameters.TextType = text_type - r.Parameters.OutputType = output_type + r.Parameters.TextType = textType + r.Parameters.OutputType = outputType data, err := json.Marshal(r) if err != nil { return nil, err } if timeoutSec <= 0 { - timeoutSec = 30 + timeoutSec = utils.DefaultTimeout } ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) diff --git a/internal/models/openai/openai_embedding.go b/internal/models/openai/openai_embedding.go index bb6f88be0cd18..433a95a2e32a7 100644 --- a/internal/models/openai/openai_embedding.go +++ b/internal/models/openai/openai_embedding.go @@ -135,20 +135,7 @@ func (c *openAIBase) genReq(modelName string, texts []string, dim int, user stri return &r } -type OpenAIEmbeddingClient struct { - openAIBase -} - -func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient { - return &OpenAIEmbeddingClient{ - openAIBase{ - apiKey: apiKey, - url: url, - }, - } -} - -func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { +func (c *openAIBase) embedding(url string, headers map[string]string, modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { r := c.genReq(modelName, texts, dim, user) data, err := json.Marshal(r) if err != nil { @@ -156,20 +143,23 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim } if timeoutSec <= 0 { - timeoutSec = 30 + timeoutSec = utils.DefaultTimeout } + ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + for key, value := range headers { + req.Header.Set(key, value) + } body, err := utils.RetrySend(req, 3) if err != nil { return nil, err } + var res EmbeddingResponse err = json.Unmarshal(body, &res) if err != nil { @@ -179,6 +169,27 @@ func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim return &res, err } +type OpenAIEmbeddingClient struct { + openAIBase +} + +func NewOpenAIEmbeddingClient(apiKey string, url string) *OpenAIEmbeddingClient { + return &OpenAIEmbeddingClient{ + openAIBase{ + apiKey: apiKey, + url: url, + }, + } +} + +func (c *OpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), + } + return c.embedding(c.url, headers, modelName, texts, dim, user, timeoutSec) +} + type AzureOpenAIEmbeddingClient struct { openAIBase apiVersion string @@ -195,16 +206,6 @@ func NewAzureOpenAIEmbeddingClient(apiKey string, url string) *AzureOpenAIEmbedd } func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, dim int, user string, timeoutSec time.Duration) (*EmbeddingResponse, error) { - r := c.genReq(modelName, texts, dim, user) - data, err := json.Marshal(r) - if err != nil { - return nil, err - } - - if timeoutSec <= 0 { - timeoutSec = 30 - } - base, err := url.Parse(c.url) if err != nil { return nil, err @@ -214,25 +215,11 @@ func (c *AzureOpenAIEmbeddingClient) Embedding(modelName string, texts []string, params := url.Values{} params.Add("api-version", c.apiVersion) base.RawQuery = params.Encode() + url := base.String() - ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) - defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodPost, base.String(), bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("api-key", c.apiKey) - body, err := utils.RetrySend(req, 3) - if err != nil { - return nil, err - } - - var res EmbeddingResponse - err = json.Unmarshal(body, &res) - if err != nil { - return nil, err + headers := map[string]string{ + "Content-Type": "application/json", + "api-key": c.apiKey, } - sort.Sort(&ByIndex{&res}) - return &res, err + return c.embedding(url, headers, modelName, texts, dim, user, timeoutSec) } diff --git a/internal/models/utils/embedding_util.go b/internal/models/utils/embedding_util.go index 1d6e7d916cab2..1383d5740e814 100644 --- a/internal/models/utils/embedding_util.go +++ b/internal/models/utils/embedding_util.go @@ -20,8 +20,11 @@ import ( "fmt" "io" "net/http" + "time" ) +const DefaultTimeout time.Duration = 30 + func send(req *http.Request) ([]byte, error) { resp, err := http.DefaultClient.Do(req) if err != nil { @@ -34,7 +37,7 @@ func send(req *http.Request) ([]byte, error) { return nil, err } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf(string(body)) } return body, nil diff --git a/internal/models/vertexai/vertexai_text_embedding.go b/internal/models/vertexai/vertexai_text_embedding.go index 3842824616214..1a63c59961f81 100644 --- a/internal/models/vertexai/vertexai_text_embedding.go +++ b/internal/models/vertexai/vertexai_text_embedding.go @@ -24,9 +24,9 @@ import ( "net/http" "time" - "github.com/milvus-io/milvus/internal/models/utils" - "golang.org/x/oauth2/google" + + "github.com/milvus-io/milvus/internal/models/utils" ) type Instance struct { @@ -129,7 +129,7 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6 } if timeoutSec <= 0 { - timeoutSec = 30 + timeoutSec = utils.DefaultTimeout } ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) diff --git a/internal/models/vertexai/vertexai_text_embedding_test.go b/internal/models/vertexai/vertexai_text_embedding_test.go index f138d659a3ea4..83b26ac4d4634 100644 --- a/internal/models/vertexai/vertexai_text_embedding_test.go +++ b/internal/models/vertexai/vertexai_text_embedding_test.go @@ -27,7 +27,7 @@ import ( ) func TestEmbeddingClientCheck(t *testing.T) { - mockJsonKey := []byte{1, 2, 3} + mockJSONKey := []byte{1, 2, 3} { c := NewVertexAIEmbedding("mock_url", []byte{}, "mock_scopes", "") err := c.Check() @@ -36,14 +36,14 @@ func TestEmbeddingClientCheck(t *testing.T) { } { - c := NewVertexAIEmbedding("", mockJsonKey, "", "") + c := NewVertexAIEmbedding("", mockJSONKey, "", "") err := c.Check() assert.True(t, err != nil) fmt.Println(err) } { - c := NewVertexAIEmbedding("mock_url", mockJsonKey, "mock_scopes", "") + c := NewVertexAIEmbedding("mock_url", mockJSONKey, "mock_scopes", "") err := c.Check() assert.True(t, err == nil) } diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index 2586ddf37afca..006d383be3146 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -340,14 +340,19 @@ func TestInsertTask_Function(t *testing.T) { AutoID: true, Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, - {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: "max_length", Value: "200"}, - }}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }, IsFunctionOutput: true}, + }, + IsFunctionOutput: true, + }, }, Functions: []*schemapb.FunctionSchema{ { diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 4af9d69a82a67..b255027cf7316 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -487,18 +487,24 @@ func TestSearchTask_WithFunctions(t *testing.T) { AutoID: true, Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, - {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: "max_length", Value: "200"}, - }}, - {FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 102, Name: "vector1", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, - {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, Functions: []*schemapb.FunctionSchema{ { diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index 348ee64313d31..da0b3595cc45e 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -409,14 +409,19 @@ func TestUpsertTask_Function(t *testing.T) { AutoID: true, Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, - {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, + { + FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ {Key: "max_length", Value: "200"}, - }}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }, IsFunctionOutput: true}, + }, + IsFunctionOutput: true, + }, }, Functions: []*schemapb.FunctionSchema{ { diff --git a/internal/util/function/ali_embedding_provider.go b/internal/util/function/ali_embedding_provider.go index 2def771d86971..966c530522e16 100644 --- a/internal/util/function/ali_embedding_provider.go +++ b/internal/util/function/ali_embedding_provider.go @@ -21,7 +21,6 @@ package function import ( "fmt" "os" - "strconv" "strings" "time" @@ -70,17 +69,13 @@ func NewAliDashScopeEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functio case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + return nil, err } case apiKeyParamKey: apiKey = param.Value - case embeddingUrlParamKey: + case embeddingURLParamKey: url = param.Value default: } diff --git a/internal/util/function/alitext_embedding_provider_test.go b/internal/util/function/alitext_embedding_provider_test.go index 73d8613f20a4d..a852b1b74e6ab 100644 --- a/internal/util/function/alitext_embedding_provider_test.go +++ b/internal/util/function/alitext_embedding_provider_test.go @@ -25,12 +25,11 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/suite" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/models/ali" + + "github.com/stretchr/testify/suite" ) func TestAliTextEmbeddingProvider(t *testing.T) { @@ -52,7 +51,8 @@ func (s *AliTextEmbeddingProviderSuite) SetupTest() { {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } s.providers = []string{AliDashScopeProvider} @@ -69,7 +69,7 @@ func createAliProvider(url string, schema *schemapb.FieldSchema, providerName st Params: []*commonpb.KeyValuePair{ {Key: modelNameParamKey, Value: TextEmbeddingV3}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "4"}, }, } @@ -101,7 +101,6 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } - } } @@ -134,7 +133,6 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { data := []string{"sentence", "sentence"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } @@ -163,6 +161,5 @@ func (s *AliTextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { data := []string{"sentence", "sentence2"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } diff --git a/internal/util/function/bedrock_embedding_provider.go b/internal/util/function/bedrock_embedding_provider.go index f9d4d184e4c8e..eb54712ce5499 100644 --- a/internal/util/function/bedrock_embedding_provider.go +++ b/internal/util/function/bedrock_embedding_provider.go @@ -23,16 +23,15 @@ import ( "encoding/json" "fmt" "os" - "strconv" "strings" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type BedrockClient interface { @@ -94,13 +93,9 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) + return nil, err } case awsAccessKeyIdParamKey: awsAccessKeyId = param.Value @@ -178,7 +173,6 @@ func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLim ModelId: aws.String(provider.modelName), ContentType: aws.String("application/json"), }) - if err != nil { return nil, err } diff --git a/internal/util/function/bedrock_text_embedding_provider_test.go b/internal/util/function/bedrock_text_embedding_provider_test.go index eb26aa03ef3f7..e8f08df77e8d1 100644 --- a/internal/util/function/bedrock_text_embedding_provider_test.go +++ b/internal/util/function/bedrock_text_embedding_provider_test.go @@ -47,7 +47,8 @@ func (s *BedrockTextEmbeddingProviderSuite) SetupTest() { {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } s.providers = []string{BedrockProvider} @@ -92,7 +93,6 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}, {0.0, 0.1, 0.2, 0.3}}, ret) } - } } @@ -105,6 +105,5 @@ func (s *BedrockTextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { data := []string{"sentence", "sentence"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } diff --git a/internal/util/function/common.go b/internal/util/function/common.go index 56da30e5ed42f..a6c4fe4840b0e 100644 --- a/internal/util/function/common.go +++ b/internal/util/function/common.go @@ -18,6 +18,11 @@ package function +import ( + "fmt" + "strconv" +) + const ( InsertMode string = "Insert" SearchMode string = "Search" @@ -27,7 +32,7 @@ const ( const ( modelNameParamKey string = "model_name" dimParamKey string = "dim" - embeddingUrlParamKey string = "url" + embeddingURLParamKey string = "url" apiKeyParamKey string = "api_key" ) @@ -80,3 +85,15 @@ const ( vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS" ) + +func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) { + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("dim [%s] is not int", dimStr) + } + + if dim != 0 && dim != fieldDim { + return 0, fmt.Errorf("Field %s's dim is [%d], but embedding's dim is [%d]", fieldName, fieldDim, dim) + } + return dim, nil +} diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index fa3253bf3c7bd..aabcfdf5c0ea2 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -29,10 +29,10 @@ type FunctionBase struct { outputFields []*schemapb.FieldSchema } -func NewFunctionBase(coll *schemapb.CollectionSchema, f_schema *schemapb.FunctionSchema) (*FunctionBase, error) { +func NewFunctionBase(coll *schemapb.CollectionSchema, fSchema *schemapb.FunctionSchema) (*FunctionBase, error) { var base FunctionBase - base.schema = f_schema - for _, fieldName := range f_schema.GetOutputFieldNames() { + base.schema = fSchema + for _, fieldName := range fSchema.GetOutputFieldNames() { for _, field := range coll.GetFields() { if field.GetName() == fieldName { base.outputFields = append(base.outputFields, field) @@ -41,9 +41,9 @@ func NewFunctionBase(coll *schemapb.CollectionSchema, f_schema *schemapb.Functio } } - if len(base.outputFields) != len(f_schema.GetOutputFieldNames()) { + if len(base.outputFields) != len(fSchema.GetOutputFieldNames()) { return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema", - coll.Name, f_schema.Name) + coll.Name, fSchema.Name) } return &base, nil } diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index 6f2469cca9173..011380ae5a226 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -62,8 +62,8 @@ func createFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSc } func CheckFunctions(schema *schemapb.CollectionSchema) error { - for _, f_schema := range schema.Functions { - if _, err := createFunction(schema, f_schema); err != nil { + for _, fSchema := range schema.Functions { + if _, err := createFunction(schema, fSchema); err != nil { return err } } @@ -77,12 +77,12 @@ func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, executor := &FunctionExecutor{ runners: make(map[int64]Runner), } - for _, f_schema := range schema.Functions { - if runner, err := createFunction(schema, f_schema); err != nil { + for _, fSchema := range schema.Functions { + if runner, err := createFunction(schema, fSchema); err != nil { return nil, err } else { if runner != nil { - executor.runners[f_schema.GetOutputFieldIds()[0]] = runner + executor.runners[fSchema.GetOutputFieldIds()[0]] = runner } } } @@ -200,7 +200,6 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ } else { outputs <- map[int64][]byte{idx: newHolder} } - }(runner, int64(idx)) } } @@ -222,9 +221,8 @@ func (executor *FunctionExecutor) prcessAdvanceSearch(req *internalpb.SearchRequ func (executor *FunctionExecutor) ProcessSearch(req *internalpb.SearchRequest) error { if !req.IsAdvanced { return executor.prcessSearch(req) - } else { - return executor.prcessAdvanceSearch(req) - } + } + return executor.prcessAdvanceSearch(req) } func (executor *FunctionExecutor) processSingleBulkInsert(runner Runner, data *storage.InsertData) (map[storage.FieldID]storage.FieldData, error) { diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index a791760351fb7..a38360d343660 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -49,16 +49,20 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, }, IsFunctionOutput: true, }, - {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "8"}, - }, IsFunctionOutput: true}, + }, + IsFunctionOutput: true, + }, }, Functions: []*schemapb.FunctionSchema{ { @@ -72,7 +76,7 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "4"}, }, }, @@ -87,17 +91,15 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch {Key: Provider, Value: OpenAIProvider}, {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "8"}, }, }, }, } - } func (s *FunctionExecutorSuite) createMsg(texts []string) *msgstream.InsertMsg { - data := []*schemapb.FieldData{} f := schemapb.FieldData{ Type: schemapb.DataType_VarChar, @@ -173,7 +175,6 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - })) defer ts.Close() schema := s.creataSchema(ts.URL) diff --git a/internal/util/function/function_util.go b/internal/util/function/function_util.go index bd0265336baa7..240e13615b4f7 100644 --- a/internal/util/function/function_util.go +++ b/internal/util/function/function_util.go @@ -26,8 +26,8 @@ import ( func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool { // Determine whether the column corresponding to outputIDs contains functions, except bm25 function, // if outputIDs is empty, check all cols - for _, f_schema := range functions { - switch f_schema.GetType() { + for _, fSchema := range functions { + switch fSchema.GetType() { case schemapb.FunctionType_BM25: case schemapb.FunctionType_Unknown: default: @@ -35,7 +35,7 @@ func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool return true } else { for _, id := range outputIDs { - if f_schema.GetOutputFieldIds()[0] == id { + if fSchema.GetOutputFieldIds()[0] == id { return true } } @@ -47,14 +47,14 @@ func HasFunctions(functions []*schemapb.FunctionSchema, outputIDs []int64) bool func GetOutputIDFunctionsMap(functions []*schemapb.FunctionSchema) (map[int64]*schemapb.FunctionSchema, error) { outputIdMap := map[int64]*schemapb.FunctionSchema{} - for _, f_schema := range functions { - switch f_schema.GetType() { + for _, fSchema := range functions { + switch fSchema.GetType() { case schemapb.FunctionType_BM25: default: - if len(f_schema.OutputFieldIds) != 1 { - return nil, merr.WrapErrParameterInvalidMsg("Function [%s]'s outputs err, only supports one outputs", f_schema.Name) + if len(fSchema.OutputFieldIds) != 1 { + return nil, merr.WrapErrParameterInvalidMsg("Function [%s]'s outputs err, only supports one outputs", fSchema.Name) } - outputIdMap[f_schema.OutputFieldIds[0]] = f_schema + outputIdMap[fSchema.OutputFieldIds[0]] = fSchema } } return outputIdMap, nil diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go index 4cb181a7a0c4f..c071a2056df72 100644 --- a/internal/util/function/mock_embedding_service.go +++ b/internal/util/function/mock_embedding_service.go @@ -25,10 +25,11 @@ import ( "net/http" "net/http/httptest" - "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/milvus-io/milvus/internal/models/ali" "github.com/milvus-io/milvus/internal/models/openai" "github.com/milvus-io/milvus/internal/models/vertexai" + + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" ) func mockEmbedding(texts []string, dim int) [][]float32 { @@ -69,7 +70,6 @@ func CreateOpenAIEmbeddingServer() *httptest.Server { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - })) return ts } @@ -129,7 +129,6 @@ func CreateVertexAIEmbeddingServer() *httptest.Server { w.WriteHeader(http.StatusOK) data, _ := json.Marshal(res) w.Write(data) - })) return ts } diff --git a/internal/util/function/openai_embedding_provider.go b/internal/util/function/openai_embedding_provider.go index 32cfb945509f7..8b5f53bc7fd2c 100644 --- a/internal/util/function/openai_embedding_provider.go +++ b/internal/util/function/openai_embedding_provider.go @@ -21,7 +21,6 @@ package function import ( "fmt" "os" - "strconv" "strings" "time" @@ -89,19 +88,15 @@ func newOpenAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", fieldSchema.Name, fieldDim, dim) + return nil, err } case userParamKey: user = param.Value case apiKeyParamKey: apiKey = param.Value - case embeddingUrlParamKey: + case embeddingURLParamKey: url = param.Value default: } diff --git a/internal/util/function/openai_text_embedding_provider_test.go b/internal/util/function/openai_text_embedding_provider_test.go index 395ecf06cdc9d..09b120e0603d5 100644 --- a/internal/util/function/openai_text_embedding_provider_test.go +++ b/internal/util/function/openai_text_embedding_provider_test.go @@ -25,12 +25,11 @@ import ( "net/http/httptest" "testing" - "github.com/stretchr/testify/suite" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/models/openai" + + "github.com/stretchr/testify/suite" ) func TestOpenAITextEmbeddingProvider(t *testing.T) { @@ -49,10 +48,12 @@ func (s *OpenAITextEmbeddingProviderSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } s.providers = []string{OpenAIProvider, AzureOpenAIProvider} @@ -70,7 +71,7 @@ func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: apiKeyParamKey, Value: "mock"}, {Key: dimParamKey, Value: "4"}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, }, } switch providerName { @@ -103,7 +104,6 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } - } } @@ -141,7 +141,6 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { data := []string{"sentence", "sentence"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } @@ -174,6 +173,5 @@ func (s *OpenAITextEmbeddingProviderSuite) TestEmbeddingNubmerNotMatch() { data := []string{"sentence", "sentence2"} _, err2 := provder.CallEmbedding(data, false, InsertMode) s.Error(err2) - } } diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index d359514d6cfbc..030679df812fb 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -26,7 +26,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/util/funcutil" - // "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ( @@ -134,7 +133,6 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s default: return nil, fmt.Errorf("Unsupported embedding service provider: [%s] , list of supported [%s, %s, %s, %s]", provider, OpenAIProvider, AzureOpenAIProvider, AliDashScopeProvider, BedrockProvider) } - } func (runner *TextEmebddingFunction) MaxBatch() int { @@ -147,7 +145,7 @@ func (runner *TextEmebddingFunction) ProcessInsert(inputs []*schemapb.FieldData) } if inputs[0].Type != schemapb.DataType_VarChar { - return nil, fmt.Errorf("Text embedding only supports varchar field, the input is not varchar") + return nil, fmt.Errorf("Text embedding only supports varchar field as input field, but got %s", schemapb.DataType_name[int32(inputs[0].Type)]) } texts := inputs[0].GetScalars().GetStringData().GetData() @@ -193,11 +191,11 @@ func (runner *TextEmebddingFunction) ProcessSearch(placeholderGroup *commonpb.Pl func (runner *TextEmebddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) { if len(inputs) != 1 { - return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) + return nil, fmt.Errorf("TextEmbedding function only receives one input, bug got [%d]", len(inputs)) } if inputs[0].GetDataType() != schemapb.DataType_VarChar { - return nil, fmt.Errorf("OpenAIEmbedding only supports varchar field, the input is not varchar") + return nil, fmt.Errorf(" only supports varchar field, the input is not varchar") } texts, ok := inputs[0].GetDataRows().([]string) diff --git a/internal/util/function/text_embedding_function_test.go b/internal/util/function/text_embedding_function_test.go index ce0bfc86dbf51..353684e55b838 100644 --- a/internal/util/function/text_embedding_function_test.go +++ b/internal/util/function/text_embedding_function_test.go @@ -42,10 +42,12 @@ func (s *TextEmbeddingFunctionSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } } @@ -74,7 +76,6 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { ts := CreateOpenAIEmbeddingServer() defer ts.Close() { - runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_Unknown, @@ -87,7 +88,7 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: ts.URL}, + {Key: embeddingURLParamKey, Value: ts.URL}, }, }) s.NoError(err) @@ -106,7 +107,6 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) } } - { runner, err := NewTextEmbeddingFunction(s.schema, &schemapb.FunctionSchema{ @@ -121,7 +121,7 @@ func (s *TextEmbeddingFunctionSuite) TestOpenAIEmbedding() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: ts.URL}, + {Key: embeddingURLParamKey, Value: ts.URL}, }, }) s.NoError(err) @@ -158,7 +158,7 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { {Key: modelNameParamKey, Value: TextEmbeddingV3}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: ts.URL}, + {Key: embeddingURLParamKey, Value: ts.URL}, }, }) s.NoError(err) @@ -176,7 +176,6 @@ func (s *TextEmbeddingFunctionSuite) TestAliEmbedding() { ret, _ := runner.ProcessInsert(data) s.Equal([]float32{0.0, 0.1, 0.2, 0.3, 1.0, 1.1, 1.2, 1.3, 2.0, 2.1, 2.2, 2.3}, ret[0].GetVectors().GetFloatVector().Data) } - } func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { @@ -187,10 +186,12 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } @@ -206,7 +207,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) @@ -219,14 +220,18 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, - {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + }, + }, + { + FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } _, err := NewTextEmbeddingFunction(schema, &schemapb.FunctionSchema{ @@ -241,7 +246,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) @@ -261,7 +266,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-002"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) @@ -281,7 +286,7 @@ func (s *TextEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: modelNameParamKey, Value: "text-embedding-ada-004"}, {Key: dimParamKey, Value: "4"}, {Key: apiKeyParamKey, Value: "mock"}, - {Key: embeddingUrlParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: "mock"}, }, }) s.Error(err) diff --git a/internal/util/function/vertexai_embedding_provider.go b/internal/util/function/vertexai_embedding_provider.go index 1d9c997571dcf..8eb4ad52ff1f9 100644 --- a/internal/util/function/vertexai_embedding_provider.go +++ b/internal/util/function/vertexai_embedding_provider.go @@ -21,7 +21,6 @@ package function import ( "fmt" "os" - "strconv" "strings" "sync" "time" @@ -102,14 +101,10 @@ func NewVertextAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSc case modelNameParamKey: modelName = param.Value case dimParamKey: - dim, err = strconv.ParseInt(param.Value, 10, 64) + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) if err != nil { - return nil, fmt.Errorf("dim [%s] is not int", param.Value) - } - - if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", functionSchema.Name, fieldDim, dim) - } + return nil, err + } case locationParamKey: location = param.Value case projectIDParamKey: diff --git a/internal/util/function/vertexai_embedding_provider_test.go b/internal/util/function/vertexai_embedding_provider_test.go index 2c18b133cc974..10a9093d69634 100644 --- a/internal/util/function/vertexai_embedding_provider_test.go +++ b/internal/util/function/vertexai_embedding_provider_test.go @@ -47,10 +47,12 @@ func (s *VertextAITextEmbeddingProviderSuite) SetupTest() { Fields: []*schemapb.FieldSchema{ {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, - {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: "dim", Value: "4"}, - }}, + }, + }, }, } } @@ -68,7 +70,7 @@ func createVertextAIProvider(url string, schema *schemapb.FieldSchema) (TextEmbe {Key: locationParamKey, Value: "mock_local"}, {Key: projectIDParamKey, Value: "mock_id"}, {Key: taskTypeParamKey, Value: vertexAICodeRetrival}, - {Key: embeddingUrlParamKey, Value: url}, + {Key: embeddingURLParamKey, Value: url}, {Key: dimParamKey, Value: "4"}, }, } @@ -95,7 +97,6 @@ func (s *VertextAITextEmbeddingProviderSuite) TestEmbedding() { ret, _ := provder.CallEmbedding(data, false, SearchMode) s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) } - } func (s *VertextAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() {