From 0f309609b0bb85fe3ea0f1e6790f7a6dc9fa2142 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Wed, 9 Oct 2024 11:25:05 +0800 Subject: [PATCH] Insert & Upsert support functions Signed-off-by: junjie.jiang --- internal/proxy/task_insert.go | 13 +++++++++++++ internal/proxy/task_upsert.go | 15 +++++++++++++++ internal/util/function/function_executor.go | 7 +++++-- internal/util/function/function_executor_test.go | 6 +++--- 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index fd86fc9d3c343..ffec44204438a 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -12,6 +12,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -132,6 +133,18 @@ func (it *insertTask) PreExecute(ctx context.Context) error { } it.schema = schema.CollectionSchema + // Calculate embedding fields + exec, err := function.NewFunctionExecutor(schema.CollectionSchema) + if err != nil { + return err + } + + if !exec.Empty() { + if err := exec.ProcessInsert(it.insertMsg); err != nil { + return err + } + } + rowNums := uint32(it.insertMsg.NRows()) // set insertTask.rowIDs var rowIDBegin UniqueID diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 154bbba8753b7..e3e32195369c9 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -152,6 +153,20 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { return err } + // Calculate embedding fields + { + exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema) + if err != nil { + return err + } + + if !exec.Empty() { + if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil { + return err + } + } + } + rowNums := uint32(it.upsertMsg.InsertMsg.NRows()) // set upsertTask.insertRequest.rowIDs tr := timerecord.NewTimeRecorder("applyPK") diff --git a/internal/util/function/function_executor.go b/internal/util/function/function_executor.go index beaaf4b1e8e2a..abd490dc30fbd 100644 --- a/internal/util/function/function_executor.go +++ b/internal/util/function/function_executor.go @@ -43,7 +43,7 @@ type FunctionExecutor struct { runners []Runner } -func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { +func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) { executor := new(FunctionExecutor) for _, f_schema := range schema.Functions { switch f_schema.GetType() { @@ -61,6 +61,10 @@ func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, return executor, nil } +func (executor *FunctionExecutor)Empty() bool { + return len(executor.runners) != 0 +} + func (executor *FunctionExecutor)processSingleFunction(idx int, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) { runner := executor.runners[idx] inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds)) @@ -119,7 +123,6 @@ func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error { return nil } - func (executor *FunctionExecutor)ProcessSearch(msg *milvuspb.SearchRequest) error { return nil } diff --git a/internal/util/function/function_executor_test.go b/internal/util/function/function_executor_test.go index 2d0a701352f75..8eef0b7b55857 100644 --- a/internal/util/function/function_executor_test.go +++ b/internal/util/function/function_executor_test.go @@ -162,7 +162,7 @@ func (s *FunctionExecutorSuite) TestExecutor() { defer ts.Close() schema := s.creataSchema(ts.URL) - exec, err := newFunctionExecutor(schema) + exec, err := NewFunctionExecutor(schema) s.NoError(err) msg := s.createMsg([]string{"sentence", "sentence"}) exec.ProcessInsert(msg) @@ -198,7 +198,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() { })) defer ts.Close() schema := s.creataSchema(ts.URL) - exec, err := newFunctionExecutor(schema) + exec, err := NewFunctionExecutor(schema) s.NoError(err) msg := s.createMsg([]string{"sentence", "sentence"}) err = exec.ProcessInsert(msg) @@ -208,6 +208,6 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() { func (s *FunctionExecutorSuite) TestErrorSchema() { schema := s.creataSchema("http://localhost") schema.Functions[0].Type = schemapb.FunctionType_Unknown - _, err := newFunctionExecutor(schema) + _, err := NewFunctionExecutor(schema) s.Error(err) }