diff --git a/go.mod b/go.mod index 9eff0323dba59..f8fcbca181cb2 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/klauspost/compress v1.17.7 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497 + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241008034515-85ccff4d57fe github.com/minio/minio-go/v7 v7.0.61 github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81 github.com/prometheus/client_golang v1.14.0 @@ -56,6 +56,7 @@ require ( ) require ( + cloud.google.com/go/storage v1.43.0 github.com/bits-and-blooms/bitset v1.10.0 github.com/bytedance/sonic v1.12.2 github.com/cenkalti/backoff/v4 v4.2.1 @@ -70,6 +71,7 @@ require ( github.com/valyala/fastjson v1.6.4 github.com/zeebo/xxh3 v1.0.2 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 + google.golang.org/api v0.187.0 google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v3 v3.0.1 ) @@ -80,7 +82,6 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect cloud.google.com/go/iam v1.1.8 // indirect - cloud.google.com/go/storage v1.43.0 // indirect github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/99designs/keyring v1.2.1 // indirect github.com/AthenZ/athenz v1.10.39 // indirect @@ -246,7 +247,6 @@ require ( golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect gonum.org/v1/gonum v0.11.0 // indirect - google.golang.org/api v0.187.0 // indirect google.golang.org/genproto v0.0.0-20240624140628-dc46fd24d27d // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf // indirect diff --git a/go.sum b/go.sum index b5108edda424b..88728d620c6a5 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,8 @@ cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1 cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= cloud.google.com/go/iam v1.1.8 h1:r7umDwhj+BQyz0ScZMp4QrGXjSTI3ZINnpgU2nlB/K0= cloud.google.com/go/iam v1.1.8/go.mod h1:GvE6lyMmfxXauzNq8NbgJbeVQNspG+tcdL/W8QO1+zE= +cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= +cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= @@ -412,9 +414,12 @@ github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= +github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= +github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -620,8 +625,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497 h1:t4sQMbSy05p8qgMGvEGyLYYLoZ9fD1dushS1bj5X6+0= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241008034515-85ccff4d57fe h1:nvdFNyfPEKKL3/q2mB/H3BWh7rfIvEr36gwqDf7siEE= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20241008034515-85ccff4d57fe/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/pulsar-client-go v0.12.1 h1:O2JZp1tsYiO7C0MQ4hrUY/aJXnn2Gry6hpm7UodghmE= github.com/milvus-io/pulsar-client-go v0.12.1/go.mod h1:dkutuH4oS2pXiGm+Ti7fQZ4MRjrMPZ8IJeEGAWMeckk= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= @@ -1143,8 +1148,6 @@ golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo= -golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1417,12 +1420,8 @@ google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaE google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto v0.0.0-20210624195500-8bfb893ecb84/go.mod h1:SzzZ/N+nwJDaO1kznhnlzqS8ocJICar6hYhVyhi++24= google.golang.org/genproto v0.0.0-20220503193339-ba3ae3f07e29/go.mod h1:RAyBrSAP7Fh3Nc84ghnVLDPuV51xc9agzmm4Ph6i0Q4= -google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= -google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17/go.mod h1:J7XzRzVy1+IPwWHZUzoD0IccYZIrXILAQpc+Qy9CMhY= google.golang.org/genproto v0.0.0-20240624140628-dc46fd24d27d h1:PksQg4dV6Sem3/HkBX+Ltq8T0ke0PKIRBNBatoDTVls= google.golang.org/genproto v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:s7iA721uChleev562UJO2OYB0PPT9CMFjV+Ce7VJH5M= -google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 h1:7whR9kGa5LUwFtpLm2ArCEejtnxlGeLbAyjFY8sGNFw= -google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU= google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc= google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c= google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf h1:liao9UHurZLtiEwBgT9LMOnKYsHze6eA6w1KQCMVN2Q= diff --git a/internal/util/function/function_base.go b/internal/util/function/function_base.go index 54cd55d18d496..1914ed7dbaae2 100644 --- a/internal/util/function/function_base.go +++ b/internal/util/function/function_base.go @@ -25,24 +25,14 @@ import ( ) -type RunnerMode int - -const ( - InsertMode RunnerMode = iota - SearchMode -) - - type FunctionBase struct { schema *schemapb.FunctionSchema outputFields []*schemapb.FieldSchema - mode RunnerMode } -func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*FunctionBase, error) { +func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*FunctionBase, error) { var base FunctionBase base.schema = schema - base.mode = mode for _, field_id := range schema.GetOutputFieldIds() { for _, field := range coll.GetFields() { if field.GetFieldID() == field_id { @@ -53,7 +43,8 @@ func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, m } if len(base.outputFields) != len(schema.GetOutputFieldIds()) { - return &base, fmt.Errorf("Collection [%s]'s function [%s]'s outputs mismatch schema", coll.Name, schema.Name) + return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema", + coll.Name, schema.Name) } return &base, nil } diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go new file mode 100644 index 0000000000000..beaaf4b1e8e2a --- /dev/null +++ b/internal/util/function/function_executor.go @@ -0,0 +1,125 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + + +package function + + +import ( + "fmt" + "sync" + + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + + +type Runner interface { + GetSchema() *schemapb.FunctionSchema + GetOutputFields() []*schemapb.FieldSchema + + MaxBatch() int + ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) +} + + +type FunctionExecutor struct { + runners []Runner +} + +func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { + executor := new(FunctionExecutor) + for _, f_schema := range schema.Functions { + switch f_schema.GetType() { + case schemapb.FunctionType_BM25: + case schemapb.FunctionType_OpenAIEmbedding: + f, err := NewOpenAIEmbeddingFunction(schema, f_schema) + if err != nil { + return nil, err + } + executor.runners = append(executor.runners, f) + default: + return nil, fmt.Errorf("unknown functionRunner type %s", f_schema.GetType().String()) + } + } + return executor, nil +} + +func (executor *FunctionExecutor)processSingleFunction(idx int, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { + runner := executor.runners[idx] + inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds)) + for _, id := range runner.GetSchema().InputFieldIds { + for _, field := range msg.FieldsData{ + if field.FieldId == id { + inputs = append(inputs, field) + } + } + } + + if len(inputs) != len(runner.GetSchema().InputFieldIds) { + return nil, fmt.Errorf("Input field not found") + } + + outputs, err := runner.ProcessInsert(inputs) + if err != nil { + return nil, err + } + return outputs, nil +} + +func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error { + numRows := msg.NumRows + for _, runner := range executor.runners { + if numRows > uint64(runner.MaxBatch()) { + return fmt.Errorf("numRows [%d] > function [%s]'s max batch [%d]", numRows, runner.GetSchema().Name, runner.MaxBatch()) + } + } + + outputs := make(chan []*schemapb.FieldData, len(executor.runners)) + errChan := make(chan error, len(executor.runners)) + var wg sync.WaitGroup + for idx, _ := range executor.runners { + wg.Add(1) + go func(index int) { + defer wg.Done() + data, err := executor.processSingleFunction(index, msg) + if err != nil { + errChan <- err + } else { + outputs <- data + } + + }(idx) + } + wg.Wait() + close(errChan) + close(outputs) + for err := range errChan { + return err + } + for output := range outputs { + msg.FieldsData = append(msg.FieldsData, output...) + } + return nil +} + + +func (executor *FunctionExecutor)ProcessSearch(msg *milvuspb.SearchRequest) error { + return nil +} diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go new file mode 100644 index 0000000000000..2d0a701352f75 --- /dev/null +++ b/internal/util/function/function_executor_test.go @@ -0,0 +1,213 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + + +package function + + +import ( + "io" + "testing" + "net/http" + "net/http/httptest" + "encoding/json" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/models" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" +) + + +func TestFunctionExecutor(t *testing.T) { + suite.Run(t, new(FunctionExecutorSuite)) +} + +type FunctionExecutorSuite struct { + suite.Suite +} + + +func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSchema{ + return &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + {FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }}, + {FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "8"}, + }}, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "test", + Type: schemapb.FunctionType_OpenAIEmbedding, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: url}, + {Key: DimParamKey, Value: "4"}, + }, + }, + { + Name: "test", + Type: schemapb.FunctionType_OpenAIEmbedding, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{103}, + Params: []*commonpb.KeyValuePair{ + {Key: ModelNameParamKey, Value: "text-embedding-ada-002"}, + {Key: OpenaiApiKeyParamKey, Value: "mock"}, + {Key: OpenaiEmbeddingUrlParamKey, Value: url}, + {Key: DimParamKey, Value: "8"}, + }, + }, + }, + } + +} + +func (s *FunctionExecutorSuite)createMsg(texts []string) *msgstream.InsertMsg{ + + data := []*schemapb.FieldData{} + f := schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldId: 101, + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: texts, + }, + }, + }, + }, + } + data = append(data, &f) + + msg := msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + FieldsData: data, + }, + } + return &msg +} + +func (s *FunctionExecutorSuite)createEmbedding(texts []string, dim int) [][]float32{ + embeddings := make([][]float32, 0) + for i := 0; i < len(texts); i++ { + f := float32(i) + emb := make([]float32, 0) + for j := 0; j < dim; j++ { + emb = append(emb, f + float32(j) * 0.1) + } + embeddings = append(embeddings, emb) + } + return embeddings +} + +func (s *FunctionExecutorSuite) TestExecutor() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request){ + var req models.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + embs := s.createEmbedding(req.Input, req.Dimensions) + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: embs[i], + Index: i, + }) + } + + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + + })) + + defer ts.Close() + schema := s.creataSchema(ts.URL) + exec, err := newFunctionExecutor(schema) + s.NoError(err) + msg := s.createMsg([]string{"sentence", "sentence"}) + exec.ProcessInsert(msg) + s.Equal(len(msg.FieldsData), 3) +} + +func (s *FunctionExecutorSuite) TestErrorEmbedding() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request){ + var req models.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + + var res models.EmbeddingResponse + res.Object = "list" + res.Model = "text-embedding-3-small" + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, models.EmbeddingData{ + Object: "embedding", + Embedding: []float32{}, + Index: i, + }) + } + + res.Usage = models.Usage{ + PromptTokens: 1, + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + + })) + defer ts.Close() + schema := s.creataSchema(ts.URL) + exec, err := newFunctionExecutor(schema) + s.NoError(err) + msg := s.createMsg([]string{"sentence", "sentence"}) + err = exec.ProcessInsert(msg) + s.Error(err) +} + +func (s *FunctionExecutorSuite) TestErrorSchema() { + schema := s.creataSchema("http://localhost") + schema.Functions[0].Type = schemapb.FunctionType_Unknown + _, err := newFunctionExecutor(schema) + s.Error(err) +} diff --git a/internal/util/function/openai_embedding_function.go b/internal/util/function/openai_embedding_function.go index 10182cf9fa7cc..11151f92a6adb 100644 --- a/internal/util/function/openai_embedding_function.go +++ b/internal/util/function/openai_embedding_function.go @@ -38,8 +38,7 @@ const ( const ( maxBatch = 128 - timeoutSec = 60 - maxRowNum = 60 * maxBatch + timeoutSec = 30 ) const ( @@ -52,7 +51,7 @@ const ( type OpenAIEmbeddingFunction struct { - base *FunctionBase + FunctionBase fieldDim int64 client *models.OpenAIEmbeddingClient @@ -79,12 +78,12 @@ func createOpenAIEmbeddingClient(apiKey string, url string) (*models.OpenAIEmbed return &c, nil } -func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*OpenAIEmbeddingFunction, error) { +func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*OpenAIEmbeddingFunction, error) { if len(schema.GetOutputFieldIds()) != 1 { - return nil, fmt.Errorf("OpenAIEmbedding function should only have one output field, but now %d", len(schema.GetOutputFieldIds())) + return nil, fmt.Errorf("OpenAIEmbedding function should only have one output field, but now is %d", len(schema.GetOutputFieldIds())) } - base, err := NewBase(coll, schema, mode) + base, err := NewBase(coll, schema) if err != nil { return nil, err } @@ -107,13 +106,13 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap case ModelNameParamKey: modelName = param.Value case DimParamKey: - dim, err := strconv.ParseInt(param.Value, 10, 64) + dim, err = strconv.ParseInt(param.Value, 10, 64) if err != nil { return nil, fmt.Errorf("dim [%s] is not int", param.Value) } if dim != 0 && dim != fieldDim { - return nil, fmt.Errorf("Dim in field's schema is [%d], but embeding dim is [%d]", fieldDim, dim) + return nil, fmt.Errorf("Field %s's dim is [%d], but embeding's dim is [%d]", schema.Name, fieldDim, dim) } case UserParamKey: user = param.Value @@ -131,7 +130,7 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap } runner := OpenAIEmbeddingFunction{ - base: base, + FunctionBase: *base, client: c, fieldDim: fieldDim, modelName: modelName, @@ -146,7 +145,16 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap return &runner, nil } -func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { +func (runner *OpenAIEmbeddingFunction)MaxBatch() int { + return 5 * maxBatch +} + + +func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { + return runner.Run(inputs) +} + +func (runner *OpenAIEmbeddingFunction) Run( inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) { if len(inputs) != 1 { return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs)) } @@ -161,15 +169,15 @@ func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*sch } numRows := len(texts) - if numRows > maxRowNum { - return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", maxRowNum, numRows) + if numRows > runner.MaxBatch() { + return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows) } var output_field schemapb.FieldData - output_field.FieldId = runner.base.outputFields[0].FieldID - output_field.FieldName = runner.base.outputFields[0].Name - output_field.Type = runner.base.outputFields[0].DataType - output_field.IsDynamic = runner.base.outputFields[0].IsDynamic + output_field.FieldId = runner.outputFields[0].FieldID + output_field.FieldName = runner.outputFields[0].Name + output_field.Type = runner.outputFields[0].DataType + output_field.IsDynamic = runner.outputFields[0].IsDynamic data := make([]float32, 0, numRows * int(runner.fieldDim)) for i := 0; i < numRows; i += maxBatch { end := i + maxBatch @@ -185,8 +193,8 @@ func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*sch } for _, item := range resp.Data { if len(item.Embedding) != int(runner.fieldDim) { - return nil, fmt.Errorf("Dim in field's schema is [%d], but embeding dim is [%d]", - runner.fieldDim, len(resp.Data[0].Embedding)) + return nil, fmt.Errorf("The required embedding dim for field [%s] is [%d], but the embedding obtained from the model is [%d]", + output_field.FieldName, runner.fieldDim, len(item.Embedding)) } data = append(data, item.Embedding...) } diff --git a/internal/util/function/openai_embedding_function_test.go b/internal/util/function/openai_embedding_function_test.go index 68420cbeddbeb..f295edbbe577d 100644 --- a/internal/util/function/openai_embedding_function_test.go +++ b/internal/util/function/openai_embedding_function_test.go @@ -21,7 +21,6 @@ package function import ( "io" - "fmt" "testing" "net/http" "net/http/httptest" @@ -103,7 +102,7 @@ func createRunner(url string, schema *schemapb.CollectionSchema) (*OpenAIEmbeddi {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: url}, }, - }, InsertMode) + }) } func (s *OpenAIEmbeddingFunctionSuite) TestEmbedding() { @@ -186,8 +185,6 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingDimNotMatch() { data := createData([]string{"sentence", "sentence"}) _, err2 := runner.Run(data) s.Error(err2) - fmt.Println(err2.Error()) - // s.NoError(err2) } func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() { @@ -218,8 +215,6 @@ func (s *OpenAIEmbeddingFunctionSuite) TestEmbeddingNubmerNotMatch() { data := createData([]string{"sentence", "sentence2"}) _, err2 := runner.Run(data) s.Error(err2) - fmt.Println(err2.Error()) - // s.NoError(err2) } func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { @@ -248,9 +243,8 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) - fmt.Println(err.Error()) } // outputfield number mismatc @@ -281,9 +275,8 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) - fmt.Println(err.Error()) } // outputfield miss @@ -299,9 +292,8 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) - fmt.Println(err.Error()) } // error model name @@ -317,9 +309,8 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { {Key: OpenaiApiKeyParamKey, Value: "mock"}, {Key: OpenaiEmbeddingUrlParamKey, Value: "mock"}, }, - }, InsertMode) + }) s.Error(err) - fmt.Println(err.Error()) } // no openai api key @@ -332,8 +323,7 @@ func (s *OpenAIEmbeddingFunctionSuite) TestRunnerParamsErr() { Params: []*commonpb.KeyValuePair{ {Key: ModelNameParamKey, Value: "text-embedding-ada-003"}, }, - }, InsertMode) + }) s.Error(err) - fmt.Println(err.Error()) } }