Skip to content

Commit

Permalink
code refactor and adding test case for equalto
Browse files Browse the repository at this point in the history
  • Loading branch information
vaibhawvipul committed Jun 27, 2024
1 parent 20ac2ba commit 00e44ed
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 49 deletions.
89 changes: 40 additions & 49 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -857,61 +857,52 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
case EqualTo(left, right) =>
val zero = Literal.default(left.dataType)
val negZero = UnaryMinus(zero)
if (left.dataType == DoubleType || left.dataType == FloatType) {
if (right == negZero) {
return Some(
ExprOuterClass.Expr
.newBuilder()
.setEq(
ExprOuterClass.Equal
.newBuilder()
.setLeft(exprToProtoInternal(left, inputs).get)
.setRight(exprToProtoInternal(Abs(right).child, inputs).get))
.build())
} else if (left == negZero) {
return Some(
ExprOuterClass.Expr
.newBuilder()
.setEq(
ExprOuterClass.Equal
.newBuilder()
.setLeft(exprToProtoInternal(Abs(left).child, inputs).get)
.setRight(exprToProtoInternal(right, inputs).get))
.build())
} else {
Some(
ExprOuterClass.Expr

def buildEqualExpr(leftExpr: Expr, rightExpr: Expr): ExprOuterClass.Expr = {
ExprOuterClass.Expr
.newBuilder()
.setEq(
ExprOuterClass.Equal
.newBuilder()
.setEq(
ExprOuterClass.Equal
.newBuilder()
.setLeft(exprToProtoInternal(left, inputs).get)
.setRight(exprToProtoInternal(right, inputs).get))
.build())
}
.setLeft(leftExpr)
.setRight(rightExpr))
.build()
}

var leftExpr, rightExpr: Option[Expr] = None
if (left.dataType == DoubleType || left.dataType == FloatType) {
leftExpr = exprToProtoInternal(If(EqualTo(left, negZero), zero, left), inputs)
rightExpr = exprToProtoInternal(If(EqualTo(right, negZero), zero, right), inputs)
} else {
leftExpr = exprToProtoInternal(left, inputs)
rightExpr = exprToProtoInternal(right, inputs)
(left, right) match {
case (`negZero`, _) =>
return Some(
buildEqualExpr(
exprToProtoInternal(Abs(left).child, inputs).get,
exprToProtoInternal(right, inputs).get))
case (_, `negZero`) =>
return Some(
buildEqualExpr(
exprToProtoInternal(left, inputs).get,
exprToProtoInternal(Abs(right).child, inputs).get))
case _ =>
Some(
buildEqualExpr(
exprToProtoInternal(left, inputs).get,
exprToProtoInternal(right, inputs).get))
}
}
if (leftExpr.isDefined && rightExpr.isDefined) {
val builder = ExprOuterClass.Equal.newBuilder()
builder.setLeft(leftExpr.get)
builder.setRight(rightExpr.get)

Some(
ExprOuterClass.Expr
.newBuilder()
.setEq(builder)
.build())
} else {
withInfo(expr, left, right)
None
val (leftExpr, rightExpr) =
if (left.dataType == DoubleType || left.dataType == FloatType) {
(
exprToProtoInternal(If(EqualTo(left, negZero), zero, left), inputs),
exprToProtoInternal(If(EqualTo(right, negZero), zero, right), inputs))
} else {
(exprToProtoInternal(left, inputs), exprToProtoInternal(right, inputs))
}

(leftExpr, rightExpr) match {
case (Some(l), Some(r)) => Some(buildEqualExpr(l, r))
case _ =>
withInfo(expr, left, right)
None
}

case Not(EqualTo(left, right)) =>
Expand Down
24 changes: 24 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,30 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("zero equality") {
withParquetTable(
Seq(
(-0.0, 0.0),
(0.0, -0.0),
(-0.0, -0.0),
(0.0, 0.0),
(1.0, 2.0),
(1.0, 1.0),
(1.0, 0.0),
(0.0, 1.0),
(-0.0, 1.0),
(1.0, -0.0),
(1.0, -1.0),
(-1.0, 1.0),
(-1.0, -0.0),
(-1.0, -1.0),
(-1.0, 0.0),
(0.0, -1.0)),
"t") {
checkSparkAnswerAndOperator("SELECT _1 == _2 FROM t")
}
}

test("remainder") {
val query = "SELECT _1, _2, _1 % _2 FROM t"
withParquetTable(Seq((21840, -0.0), (21840, 5.0)), "t") {
Expand Down

0 comments on commit 00e44ed

Please sign in to comment.