Skip to content

Commit

Permalink
enhance: [cherry-pick ]Enhance the expression template to support AND…
Browse files Browse the repository at this point in the history
… and OR operations (#37217)

issue: #36672

master pr: #37033

Signed-off-by: Cai Zhang <[email protected]>
  • Loading branch information
xiaocai2333 authored Oct 29, 2024
1 parent 3d1e81f commit 05c4052
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 139 deletions.
56 changes: 30 additions & 26 deletions internal/parser/planparserv2/fill_expression_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,39 +143,43 @@ func FillBinaryArithOpEvalRangeExpressionValue(expr *planpb.BinaryArithOpEvalRan
var err error
var ok bool

operand := expr.GetRightOperand()
if operand == nil || expr.GetOperandTemplateVariableName() != "" {
operand, ok = templateValues[expr.GetOperandTemplateVariableName()]
if !ok {
return fmt.Errorf("the right operand value of expression template variable name {%s} is not found", expr.GetOperandTemplateVariableName())
if expr.ArithOp == planpb.ArithOpType_ArrayLength {
dataType = schemapb.DataType_Int64
} else {
operand := expr.GetRightOperand()
if operand == nil || expr.GetOperandTemplateVariableName() != "" {
operand, ok = templateValues[expr.GetOperandTemplateVariableName()]
if !ok {
return fmt.Errorf("the right operand value of expression template variable name {%s} is not found", expr.GetOperandTemplateVariableName())
}
}
}

operandExpr := toValueExpr(operand)
lDataType, rDataType := expr.GetColumnInfo().GetDataType(), operandExpr.dataType
if typeutil.IsArrayType(expr.GetColumnInfo().GetDataType()) {
lDataType = expr.GetColumnInfo().GetElementType()
}
operandExpr := toValueExpr(operand)
lDataType, rDataType := expr.GetColumnInfo().GetDataType(), operandExpr.dataType
if typeutil.IsArrayType(expr.GetColumnInfo().GetDataType()) {
lDataType = expr.GetColumnInfo().GetElementType()
}

if err = checkValidModArith(expr.GetArithOp(), expr.GetColumnInfo().GetDataType(), expr.GetColumnInfo().GetElementType(),
rDataType, schemapb.DataType_None); err != nil {
return err
}
if err = checkValidModArith(expr.GetArithOp(), expr.GetColumnInfo().GetDataType(), expr.GetColumnInfo().GetElementType(),
rDataType, schemapb.DataType_None); err != nil {
return err
}

if operand.GetArrayVal() != nil {
return fmt.Errorf("can not comparisons array directly")
}
if operand.GetArrayVal() != nil {
return fmt.Errorf("can not comparisons array directly")
}

dataType, err = getTargetType(lDataType, rDataType)
if err != nil {
return err
}
dataType, err = getTargetType(lDataType, rDataType)
if err != nil {
return err
}

castedOperand, err := castValue(dataType, operand)
if err != nil {
return err
castedOperand, err := castValue(dataType, operand)
if err != nil {
return err
}
expr.RightOperand = castedOperand
}
expr.RightOperand = castedOperand

value := expr.GetValue()
if expr.GetValue() == nil || expr.GetValueTemplateVariableName() != "" {
Expand Down
193 changes: 80 additions & 113 deletions internal/parser/planparserv2/fill_expression_value_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package planparserv2

import (
"encoding/json"
"fmt"
"testing"

"github.com/stretchr/testify/suite"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

type FillExpressionValueSuite struct {
Expand All @@ -24,10 +22,17 @@ type testcase struct {
values map[string]*schemapb.TemplateValue
}

func (s *FillExpressionValueSuite) jsonMarshal(v interface{}) []byte {
r, err := json.Marshal(v)
s.NoError(err)
return r
func (s *FillExpressionValueSuite) assertValidExpr(helper *typeutil.SchemaHelper, exprStr string, templateValues map[string]*schemapb.TemplateValue) {
expr, err := ParseExpr(helper, exprStr, templateValues)
s.NoError(err, exprStr)
s.NotNil(expr, exprStr)
ShowExpr(expr)
}

func (s *FillExpressionValueSuite) assertInvalidExpr(helper *typeutil.SchemaHelper, exprStr string, templateValues map[string]*schemapb.TemplateValue) {
expr, err := ParseExpr(helper, exprStr, templateValues)
s.Error(err, exprStr)
s.Nil(expr, exprStr)
}

func (s *FillExpressionValueSuite) TestTermExpr() {
Expand Down Expand Up @@ -88,17 +93,8 @@ func (s *FillExpressionValueSuite) TestTermExpr() {
}},
}
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.NoError(err)
s.NotNil(plan)
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

Expand Down Expand Up @@ -130,17 +126,8 @@ func (s *FillExpressionValueSuite) TestTermExpr() {
}},
}
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)
s.Error(err)
s.Nil(plan)
fmt.Println(plan)
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}
Expand Down Expand Up @@ -172,19 +159,8 @@ func (s *FillExpressionValueSuite) TestUnaryRange() {
}},
}
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.NoError(err)
s.NotNil(plan)
s.NotNil(plan.GetVectorAnns())
s.NotNil(plan.GetVectorAnns().GetPredicates())
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

Expand Down Expand Up @@ -214,17 +190,8 @@ func (s *FillExpressionValueSuite) TestUnaryRange() {
}},
}
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.Error(err)
s.Nil(plan)
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}
Expand Down Expand Up @@ -264,19 +231,8 @@ func (s *FillExpressionValueSuite) TestBinaryRange() {
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.NoError(err)
s.NotNil(plan)
s.NotNil(plan.GetVectorAnns())
s.NotNil(plan.GetVectorAnns().GetPredicates())
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

Expand Down Expand Up @@ -311,17 +267,8 @@ func (s *FillExpressionValueSuite) TestBinaryRange() {
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.Error(err)
s.Nil(plan)
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}
Expand All @@ -343,22 +290,17 @@ func (s *FillExpressionValueSuite) TestBinaryArithOpEvalRange() {
{`ArrayField[0] % {offset} < 11`, map[string]*schemapb.TemplateValue{
"offset": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)),
}},
{`array_length(ArrayField) == {length}`, map[string]*schemapb.TemplateValue{
"length": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)),
}},
{`array_length(ArrayField) > {length}`, map[string]*schemapb.TemplateValue{
"length": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)),
}},
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.NoError(err)
s.NotNil(plan)
s.NotNil(plan.GetVectorAnns())
s.NotNil(plan.GetVectorAnns().GetPredicates())
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

