Skip to content

Commit

Permalink
Insert & Upsert support functions
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <[email protected]>
  • Loading branch information
junjiejiangjjj committed Oct 9, 2024
1 parent 247588f commit 0f30960
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 5 deletions.
13 changes: 13 additions & 0 deletions internal/proxy/task_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
Expand Down Expand Up @@ -132,6 +133,18 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}
it.schema = schema.CollectionSchema

// Calculate embedding fields
exec, err := function.NewFunctionExecutor(schema.CollectionSchema)
if err != nil {
return err
}

if !exec.Empty() {
if err := exec.ProcessInsert(it.insertMsg); err != nil {
return err
}
}

rowNums := uint32(it.insertMsg.NRows())
// set insertTask.rowIDs
var rowIDBegin UniqueID
Expand Down
15 changes: 15 additions & 0 deletions internal/proxy/task_upsert.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
Expand Down Expand Up @@ -152,6 +153,20 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
return err
}

// Calculate embedding fields
{
exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema)
if err != nil {
return err
}

if !exec.Empty() {
if err := exec.ProcessInsert(it.upsertMsg.InsertMsg); err != nil {
return err
}
}
}

rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
// set upsertTask.insertRequest.rowIDs
tr := timerecord.NewTimeRecorder("applyPK")
Expand Down
7 changes: 5 additions & 2 deletions internal/util/function/function_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type FunctionExecutor struct {
runners []Runner
}

func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) {
func NewFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) {
executor := new(FunctionExecutor)
for _, f_schema := range schema.Functions {
switch f_schema.GetType() {
Expand All @@ -61,6 +61,10 @@ func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor,
return executor, nil
}

func (executor *FunctionExecutor)Empty() bool {
return len(executor.runners) != 0
}

func (executor *FunctionExecutor)processSingleFunction(idx int, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error) {
runner := executor.runners[idx]
inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds))
Expand Down Expand Up @@ -119,7 +123,6 @@ func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error {
return nil
}


func (executor *FunctionExecutor)ProcessSearch(msg *milvuspb.SearchRequest) error {
return nil
}
6 changes: 3 additions & 3 deletions internal/util/function/function_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (s *FunctionExecutorSuite) TestExecutor() {

defer ts.Close()
schema := s.creataSchema(ts.URL)
exec, err := newFunctionExecutor(schema)
exec, err := NewFunctionExecutor(schema)
s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"})
exec.ProcessInsert(msg)
Expand Down Expand Up @@ -198,7 +198,7 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() {
}))
defer ts.Close()
schema := s.creataSchema(ts.URL)
exec, err := newFunctionExecutor(schema)
exec, err := NewFunctionExecutor(schema)
s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"})
err = exec.ProcessInsert(msg)
Expand All @@ -208,6 +208,6 @@ func (s *FunctionExecutorSuite) TestErrorEmbedding() {
func (s *FunctionExecutorSuite) TestErrorSchema() {
schema := s.creataSchema("http://localhost")
schema.Functions[0].Type = schemapb.FunctionType_Unknown
_, err := newFunctionExecutor(schema)
_, err := NewFunctionExecutor(schema)
s.Error(err)
}

0 comments on commit 0f30960

Please sign in to comment.