From 05c40522ce5d88305fabbb6af70b480ef25324d7 Mon Sep 17 00:00:00 2001 From: "cai.zhang" Date: Tue, 29 Oct 2024 15:38:40 +0800 Subject: [PATCH] enhance: [cherry-pick ]Enhance the expression template to support AND and OR operations (#37217) issue: #36672 master pr: #37033 Signed-off-by: Cai Zhang --- .../planparserv2/fill_expression_value.go | 56 ++--- .../fill_expression_value_test.go | 193 ++++++++---------- .../parser/planparserv2/parser_visitor.go | 2 + .../planparserv2/plan_parser_v2_test.go | 1 + internal/parser/planparserv2/show_visitor.go | 3 + internal/parser/planparserv2/utils.go | 1 + 6 files changed, 117 insertions(+), 139 deletions(-) diff --git a/internal/parser/planparserv2/fill_expression_value.go b/internal/parser/planparserv2/fill_expression_value.go index 534bbc564772a..8840ac75b08bf 100644 --- a/internal/parser/planparserv2/fill_expression_value.go +++ b/internal/parser/planparserv2/fill_expression_value.go @@ -143,39 +143,43 @@ func FillBinaryArithOpEvalRangeExpressionValue(expr *planpb.BinaryArithOpEvalRan var err error var ok bool - operand := expr.GetRightOperand() - if operand == nil || expr.GetOperandTemplateVariableName() != "" { - operand, ok = templateValues[expr.GetOperandTemplateVariableName()] - if !ok { - return fmt.Errorf("the right operand value of expression template variable name {%s} is not found", expr.GetOperandTemplateVariableName()) + if expr.ArithOp == planpb.ArithOpType_ArrayLength { + dataType = schemapb.DataType_Int64 + } else { + operand := expr.GetRightOperand() + if operand == nil || expr.GetOperandTemplateVariableName() != "" { + operand, ok = templateValues[expr.GetOperandTemplateVariableName()] + if !ok { + return fmt.Errorf("the right operand value of expression template variable name {%s} is not found", expr.GetOperandTemplateVariableName()) + } } - } - operandExpr := toValueExpr(operand) - lDataType, rDataType := expr.GetColumnInfo().GetDataType(), operandExpr.dataType - if typeutil.IsArrayType(expr.GetColumnInfo().GetDataType()) { - lDataType = expr.GetColumnInfo().GetElementType() - } + operandExpr := toValueExpr(operand) + lDataType, rDataType := expr.GetColumnInfo().GetDataType(), operandExpr.dataType + if typeutil.IsArrayType(expr.GetColumnInfo().GetDataType()) { + lDataType = expr.GetColumnInfo().GetElementType() + } - if err = checkValidModArith(expr.GetArithOp(), expr.GetColumnInfo().GetDataType(), expr.GetColumnInfo().GetElementType(), - rDataType, schemapb.DataType_None); err != nil { - return err - } + if err = checkValidModArith(expr.GetArithOp(), expr.GetColumnInfo().GetDataType(), expr.GetColumnInfo().GetElementType(), + rDataType, schemapb.DataType_None); err != nil { + return err + } - if operand.GetArrayVal() != nil { - return fmt.Errorf("can not comparisons array directly") - } + if operand.GetArrayVal() != nil { + return fmt.Errorf("can not comparisons array directly") + } - dataType, err = getTargetType(lDataType, rDataType) - if err != nil { - return err - } + dataType, err = getTargetType(lDataType, rDataType) + if err != nil { + return err + } - castedOperand, err := castValue(dataType, operand) - if err != nil { - return err + castedOperand, err := castValue(dataType, operand) + if err != nil { + return err + } + expr.RightOperand = castedOperand } - expr.RightOperand = castedOperand value := expr.GetValue() if expr.GetValue() == nil || expr.GetValueTemplateVariableName() != "" { diff --git a/internal/parser/planparserv2/fill_expression_value_test.go b/internal/parser/planparserv2/fill_expression_value_test.go index b2adb65874035..0de63bb65b88f 100644 --- a/internal/parser/planparserv2/fill_expression_value_test.go +++ b/internal/parser/planparserv2/fill_expression_value_test.go @@ -1,14 +1,12 @@ package planparserv2 import ( - "encoding/json" - "fmt" "testing" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type FillExpressionValueSuite struct { @@ -24,10 +22,17 @@ type testcase struct { values map[string]*schemapb.TemplateValue } -func (s *FillExpressionValueSuite) jsonMarshal(v interface{}) []byte { - r, err := json.Marshal(v) - s.NoError(err) - return r +func (s *FillExpressionValueSuite) assertValidExpr(helper *typeutil.SchemaHelper, exprStr string, templateValues map[string]*schemapb.TemplateValue) { + expr, err := ParseExpr(helper, exprStr, templateValues) + s.NoError(err, exprStr) + s.NotNil(expr, exprStr) + ShowExpr(expr) +} + +func (s *FillExpressionValueSuite) assertInvalidExpr(helper *typeutil.SchemaHelper, exprStr string, templateValues map[string]*schemapb.TemplateValue) { + expr, err := ParseExpr(helper, exprStr, templateValues) + s.Error(err, exprStr) + s.Nil(expr, exprStr) } func (s *FillExpressionValueSuite) TestTermExpr() { @@ -88,17 +93,8 @@ func (s *FillExpressionValueSuite) TestTermExpr() { }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.NoError(err) - s.NotNil(plan) + s.assertValidExpr(schemaH, c.expr, c.values) } }) @@ -130,17 +126,8 @@ func (s *FillExpressionValueSuite) TestTermExpr() { }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - s.Error(err) - s.Nil(plan) - fmt.Println(plan) + s.assertInvalidExpr(schemaH, c.expr, c.values) } }) } @@ -172,19 +159,8 @@ func (s *FillExpressionValueSuite) TestUnaryRange() { }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.NoError(err) - s.NotNil(plan) - s.NotNil(plan.GetVectorAnns()) - s.NotNil(plan.GetVectorAnns().GetPredicates()) + s.assertValidExpr(schemaH, c.expr, c.values) } }) @@ -214,17 +190,8 @@ func (s *FillExpressionValueSuite) TestUnaryRange() { }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.Error(err) - s.Nil(plan) + s.assertInvalidExpr(schemaH, c.expr, c.values) } }) } @@ -264,19 +231,8 @@ func (s *FillExpressionValueSuite) TestBinaryRange() { } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.NoError(err) - s.NotNil(plan) - s.NotNil(plan.GetVectorAnns()) - s.NotNil(plan.GetVectorAnns().GetPredicates()) + s.assertValidExpr(schemaH, c.expr, c.values) } }) @@ -311,17 +267,8 @@ func (s *FillExpressionValueSuite) TestBinaryRange() { } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.Error(err) - s.Nil(plan) + s.assertInvalidExpr(schemaH, c.expr, c.values) } }) } @@ -343,22 +290,17 @@ func (s *FillExpressionValueSuite) TestBinaryArithOpEvalRange() { {`ArrayField[0] % {offset} < 11`, map[string]*schemapb.TemplateValue{ "offset": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)), }}, + {`array_length(ArrayField) == {length}`, map[string]*schemapb.TemplateValue{ + "length": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)), + }}, + {`array_length(ArrayField) > {length}`, map[string]*schemapb.TemplateValue{ + "length": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)), + }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.NoError(err) - s.NotNil(plan) - s.NotNil(plan.GetVectorAnns()) - s.NotNil(plan.GetVectorAnns().GetPredicates()) + s.assertValidExpr(schemaH, c.expr, c.values) } }) @@ -401,20 +343,14 @@ func (s *FillExpressionValueSuite) TestBinaryArithOpEvalRange() { }), "target": generateExpressionFieldData(schemapb.DataType_Int64, int64(5)), }}, + {`array_length(ArrayField) == {length}`, map[string]*schemapb.TemplateValue{ + "length": generateExpressionFieldData(schemapb.DataType_String, "abc"), + }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.Error(err) - s.Nil(plan) + s.assertInvalidExpr(schemaH, c.expr, c.values) } }) } @@ -494,17 +430,7 @@ func (s *FillExpressionValueSuite) TestJSONContainsExpression() { schemaH := newTestSchemaHelper(s.T()) for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.NoError(err) - s.NotNil(plan) - s.NotNil(plan.GetVectorAnns()) - s.NotNil(plan.GetVectorAnns().GetPredicates()) + s.assertValidExpr(schemaH, c.expr, c.values) } }) @@ -554,15 +480,56 @@ func (s *FillExpressionValueSuite) TestJSONContainsExpression() { schemaH := newTestSchemaHelper(s.T()) for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.Error(err) - s.Nil(plan) + s.assertInvalidExpr(schemaH, c.expr, c.values) + } + }) +} + +func (s *FillExpressionValueSuite) TestBinaryExpression() { + s.Run("normal case", func() { + testcases := []testcase{ + {`Int64Field > {int} && StringField in {list}`, map[string]*schemapb.TemplateValue{ + "int": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)), + "list": generateExpressionFieldData(schemapb.DataType_Array, []interface{}{ + generateExpressionFieldData(schemapb.DataType_VarChar, "abc"), + generateExpressionFieldData(schemapb.DataType_VarChar, "def"), + generateExpressionFieldData(schemapb.DataType_VarChar, "ghi"), + }), + }}, + {`{max} > FloatField >= {min} or BoolField == {bool}`, map[string]*schemapb.TemplateValue{ + "min": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)), + "max": generateExpressionFieldData(schemapb.DataType_Float, 22.22), + "bool": generateExpressionFieldData(schemapb.DataType_Bool, true), + }}, + } + + schemaH := newTestSchemaHelper(s.T()) + + for _, c := range testcases { + s.assertValidExpr(schemaH, c.expr, c.values) + } + }) + + s.Run("failed case", func() { + testcases := []testcase{ + {`Int64Field > {int} && StringField in {list}`, map[string]*schemapb.TemplateValue{ + "int": generateExpressionFieldData(schemapb.DataType_String, "abc"), + "list": generateExpressionFieldData(schemapb.DataType_Array, []interface{}{ + generateExpressionFieldData(schemapb.DataType_VarChar, "abc"), + generateExpressionFieldData(schemapb.DataType_Int64, int64(10)), + generateExpressionFieldData(schemapb.DataType_VarChar, "ghi"), + }), + }}, + {`{max} > FloatField >= {min} or BoolField == {bool}`, map[string]*schemapb.TemplateValue{ + "min": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)), + "bool": generateExpressionFieldData(schemapb.DataType_Bool, true), + }}, + } + + schemaH := newTestSchemaHelper(s.T()) + + for _, c := range testcases { + s.assertInvalidExpr(schemaH, c.expr, c.values) } }) } diff --git a/internal/parser/planparserv2/parser_visitor.go b/internal/parser/planparserv2/parser_visitor.go index bd0f36dba169a..dba7f939bd900 100644 --- a/internal/parser/planparserv2/parser_visitor.go +++ b/internal/parser/planparserv2/parser_visitor.go @@ -826,6 +826,7 @@ func (v *ParserVisitor) VisitLogicalOr(ctx *parser.LogicalOrContext) interface{} Op: planpb.BinaryExpr_LogicalOr, }, }, + IsTemplate: leftExpr.expr.GetIsTemplate() || rightExpr.expr.GetIsTemplate(), } return &ExprWithType{ @@ -874,6 +875,7 @@ func (v *ParserVisitor) VisitLogicalAnd(ctx *parser.LogicalAndContext) interface Op: planpb.BinaryExpr_LogicalAnd, }, }, + IsTemplate: leftExpr.expr.GetIsTemplate() || rightExpr.expr.GetIsTemplate(), } return &ExprWithType{ diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index b492f62e1d7fe..50a9fdbc36161 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -226,6 +226,7 @@ func TestExpr_BinaryArith(t *testing.T) { exprStrs := []string{ `Int64Field % 10 == 9`, `Int64Field % 10 != 9`, + `FloatField + 1.1 == 2.1`, `Int64Field + 1.1 == 2.1`, `A % 10 != 2`, `Int8Field + 1 < 2`, diff --git a/internal/parser/planparserv2/show_visitor.go b/internal/parser/planparserv2/show_visitor.go index b9b263b6e0631..2c04194924102 100644 --- a/internal/parser/planparserv2/show_visitor.go +++ b/internal/parser/planparserv2/show_visitor.go @@ -21,6 +21,9 @@ func extractColumnInfo(info *planpb.ColumnInfo) interface{} { } func extractGenericValue(value *planpb.GenericValue) interface{} { + if value == nil { + return nil + } switch realValue := value.Val.(type) { case *planpb.GenericValue_BoolVal: return realValue.BoolVal diff --git a/internal/parser/planparserv2/utils.go b/internal/parser/planparserv2/utils.go index b358cc38f44fb..8ad20092300f2 100644 --- a/internal/parser/planparserv2/utils.go +++ b/internal/parser/planparserv2/utils.go @@ -269,6 +269,7 @@ func combineArrayLengthExpr(op planpb.OpType, arithOp planpb.ArithOpType, column ValueTemplateVariableName: valueExpr.GetTemplateVariableName(), }, }, + IsTemplate: isTemplateExpr(valueExpr), }, nil }