Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ANSI support for Add, Subtract & Multiply #1135

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,34 +185,39 @@ impl PhysicalPlanner {
ExprStruct::Add(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.fail_on_error,
expr.return_type.as_ref(),
DataFusionOperator::Plus,
input_schema,
),
ExprStruct::Subtract(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.fail_on_error,
expr.return_type.as_ref(),
DataFusionOperator::Minus,
input_schema,
),
ExprStruct::Multiply(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.fail_on_error,
expr.return_type.as_ref(),
DataFusionOperator::Multiply,
input_schema,
),
ExprStruct::Divide(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.fail_on_error,
expr.return_type.as_ref(),
DataFusionOperator::Divide,
input_schema,
),
ExprStruct::Remainder(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.fail_on_error,
expr.return_type.as_ref(),
DataFusionOperator::Modulo,
input_schema,
Expand Down Expand Up @@ -777,6 +782,7 @@ impl PhysicalPlanner {
&self,
left: &Expr,
right: &Expr,
fail_on_error: bool,
return_type: Option<&spark_expression::DataType>,
op: DataFusionOperator,
input_schema: SchemaRef,
Expand Down Expand Up @@ -818,7 +824,8 @@ impl PhysicalPlanner {
EvalMode::Legacy,
false,
));
let child = Arc::new(BinaryExpr::new(left, op, right));
let child =
Arc::new(BinaryExpr::new(left, op, right).with_fail_on_overflow(fail_on_error));
Ok(Arc::new(Cast::new_without_timezone(
child,
data_type,
Expand All @@ -844,7 +851,9 @@ impl PhysicalPlanner {
data_type,
)))
}
_ => Ok(Arc::new(BinaryExpr::new(left, op, right))),
_ => Ok(Arc::new(
BinaryExpr::new(left, op, right).with_fail_on_overflow(fail_on_error),
)),
}
}

Expand Down
79 changes: 52 additions & 27 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("add overflow (ANSI enabled)") {
// Enabling ANSI will cause native engine failure, but as we cannot catch
// native error now, we cannot test it here.
withAnsiMode(enabled = true) {
withParquetTable(Seq((Int.MaxValue, 1)), "tbl") {
checkOverflow("SELECT _1 + _2 FROM tbl", "")
}
}
}

test("subtract overflow (ANSI enabled)") {
// Enabling ANSI will cause native engine failure, but as we cannot catch
// native error now, we cannot test it here.
withAnsiMode(enabled = false) {
withParquetTable(Seq((Int.MinValue, 1)), "tbl") {
checkOverflow("SELECT _1 - _2 FROM tbl", "")
}
}
}

test("divide by zero (ANSI disable)") {
// Enabling ANSI will cause native engine failure, but as we cannot catch
// native error now, we cannot test it here.
Expand Down Expand Up @@ -1922,38 +1942,41 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}
test("unary negative integer overflow test") {
def withAnsiMode(enabled: Boolean)(f: => Unit): Unit = {
withSQLConf(
SQLConf.ANSI_ENABLED.key -> enabled.toString,
CometConf.COMET_ANSI_MODE_ENABLED.key -> enabled.toString,
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true")(f)
}

def checkOverflow(query: String, dtype: String): Unit = {
checkSparkMaybeThrows(sql(query)) match {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains(dtype + " overflow"))
assert(cometException.getMessage.contains(dtype + " overflow"))
case (None, None) => checkSparkAnswerAndOperator(sql(query))
case (None, Some(ex)) =>
fail("Comet threw an exception but Spark did not " + ex.getMessage)
case (Some(_), None) =>
fail("Spark threw an exception but Comet did not")
}
def withAnsiMode(enabled: Boolean)(f: => Unit): Unit = {
withSQLConf(
SQLConf.ANSI_ENABLED.key -> enabled.toString,
CometConf.COMET_ANSI_MODE_ENABLED.key -> enabled.toString,
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true")(f)
}

def checkOverflow(query: String, dtype: String): Unit = {
checkSparkMaybeThrows(sql(query)) match {
case (Some(sparkException), Some(cometException)) =>
println(cometException.getMessage)
assert(sparkException.getMessage.contains(dtype + " overflow"))
assert(cometException.getMessage.contains(dtype + " overflow"))
case (None, None) => checkSparkAnswerAndOperator(sql(query))
case (None, Some(ex)) =>
fail("Comet threw an exception but Spark did not " + ex.getMessage)
case (Some(_), None) =>
fail("Spark threw an exception but Comet did not")
}
}

def runArrayTest(query: String, dtype: String, path: String): Unit = {
withParquetTable(path, "t") {
withAnsiMode(enabled = false) {
checkSparkAnswerAndOperator(sql(query))
}
withAnsiMode(enabled = true) {
checkOverflow(query, dtype)
}
def runArrayTest(query: String, dtype: String, path: String): Unit = {
withParquetTable(path, "t") {
withAnsiMode(enabled = false) {
checkSparkAnswerAndOperator(sql(query))
}
withAnsiMode(enabled = true) {
checkOverflow(query, dtype)
}
}
}

test("unary negative integer overflow test") {

withTempDir { dir =>
// Array values test
Expand Down Expand Up @@ -2012,6 +2035,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("add integer overflow test") {}

test("readSidePadding") {
// https://stackoverflow.com/a/46290728
val table = "test"
Expand Down