diff --git a/internal/parser/planparserv2/parser_visitor.go b/internal/parser/planparserv2/parser_visitor.go index 24f437feb7d52..ffc4f13623459 100644 --- a/internal/parser/planparserv2/parser_visitor.go +++ b/internal/parser/planparserv2/parser_visitor.go @@ -545,8 +545,7 @@ func (v *ParserVisitor) VisitTerm(ctx *parser.TermContext) interface{} { } else { elementValue := valueExpr.GetValue() if elementValue == nil { - return fmt.Errorf( - "contains_any operation are only supported explicitly specified element, got: %s", ctx.Expr(1).GetText()) + return fmt.Errorf("value '%s' in list cannot be a non-const expression", ctx.Expr(1).GetText()) } if !IsArray(elementValue) { @@ -662,12 +661,12 @@ func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} { lowerValue := lowerValueExpr.GetValue() upperValue := upperValueExpr.GetValue() if !isTemplateExpr(lowerValueExpr) { - if err = checkRangeCompared(fieldDataType, lowerValue); err != nil { + if lowerValue, err = castRangeValue(fieldDataType, lowerValue); err != nil { return err } } if !isTemplateExpr(upperValueExpr) { - if err = checkRangeCompared(fieldDataType, upperValue); err != nil { + if upperValue, err = castRangeValue(fieldDataType, upperValue); err != nil { return err } } @@ -744,12 +743,12 @@ func (v *ParserVisitor) VisitReverseRange(ctx *parser.ReverseRangeContext) inter lowerValue := lowerValueExpr.GetValue() upperValue := upperValueExpr.GetValue() if !isTemplateExpr(lowerValueExpr) { - if err = checkRangeCompared(fieldDataType, lowerValue); err != nil { + if lowerValue, err = castRangeValue(fieldDataType, lowerValue); err != nil { return err } } if !isTemplateExpr(upperValueExpr) { - if err = checkRangeCompared(fieldDataType, upperValue); err != nil { + if upperValue, err = castRangeValue(fieldDataType, upperValue); err != nil { return err } } diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index d3adb5b36577c..17cca040e0ffa 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -274,6 +274,28 @@ func TestExpr_BinaryRange(t *testing.T) { } } +func TestExpr_castValue(t *testing.T) { + schema := newTestSchema() + helper, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + exprStr := `Int64Field + 1.1 == 2.1` + expr, err := ParseExpr(helper, exprStr, nil) + assert.NoError(t, err, exprStr) + assert.NotNil(t, expr, exprStr) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr()) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr().GetRightOperand().GetFloatVal()) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr().GetValue().GetFloatVal()) + + exprStr = `FloatField +1 == 2` + expr, err = ParseExpr(helper, exprStr, nil) + assert.NoError(t, err, exprStr) + assert.NotNil(t, expr, exprStr) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr()) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr().GetRightOperand().GetFloatVal()) + assert.NotNil(t, expr.GetBinaryArithOpEvalRangeExpr().GetValue().GetFloatVal()) +} + func TestExpr_BinaryArith(t *testing.T) { schema := newTestSchema() helper, err := typeutil.CreateSchemaHelper(schema) @@ -283,7 +305,6 @@ func TestExpr_BinaryArith(t *testing.T) { `Int64Field % 10 == 9`, `Int64Field % 10 != 9`, `FloatField + 1.1 == 2.1`, - `Int64Field + 1.1 == 2.1`, `A % 10 != 2`, `Int8Field + 1 < 2`, `Int16Field - 3 <= 4`, diff --git a/internal/parser/planparserv2/utils.go b/internal/parser/planparserv2/utils.go index e61bbd237c9bc..4faef470dd7c1 100644 --- a/internal/parser/planparserv2/utils.go +++ b/internal/parser/planparserv2/utils.go @@ -241,13 +241,22 @@ func castValue(dataType schemapb.DataType, value *planpb.GenericValue) (*planpb. return nil, fmt.Errorf("cannot cast value to %s, value: %s", dataType.String(), value) } -func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, operandExpr, valueExpr *planpb.ValueExpr) *planpb.Expr { +func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, arithExprDataType schemapb.DataType, columnInfo *planpb.ColumnInfo, operandExpr, valueExpr *planpb.ValueExpr) (*planpb.Expr, error) { + var err error + operand := operandExpr.GetValue() + if !isTemplateExpr(operandExpr) { + operand, err = castValue(arithExprDataType, operand) + if err != nil { + return nil, err + } + } + return &planpb.Expr{ Expr: &planpb.Expr_BinaryArithOpEvalRangeExpr{ BinaryArithOpEvalRangeExpr: &planpb.BinaryArithOpEvalRangeExpr{ ColumnInfo: columnInfo, ArithOp: arithOp, - RightOperand: operandExpr.GetValue(), + RightOperand: operand, Op: op, Value: valueExpr.GetValue(), OperandTemplateVariableName: operandExpr.GetTemplateVariableName(), @@ -255,7 +264,7 @@ func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, column }, }, IsTemplate: isTemplateExpr(operandExpr) || isTemplateExpr(valueExpr), - } + }, nil } func combineArrayLengthExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, valueExpr *planpb.ValueExpr) (*planpb.Expr, error) { @@ -297,7 +306,7 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, // a * 2 == 3 // a / 2 == 3 // a % 2 == 3 - return combineBinaryArithExpr(op, arithOp, leftExpr.GetInfo(), rightValue, valueExpr), nil + return combineBinaryArithExpr(op, arithOp, arithExprDataType, leftExpr.GetInfo(), rightValue, valueExpr) } else if rightExpr != nil && leftValue != nil { // 2 + a == 3 // 2 - a == 3 @@ -307,7 +316,7 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, switch arithExpr.GetOp() { case planpb.ArithOpType_Add, planpb.ArithOpType_Mul: - return combineBinaryArithExpr(op, arithOp, rightExpr.GetInfo(), leftValue, valueExpr), nil + return combineBinaryArithExpr(op, arithOp, arithExprDataType, rightExpr.GetInfo(), leftValue, valueExpr) default: return nil, fmt.Errorf("module field is not yet supported") } @@ -625,24 +634,27 @@ func checkValidModArith(tokenType planpb.ArithOpType, leftType, leftElementType, return nil } -func checkRangeCompared(dataType schemapb.DataType, value *planpb.GenericValue) error { +func castRangeValue(dataType schemapb.DataType, value *planpb.GenericValue) (*planpb.GenericValue, error) { switch dataType { case schemapb.DataType_String, schemapb.DataType_VarChar: if !IsString(value) { - return fmt.Errorf("invalid range operations") + return nil, fmt.Errorf("invalid range operations") } case schemapb.DataType_Bool: - return fmt.Errorf("invalid range operations on boolean expr") + return nil, fmt.Errorf("invalid range operations on boolean expr") case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64: if !IsInteger(value) { - return fmt.Errorf("invalid range operations") + return nil, fmt.Errorf("invalid range operations") } case schemapb.DataType_Float, schemapb.DataType_Double: if !IsNumber(value) { - return fmt.Errorf("invalid range operations") + return nil, fmt.Errorf("invalid range operations") + } + if IsInteger(value) { + return NewFloat(float64(value.GetInt64Val())), nil } } - return nil + return value, nil } func checkContainsElement(columnExpr *ExprWithType, op planpb.JSONContainsExpr_JSONOp, elementValue *planpb.GenericValue) error {