Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into columnar_shuffle_de…
Browse files Browse the repository at this point in the history
…fault
  • Loading branch information
viirya committed May 6, 2024
2 parents 65a9515 + 19379a3 commit 70c53b4
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 20 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,7 @@ Linux, Apple OSX (Intel and M1)
## Getting started

See the [DataFusion Comet User Guide](https://datafusion.apache.org/comet/user-guide/) for installation instructions.

## Contributing
See the [DataFusion Comet Contribution Guide](https://datafusion.apache.org/comet/contributor-guide/contributing.html)
for information on how to get started contributing to the project.
11 changes: 11 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ object CometConf {
.booleanConf
.createWithDefault(true)

val COMET_SHUFFLE_ENFORCE_MODE_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.shuffle.enforceMode.enabled")
.doc(
"Comet shuffle doesn't support Spark AQE coalesce partitions. If AQE coalesce " +
"partitions is enabled, Comet shuffle won't be triggered even enabled. This config " +
"is used to enforce Comet to trigger shuffle even if AQE coalesce partitions is " +
"enabled. This is for testing purpose only.")
.internal()
.booleanConf
.createWithDefault(false)

val COMET_EXEC_BROADCAST_FORCE_ENABLED: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.broadcast.enabled")
.doc(
Expand Down
9 changes: 9 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ pub enum CometError {
to_type: String,
},

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
CastOverFlow {
value: String,
from_type: String,
to_type: String,
},

#[error(transparent)]
Arrow {
#[from]
Expand Down
98 changes: 98 additions & 0 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,62 @@ macro_rules! cast_float_to_string {
}};
}

macro_rules! cast_int_to_int_macro {
(
$array: expr,
$eval_mode:expr,
$from_arrow_primitive_type: ty,
$to_arrow_primitive_type: ty,
$from_data_type: expr,
$to_native_type: ty,
$spark_from_data_type_name: expr,
$spark_to_data_type_name: expr
) => {{
let cast_array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
.unwrap();
let spark_int_literal_suffix = match $from_data_type {
&DataType::Int64 => "L",
&DataType::Int16 => "S",
&DataType::Int8 => "T",
_ => "",
};

let output_array = match $eval_mode {
EvalMode::Legacy => cast_array
.iter()
.map(|value| match value {
Some(value) => {
Ok::<Option<$to_native_type>, CometError>(Some(value as $to_native_type))
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
_ => cast_array
.iter()
.map(|value| match value {
Some(value) => {
let res = <$to_native_type>::try_from(value);
if res.is_err() {
Err(CometError::CastOverFlow {
value: value.to_string() + spark_int_literal_suffix,
from_type: $spark_from_data_type_name.to_string(),
to_type: $spark_to_data_type_name.to_string(),
})
} else {
Ok::<Option<$to_native_type>, CometError>(Some(res.unwrap()))
}
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
}?;
let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
result
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -218,6 +274,16 @@ impl Cast {
(DataType::Utf8, DataType::Timestamp(_, _)) => {
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)?
}
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int64, DataType::Int8)
| (DataType::Int32, DataType::Int16)
| (DataType::Int32, DataType::Int8)
| (DataType::Int16, DataType::Int8)
if self.eval_mode != EvalMode::Try =>
{
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)?
}
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
Expand Down Expand Up @@ -349,6 +415,38 @@ impl Cast {
cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
}

fn spark_cast_int_to_int(
array: &dyn Array,
eval_mode: EvalMode,
from_type: &DataType,
to_type: &DataType,
) -> CometResult<ArrayRef> {
match (from_type, to_type) {
(DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
),
(DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
),
(DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
),
(DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
),
(DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
),
(DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
),
_ => unreachable!(
"{}",
format!("invalid integer type {to_type} in cast from {from_type}")
),
}
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -917,8 +917,9 @@ object CometSparkSessionExtensions extends Logging {
COMET_EXEC_SHUFFLE_ENABLED.get(conf) &&
(conf.contains("spark.shuffle.manager") && conf.getConfString("spark.shuffle.manager") ==
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") &&
// TODO: AQE coalesce partitions feature causes Comet columnar shuffle memory leak
!conf.coalesceShufflePartitionsEnabled
// TODO: AQE coalesce partitions feature causes Comet shuffle memory leak.
// We should disable Comet shuffle when AQE coalesce partitions is enabled.
(!conf.coalesceShufflePartitionsEnabled || COMET_SHUFFLE_ENFORCE_MODE_ENABLED.get())

private[comet] def isCometScanEnabled(conf: SQLConf): Boolean = {
COMET_SCAN_ENABLED.get(conf)
Expand Down
37 changes: 24 additions & 13 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateShorts(), DataTypes.BooleanType)
}

ignore("cast ShortType to ByteType") {
test("cast ShortType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateShorts(), DataTypes.ByteType)
}
Expand Down Expand Up @@ -210,12 +210,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateInts(), DataTypes.BooleanType)
}

ignore("cast IntegerType to ByteType") {
test("cast IntegerType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateInts(), DataTypes.ByteType)
}

ignore("cast IntegerType to ShortType") {
test("cast IntegerType to ShortType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateInts(), DataTypes.ShortType)
}
Expand Down Expand Up @@ -256,17 +256,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateLongs(), DataTypes.BooleanType)
}

ignore("cast LongType to ByteType") {
test("cast LongType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.ByteType)
}

ignore("cast LongType to ShortType") {
test("cast LongType to ShortType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.ShortType)
}

ignore("cast LongType to IntegerType") {
test("cast LongType to IntegerType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.IntegerType)
}
Expand Down Expand Up @@ -921,15 +921,26 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val cometMessage = cometException.getCause.getMessage
.replace("Execution error: ", "")
if (CometSparkSessionExtensions.isSpark34Plus) {
// for Spark 3.4 we expect to reproduce the error message exactly
assert(cometMessage == sparkMessage)
} else if (CometSparkSessionExtensions.isSpark33Plus) {
// for Spark 3.3 we just need to strip the prefix from the Comet message
// before comparing
val cometMessageModified = cometMessage
.replace("[CAST_INVALID_INPUT] ", "")
.replace("[CAST_OVERFLOW] ", "")
assert(cometMessageModified == sparkMessage)
} else {
// Spark 3.2 and 3.3 have a different error message format so we can't do a direct
// comparison between Spark and Comet.
// Spark message is in format `invalid input syntax for type TYPE: VALUE`
// Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE`
// We just check that the comet message contains the same invalid value as the Spark message
val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2)
assert(cometMessage.contains(sparkInvalidValue))
// for Spark 3.2 we just make sure we are seeing a similar type of error
if (sparkMessage.contains("causes overflow")) {
assert(cometMessage.contains("due to an overflow"))
} else {
// assume that this is an invalid input message in the form:
// `invalid input syntax for type numeric: -9223372036854775809`
// we just check that the Comet message contains the same literal value
val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2)
assert(cometMessage.contains(sparkInvalidValue))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1369,8 +1369,10 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
assume(isSpark34Plus)
withSQLConf(
SQLConf.ANSI_ENABLED.key -> "false",
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_ENFORCE_MODE_ENABLED.key -> "true",
CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key -> "true",
CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true",
"spark.sql.extendedExplainProvider" -> "org.apache.comet.ExtendedExplainInfo") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
Seq(1, 100, 10000).foreach { numGroups =>
Seq(128, 1024, numValues + 1).foreach { batchSize =>
Seq(true, false).foreach { dictionaryEnabled =>
withSQLConf(CometConf.COMET_BATCH_SIZE.key -> batchSize.toString) {
withSQLConf(
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_ENFORCE_MODE_ENABLED.key -> "true",
CometConf.COMET_BATCH_SIZE.key -> batchSize.toString) {
withParquetTable(
(0 until numValues).map(i => (i, Random.nextInt() % numGroups)),
"tbl",
Expand All @@ -573,7 +576,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
Seq(1, 100, numValues).foreach { numGroups =>
Seq(128, numValues + 100).foreach { batchSize =>
Seq(true, false).foreach { dictionaryEnabled =>
withSQLConf(CometConf.COMET_BATCH_SIZE.key -> batchSize.toString) {
withSQLConf(
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_ENFORCE_MODE_ENABLED.key -> "true",
CometConf.COMET_BATCH_SIZE.key -> batchSize.toString) {
withTempPath { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFile(path, numValues, numGroups, dictionaryEnabled)
Expand Down Expand Up @@ -611,7 +617,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
Seq(1, 100, numValues).foreach { numGroups =>
Seq(128, numValues + 100).foreach { batchSize =>
Seq(true, false).foreach { dictionaryEnabled =>
withSQLConf(CometConf.COMET_BATCH_SIZE.key -> batchSize.toString) {
withSQLConf(
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_ENFORCE_MODE_ENABLED.key -> "true",
CometConf.COMET_BATCH_SIZE.key -> batchSize.toString) {
withTempPath { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFile(path, numValues, numGroups, dictionaryEnabled)
Expand Down Expand Up @@ -958,7 +967,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("first/last") {
withSQLConf(
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_ENFORCE_MODE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
Seq(true, false).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,27 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar

import testImplicits._

setupTestData()

test("Disable Comet shuffle with AQE coalesce partitions enabled") {
Seq(true, false).foreach { coalescePartitionsEnabled =>
withSQLConf(
CometConf.COMET_EXEC_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> coalescePartitionsEnabled.toString,
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val df = sql(
"SELECT * FROM (SELECT * FROM testData WHERE key = 0) t1 FULL JOIN " +
"testData2 t2 ON t1.key = t2.a")
if (coalescePartitionsEnabled) {
checkShuffleAnswer(df, 0)
} else {
checkShuffleAnswer(df, 2)
}
}
}
}

test("columnar shuffle on nested struct including nulls") {
Seq(10, 201).foreach { numPartitions =>
Seq("1.0", "10.0").foreach { ratio =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,11 @@ class CometExecSuite extends CometTestBase {
}.map(_.metrics).get

assert(metrics.contains("input_batches"))
assert(metrics("input_batches").value == 2L)
assert(metrics("input_batches").value == 8L)
assert(metrics.contains("input_rows"))
assert(metrics("input_rows").value == 10L)
assert(metrics.contains("output_batches"))
assert(metrics("output_batches").value == 1L)
assert(metrics("output_batches").value == 4L)
assert(metrics.contains("output_rows"))
assert(metrics("output_rows").value == 5L)
assert(metrics.contains("peak_mem_used"))
Expand Down

0 comments on commit 70c53b4

Please sign in to comment.