Expand Down Expand Up @@ -401,20 +343,14 @@ func (s *FillExpressionValueSuite) TestBinaryArithOpEvalRange() {
}),
"target": generateExpressionFieldData(schemapb.DataType_Int64, int64(5)),
}},
{`array_length(ArrayField) == {length}`, map[string]*schemapb.TemplateValue{
"length": generateExpressionFieldData(schemapb.DataType_String, "abc"),
}},
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.Error(err)
s.Nil(plan)
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}
Expand Down Expand Up @@ -494,17 +430,7 @@ func (s *FillExpressionValueSuite) TestJSONContainsExpression() {
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.NoError(err)
s.NotNil(plan)
s.NotNil(plan.GetVectorAnns())
s.NotNil(plan.GetVectorAnns().GetPredicates())
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

Expand Down Expand Up @@ -554,15 +480,56 @@ func (s *FillExpressionValueSuite) TestJSONContainsExpression() {
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.Error(err)
s.Nil(plan)
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}

func (s *FillExpressionValueSuite) TestBinaryExpression() {
s.Run("normal case", func() {
testcases := []testcase{
{`Int64Field > {int} && StringField in {list}`, map[string]*schemapb.TemplateValue{
"int": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)),
"list": generateExpressionFieldData(schemapb.DataType_Array, []interface{}{
generateExpressionFieldData(schemapb.DataType_VarChar, "abc"),
generateExpressionFieldData(schemapb.DataType_VarChar, "def"),
generateExpressionFieldData(schemapb.DataType_VarChar, "ghi"),
}),
}},
{`{max} > FloatField >= {min} or BoolField == {bool}`, map[string]*schemapb.TemplateValue{
"min": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)),
"max": generateExpressionFieldData(schemapb.DataType_Float, 22.22),
"bool": generateExpressionFieldData(schemapb.DataType_Bool, true),
}},
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

s.Run("failed case", func() {
testcases := []testcase{
{`Int64Field > {int} && StringField in {list}`, map[string]*schemapb.TemplateValue{
"int": generateExpressionFieldData(schemapb.DataType_String, "abc"),
"list": generateExpressionFieldData(schemapb.DataType_Array, []interface{}{
generateExpressionFieldData(schemapb.DataType_VarChar, "abc"),
generateExpressionFieldData(schemapb.DataType_Int64, int64(10)),
generateExpressionFieldData(schemapb.DataType_VarChar, "ghi"),
}),
}},
{`{max} > FloatField >= {min} or BoolField == {bool}`, map[string]*schemapb.TemplateValue{
"min": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)),
"bool": generateExpressionFieldData(schemapb.DataType_Bool, true),
}},
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}
2 changes: 2 additions & 0 deletions internal/parser/planparserv2/parser_visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ func (v *ParserVisitor) VisitLogicalOr(ctx *parser.LogicalOrContext) interface{}
Op: planpb.BinaryExpr_LogicalOr,
},
},
IsTemplate: leftExpr.expr.GetIsTemplate() || rightExpr.expr.GetIsTemplate(),
}

return &ExprWithType{
Expand Down Expand Up @@ -874,6 +875,7 @@ func (v *ParserVisitor) VisitLogicalAnd(ctx *parser.LogicalAndContext) interface
Op: planpb.BinaryExpr_LogicalAnd,
},
},
IsTemplate: leftExpr.expr.GetIsTemplate() || rightExpr.expr.GetIsTemplate(),
}

return &ExprWithType{
Expand Down
1 change: 1 addition & 0 deletions internal/parser/planparserv2/plan_parser_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ func TestExpr_BinaryArith(t *testing.T) {
exprStrs := []string{
`Int64Field % 10 == 9`,
`Int64Field % 10 != 9`,
`FloatField + 1.1 == 2.1`,
`Int64Field + 1.1 == 2.1`,
`A % 10 != 2`,
`Int8Field + 1 < 2`,
Expand Down
3 changes: 3 additions & 0 deletions internal/parser/planparserv2/show_visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ func extractColumnInfo(info *planpb.ColumnInfo) interface{} {
}

func extractGenericValue(value *planpb.GenericValue) interface{} {
if value == nil {
return nil
}
switch realValue := value.Val.(type) {
case *planpb.GenericValue_BoolVal:
return realValue.BoolVal
Expand Down
Loading

0 comments on commit 05c4052

Please sign in to comment.