Skip to content

Commit

Permalink
Add function executor
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <[email protected]>
  • Loading branch information
junjiejiangjjj committed Oct 8, 2024
1 parent 7f76022 commit 38243dd
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 29 deletions.
12 changes: 1 addition & 11 deletions internal/util/function/function_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,14 @@ import (
)


type RunnerMode int

const (
InsertMode RunnerMode = iota
SearchMode
)


type FunctionBase struct {
schema *schemapb.FunctionSchema
outputFields []*schemapb.FieldSchema
mode RunnerMode
}

func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*FunctionBase, error) {
func NewBase(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*FunctionBase, error) {
var base FunctionBase
base.schema = schema
base.mode = mode
for _, field_id := range schema.GetOutputFieldIds() {
for _, field := range coll.GetFields() {
if field.GetFieldID() == field_id {
Expand Down
120 changes: 120 additions & 0 deletions internal/util/function/function_executor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/


package function


import (
"fmt"
"sync"

"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)


type Runner interface {
GetSchema() *schemapb.FunctionSchema
GetOutputFields() []*schemapb.FieldSchema

MaxBatch() int
ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error)
}


type FunctionExecutor struct {
runners []Runner
}


func newFunctionExecutor(schema *schemapb.CollectionSchema) (*FunctionExecutor, error) {
executor := new(FunctionExecutor)
for _, f_schema := range schema.Functions {
switch f_schema.GetType() {
case schemapb.FunctionType_BM25:
case schemapb.FunctionType_OpenAIEmbedding:
f, err := NewOpenAIEmbeddingFunction(schema, f_schema)
if err != nil {
return nil, err
}
executor.runners = append(executor.runners, f)
default:
return nil, fmt.Errorf("unknown functionRunner type %s", f_schema.GetType().String())
}
}
return executor, nil
}

func (executor *FunctionExecutor)processSingeFunction(idx int, msg *msgstream.InsertMsg) ([]*schemapb.FieldData, error){
runner := executor.runners[idx]
inputs := make([]*schemapb.FieldData, 0, len(runner.GetSchema().InputFieldIds))
for _, id := range runner.GetSchema().InputFieldIds {
for _, field := range msg.FieldsData{
if field.FieldId == id {
inputs = append(inputs, field)
}
}
}

if len(inputs) != len(runner.GetSchema().InputFieldIds) {
return nil, fmt.Errorf("Input field not found")
}

outputs, err := runner.ProcessInsert(inputs)
if err != nil {
return nil, err
}
return outputs, nil
}

func (executor *FunctionExecutor)ProcessInsert(msg *msgstream.InsertMsg) error{
numRows := msg.NumRows
for _, runner := range executor.runners {
if numRows > uint64(runner.MaxBatch()) {
return fmt.Errorf("numRows [%d] > function [%s]'s max batch [%d]", numRows, runner.GetSchema().Name, runner.MaxBatch())
}
}

outputs := make(chan []*schemapb.FieldData, len(executor.runners))
var wg sync.WaitGroup
for idx, _ := range executor.runners {
wg.Add(1)
go func(index int) {
defer wg.Done()
data, err := executor.processSingeFunction(index, msg)
if err != nil {
outputs <- nil
}
outputs <- data
}(idx)
}

wg.Wait()
close(outputs)
for output := range outputs {
msg.FieldsData = append(msg.FieldsData, output...)
}
return nil
}


func (executor *FunctionExecutor)ProcessSearch(msg *milvuspb.SearchRequest) error{
return nil
}
171 changes: 171 additions & 0 deletions internal/util/function/function_executor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/


package function


import (
"io"
"fmt"
"testing"
"net/http"
"net/http/httptest"
"encoding/json"

"github.com/stretchr/testify/suite"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/models"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
)


func TestFunctionExecutor(t *testing.T) {
suite.Run(t, new(FunctionExecutorSuite))
}

type FunctionExecutorSuite struct {
suite.Suite
}


func (s *OpenAIEmbeddingFunctionSuite) creataSchema(url string) *schemapb.CollectionSchema{
return &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
}},
{FieldID: 103, Name: "vector2", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
}},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test",
Type: schemapb.FunctionType_OpenAIEmbedding,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: ModelNameParamKey, Value: "text-embedding-ada-002"},
{Key: OpenaiApiKeyParamKey, Value: "mock"},
{Key: OpenaiEmbeddingUrlParamKey, Value: url},
},
},
{
Name: "test",
Type: schemapb.FunctionType_OpenAIEmbedding,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{103},
Params: []*commonpb.KeyValuePair{
{Key: ModelNameParamKey, Value: "text-embedding-ada-002"},
{Key: OpenaiApiKeyParamKey, Value: "mock"},
{Key: OpenaiEmbeddingUrlParamKey, Value: url},
},
},
},
}

}

