Skip to content

Commit

Permalink
handled cast for long to short
Browse files Browse the repository at this point in the history
  • Loading branch information
ganesh.maddula committed Apr 27, 2024
1 parent 8485558 commit 54a5f59
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 6 deletions.
10 changes: 10 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ pub enum CometError {
from_type: String,
to_type: String,
},
// Note that this message format is based on Spark 3.4 and is more detailed than the message
// returned by Spark 3.2 or 3.3
#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
CastOverFlow {
value: String,
from_type: String,
to_type: String,
},

#[error(transparent)]
Arrow {
Expand Down
48 changes: 47 additions & 1 deletion core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow::{
record_batch::RecordBatch,
util::display::FormatOptions,
};
use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait};
use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, Int16Array, Int64Array, OffsetSizeTrait};
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
Expand Down Expand Up @@ -103,11 +103,57 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
}
(DataType::Int64, DataType::Int16) if self.eval_mode != EvalMode::Try => {
// (DataType::Int64, DataType::Int16) => {
Self::spark_cast_int64_to_int16(&array, self.eval_mode)?
}
_ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
};
let result = spark_cast(cast_result, from_type, to_type);
Ok(result)
}
fn spark_cast_int64_to_int16(
from: &dyn Array,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
{
let array = from
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();

let output_array = match eval_mode {
EvalMode::Legacy => {
array.iter()
.map(|value| match value{
Some(value) => Ok::<Option<i16>, CometError>(Some(value as i16)),
_ => Ok(None)
})
.collect::<Result<Int16Array, _>>()?
},
_ => {
array.iter()
.map(|value| match value{
Some(value) => {
let res = i16::try_from(value);
if res.is_err() {
Err(CometError::CastOverFlow{
value: value.to_string() + "L",
from_type: "BIGINT".to_string(),
to_type: "SMALLINT".to_string(),
})
}else{
Ok::<Option<i16>, CometError>(Some(i16::try_from(value).unwrap()))
}

},
_ => Ok(None)
})
.collect::<Result<Int16Array, _>>()?
}
};
Ok(Arc::new(output_array))
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
Expand Down
31 changes: 26 additions & 5 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
private val datePattern = "0123456789/" + whitespaceChars
private val timestampPattern = "0123456789/:T" + whitespaceChars

ignore("cast long to short") {
castTest(generateLongs, DataTypes.ShortType)
}

// ignore("cast long to short") {
// castTest(generateLongs, DataTypes.ShortType)
// }
//
ignore("cast float to bool") {
castTest(generateFloats, DataTypes.BooleanType)
}
Expand Down Expand Up @@ -106,14 +106,35 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(values.toDF("a"), DataTypes.DoubleType)
}

// spotless:off
test("cast short to int"){

}
test("cast short to long"){

}
test("cast int to short"){

}
test("cast int to long"){

}
test("cast long to short"){
castTest(generateLongs, DataTypes.ShortType)
}
test("cast long to int"){

}
// spotless:on

private def generateFloats(): DataFrame = {
val r = new Random(0)
Range(0, dataSize).map(_ => r.nextFloat()).toDF("a")
}

private def generateLongs(): DataFrame = {
val r = new Random(0)
Range(0, dataSize).map(_ => r.nextLong()).toDF("a")
(Range(0, dataSize).map(_ => r.nextLong()) ++ Seq(Long.MaxValue, Long.MinValue)).toDF("a")
}

private def generateString(r: Random, chars: String, maxLen: Int): String = {
Expand Down

0 comments on commit 54a5f59

Please sign in to comment.