Skip to content

Commit

Permalink
Add function check
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <[email protected]>
  • Loading branch information
junjiejiangjjj committed Dec 19, 2024
1 parent e7e5e7c commit 6334590
Show file tree
Hide file tree
Showing 14 changed files with 138 additions and 96 deletions.
10 changes: 6 additions & 4 deletions internal/datanode/importv2/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,12 @@ func (s *SchedulerSuite) TestScheduler_ImportFileWithFunction() {
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{100},
OutputFieldIds: []int64{101},
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{100},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{101},
OutputFieldNames: []string{"vec"},
Params: []*commonpb.KeyValuePair{
{Key: function.Provider, Value: function.OpenAIProvider},
{Key: "model_name", Value: "text-embedding-ada-002"},
Expand Down
22 changes: 17 additions & 5 deletions internal/proxy/task_insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,12 @@ func TestInsertTask_Function(t *testing.T) {
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test_function",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
Name: "test_function",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
OutputFieldNames: []string{"vector"},
Params: []*commonpb.KeyValuePair{
{Key: function.Provider, Value: function.OpenAIProvider},
{Key: "model_name", Value: "text-embedding-ada-002"},
Expand Down Expand Up @@ -416,6 +417,17 @@ func TestInsertTask_Function(t *testing.T) {
createdTimestamp: 10001,
createdUtcTimestamp: 10002,
}, nil)
cache.On("GetCollectionInfo",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&collectionInfo{schema: info}, nil)
cache.On("GetDatabaseInfo",
mock.Anything,
mock.Anything,
).Return(&databaseInfo{properties: []*commonpb.KeyValuePair{}}, nil)

globalMetaCache = cache
err = task.PreExecute(ctx)
assert.NoError(t, err)
Expand Down
24 changes: 14 additions & 10 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,10 @@ func TestSearchTask_PreExecute(t *testing.T) {
func TestSearchTask_WithFunctions(t *testing.T) {
ts := function.CreateOpenAIEmbeddingServer()
defer ts.Close()
collectionName := "TestInsertTask_function"
collectionName := "TestSearchTask_function"
schema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "TestInsertTask_function",
Description: "TestSearchTask_function",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
Expand All @@ -383,10 +383,12 @@ func TestSearchTask_WithFunctions(t *testing.T) {
},
Functions: []*schemapb.FunctionSchema{
{
Name: "func1",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Name: "func1",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
OutputFieldNames: []string{"vector1"},
Params: []*commonpb.KeyValuePair{
{Key: function.Provider, Value: function.OpenAIProvider},
{Key: "model_name", Value: "text-embedding-ada-002"},
Expand All @@ -396,10 +398,12 @@ func TestSearchTask_WithFunctions(t *testing.T) {
},
},
{
Name: "func2",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{103},
Name: "func2",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{103},
OutputFieldNames: []string{"vector2"},
Params: []*commonpb.KeyValuePair{
{Key: function.Provider, Value: function.OpenAIProvider},
{Key: "model_name", Value: "text-embedding-ada-002"},
Expand Down
2 changes: 0 additions & 2 deletions internal/proxy/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1032,8 +1032,6 @@ func TestCreateCollectionTask(t *testing.T) {
Type: schemapb.FunctionType_TextEmbedding,
InputFieldNames: []string{varCharField},
OutputFieldNames: []string{floatVecField},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
Expand Down
11 changes: 6 additions & 5 deletions internal/proxy/task_upsert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,12 @@ func TestUpsertTask_Function(t *testing.T) {
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test_function",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
Name: "test_function",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
OutputFieldNames: []string{"vector"},
Params: []*commonpb.KeyValuePair{
{Key: function.Provider, Value: function.OpenAIProvider},
{Key: "model_name", Value: "text-embedding-ada-002"},
Expand Down
10 changes: 6 additions & 4 deletions internal/util/function/alitext_embedding_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ func (s *AliTextEmbeddingProviderSuite) SetupTest() {

func createAliProvider(url string, schema *schemapb.FieldSchema, providerName string) (TextEmbeddingProvider, error) {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: TextEmbeddingV3},
{Key: apiKeyParamKey, Value: "mock"},
Expand Down
9 changes: 4 additions & 5 deletions internal/util/function/bedrock_embedding_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
}
var awsAccessKeyId, awsSecretAccessKey, region, modelName string
var dim int64
var normalize bool
normalize := false

for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
Expand All @@ -112,8 +112,6 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
switch strings.ToLower(param.Value) {
case "true":
normalize = true
case "false":
normalize = false
default:
return nil, fmt.Errorf("Illegal [%s:%s] param, ", normalizeParamKey, param.Value)
}
Expand Down Expand Up @@ -147,11 +145,11 @@ func NewBedrockEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSche
}

func (provider *BedrockEmbeddingProvider) MaxBatch() int {
return 5 * provider.maxBatch
return 12 * provider.maxBatch
}

func (provider *BedrockEmbeddingProvider) FieldDim() int64 {
return 5 * provider.fieldDim
return provider.fieldDim
}

func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, _ string) ([][]float32, error) {
Expand All @@ -164,6 +162,7 @@ func (provider *BedrockEmbeddingProvider) CallEmbedding(texts []string, batchLim
for i := 0; i < numRows; i += 1 {
payload := BedRockRequest{
InputText: texts[i],
Normalize: provider.normalize,
}
if provider.embedDimParam != 0 {
payload.Dimensions = provider.embedDimParam
Expand Down
10 changes: 6 additions & 4 deletions internal/util/function/bedrock_text_embedding_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ func (s *BedrockTextEmbeddingProviderSuite) SetupTest() {

func createBedrockProvider(schema *schemapb.FieldSchema, providerName string, dim int) (TextEmbeddingProvider, error) {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: BedRockTitanTextEmbeddingsV2},
{Key: apiKeyParamKey, Value: "mock"},
Expand Down
6 changes: 3 additions & 3 deletions internal/util/function/function_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ type FunctionBase struct {
func NewFunctionBase(coll *schemapb.CollectionSchema, f_schema *schemapb.FunctionSchema) (*FunctionBase, error) {
var base FunctionBase
base.schema = f_schema
for _, field_id := range f_schema.GetOutputFieldIds() {
for _, fieldName := range f_schema.GetOutputFieldNames() {
for _, field := range coll.GetFields() {
if field.GetFieldID() == field_id {
if field.GetName() == fieldName {
base.outputFields = append(base.outputFields, field)
break
}
}
}

if len(base.outputFields) != len(f_schema.GetOutputFieldIds()) {
if len(base.outputFields) != len(f_schema.GetOutputFieldNames()) {
return &base, fmt.Errorf("The collection [%s]'s information is wrong, function [%s]'s outputs does not match the schema",
coll.Name, f_schema.Name)
}
Expand Down
22 changes: 12 additions & 10 deletions internal/util/function/function_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,12 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
OutputFieldNames: []string{"vector"},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: OpenAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
Expand All @@ -76,11 +77,12 @@ func (s *FunctionExecutorSuite) creataSchema(url string) *schemapb.CollectionSch
},
},
{
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{103},
Name: "test",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{103},
OutputFieldNames: []string{"vector2"},
Params: []*commonpb.KeyValuePair{
{Key: Provider, Value: OpenAIProvider},
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
Expand Down
10 changes: 6 additions & 4 deletions internal/util/function/openai_text_embedding_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ func (s *OpenAITextEmbeddingProviderSuite) SetupTest() {

func createOpenAIProvider(url string, schema *schemapb.FieldSchema, providerName string) (TextEmbeddingProvider, error) {
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Name: "test",
Type: schemapb.FunctionType_Unknown,
InputFieldNames: []string{"text"},
OutputFieldNames: []string{"vector"},
InputFieldIds: []int64{101},
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: modelNameParamKey, Value: "text-embedding-ada-002"},
{Key: apiKeyParamKey, Value: "mock"},
Expand Down
6 changes: 3 additions & 3 deletions internal/util/function/text_embedding_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ type TextEmebddingFunction struct {
}

func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *schemapb.FunctionSchema) (*TextEmebddingFunction, error) {
if len(functionSchema.GetOutputFieldIds()) != 1 {
return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldIds()))
if len(functionSchema.GetOutputFieldNames()) != 1 {
return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldNames()))
}

base, err := NewFunctionBase(coll, functionSchema)
Expand All @@ -76,7 +76,7 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s
}

if base.outputFields[0].DataType != schemapb.DataType_FloatVector {
return nil, fmt.Errorf("Output field not match, openai embedding needs [%s], got [%s]",
return nil, fmt.Errorf("Text embedding function's output field not match, needs [%s], got [%s]",
schemapb.DataType_name[int32(schemapb.DataType_FloatVector)],
schemapb.DataType_name[int32(base.outputFields[0].DataType)])
}
Expand Down
Loading

0 comments on commit 6334590

Please sign in to comment.