Skip to content

Commit

Permalink
fix: Optimize rpad
Browse files Browse the repository at this point in the history
  • Loading branch information
kazuyukitanimura committed Aug 3, 2024
1 parent 25957dd commit ed1a846
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions native/spark-expr/src/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::{cmp::min, sync::Arc};

use arrow::{
array::{
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray,
Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::builder::GenericStringBuilder;
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
Expand All @@ -35,6 +34,8 @@ use num::{
integer::{div_ceil, div_floor},
BigInt, Signed, ToPrimitive,
};
use std::fmt::Write;
use std::{cmp::min, sync::Arc};
use unicode_segmentation::UnicodeSegmentation;

mod unhex;
Expand Down Expand Up @@ -390,7 +391,7 @@ pub fn spark_round(
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
match args[0].data_type() {
match array.data_type() {
DataType::Utf8 => spark_rpad_internal::<i32>(array, *length),
DataType::LargeUtf8 => spark_rpad_internal::<i64>(array, *length),
// TODO: handle Dictionary types
Expand All @@ -410,29 +411,37 @@ fn spark_rpad_internal<T: OffsetSizeTrait>(
length: i32,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
let length = 0.max(length) as usize;
let empty_str = "";
let space_string = " ".repeat(length);

let mut builder =
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);

let result = string_array
.iter()
.map(|string| match string {
for string in string_array.iter() {
match string {
Some(string) => {
let length = if length < 0 { 0 } else { length as usize };
if length == 0 {
Ok(Some("".to_string()))
builder.append_value(empty_str);
} else if length == 1 && string.len() > 0 {
// Special case: when length == 1, no need to calculate expensive graphemes
builder.append_value(string);
} else {
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
if length < graphemes.len() {
Ok(Some(string.to_string()))
let graphemes_len = string.graphemes(true).count();
if length <= graphemes_len {
builder.append_value(string);
} else {
let mut s = string.to_string();
s.push_str(" ".repeat(length - graphemes.len()).as_str());
Ok(Some(s))
// write_str updates only the value buffer, not null nor offset buffer
// This is convenient for concatenating str(s)
builder.write_str(string)?;
builder.append_value(&space_string[graphemes_len..]);
}
}
}
_ => Ok(None),
})
.collect::<Result<GenericStringArray<T>, DataFusionError>>()?;
Ok(ColumnarValue::Array(Arc::new(result)))
_ => builder.append_null(),
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}

// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
Expand Down

0 comments on commit ed1a846

Please sign in to comment.