func (s *OpenAIEmbeddingFunctionSuite)createMsg(texts []string) *msgstream.InsertMsg{

data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: texts,
},
},
},
},
}
data = append(data, &f)

msg := msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
FieldsData: data,
},
}
return &msg
}


func (s *OpenAIEmbeddingFunctionSuite)createEmbedding(texts []string, dim int) [][]float32 {
embeddings := make([][]float32, 0)
for i := 0; i < len(texts); i++ {
f := float32(i)
emb := make([]float32, 0)
for j := 0; j < dim; j++ {
emb = append(emb, f + float32(j) * 0.1)
}
embeddings = append(embeddings, emb)
}
return embeddings
}

func (s *OpenAIEmbeddingFunctionSuite) TestExecutor() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req models.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)

var res models.EmbeddingResponse
res.Object = "list"
res.Model = "text-embedding-3-small"
embs := s.createEmbedding(req.Input, 4)
for i := 0; i < len(req.Input); i++ {
res.Data = append(res.Data, models.EmbeddingData{
Object: "embedding",
Embedding: embs[i],
Index: i,
})
}

res.Usage = models.Usage{
PromptTokens: 1,
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)

}))

defer ts.Close()
schema := s.creataSchema(ts.URL)
exec, err := newFunctionExecutor(schema)
s.NoError(err)
msg := s.createMsg([]string{"sentence", "sentence"})
exec.ProcessInsert(msg)
fmt.Println(msg)

}
32 changes: 20 additions & 12 deletions internal/util/function/openai_embedding_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ const (
const (
maxBatch = 128
timeoutSec = 60
maxRowNum = 60 * maxBatch
)

const (
Expand All @@ -52,7 +51,7 @@ const (


type OpenAIEmbeddingFunction struct {
base *FunctionBase
FunctionBase
fieldDim int64

client *models.OpenAIEmbeddingClient
Expand All @@ -79,12 +78,12 @@ func createOpenAIEmbeddingClient(apiKey string, url string) (*models.OpenAIEmbed
return &c, nil
}

func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema, mode RunnerMode) (*OpenAIEmbeddingFunction, error) {
func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemapb.FunctionSchema) (*OpenAIEmbeddingFunction, error) {
if len(schema.GetOutputFieldIds()) != 1 {
return nil, fmt.Errorf("OpenAIEmbedding function should only have one output field, but now %d", len(schema.GetOutputFieldIds()))
}

base, err := NewBase(coll, schema, mode)
base, err := NewBase(coll, schema)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -131,7 +130,7 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap
}

runner := OpenAIEmbeddingFunction{
base: base,
FunctionBase: *base,
client: c,
fieldDim: fieldDim,
modelName: modelName,
Expand All @@ -146,7 +145,16 @@ func NewOpenAIEmbeddingFunction(coll *schemapb.CollectionSchema, schema *schemap
return &runner, nil
}

func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
func (runner *OpenAIEmbeddingFunction)MaxBatch() int {
return 5 * maxBatch
}


func (runner *OpenAIEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
return runner.Run(inputs)
}

func (runner *OpenAIEmbeddingFunction) Run( inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
if len(inputs) != 1 {
return nil, fmt.Errorf("OpenAIEmbedding function only receives one input, bug got [%d]", len(inputs))
}
Expand All @@ -161,15 +169,15 @@ func (runner *OpenAIEmbeddingFunction) Run(inputs []*schemapb.FieldData) ([]*sch
}

numRows := len(texts)
if numRows > maxRowNum {
return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", maxRowNum, numRows)
if numRows > runner.MaxBatch() {
return nil, fmt.Errorf("OpenAI embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
}

var output_field schemapb.FieldData
output_field.FieldId = runner.base.outputFields[0].FieldID
output_field.FieldName = runner.base.outputFields[0].Name
output_field.Type = runner.base.outputFields[0].DataType
output_field.IsDynamic = runner.base.outputFields[0].IsDynamic
output_field.FieldId = runner.outputFields[0].FieldID
output_field.FieldName = runner.outputFields[0].Name
output_field.Type = runner.outputFields[0].DataType
output_field.IsDynamic = runner.outputFields[0].IsDynamic
data := make([]float32, 0, numRows * int(runner.fieldDim))
for i := 0; i < numRows; i += maxBatch {
end := i + maxBatch
Expand Down
Loading

0 comments on commit 38243dd

Please sign in to comment.