From efc62865ae9d66c854b2b9839eb3b085bdee642c Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Sat, 3 Aug 2024 03:16:45 -0700 Subject: [PATCH] fix: Optimize rpad --- native/spark-expr/src/scalar_funcs.rs | 23 +++++++------------ .../tpcds-micro-benchmarks/char_type.sql | 7 ++++++ .../benchmark/CometTPCDSMicroBenchmark.scala | 1 + 3 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql diff --git a/native/spark-expr/src/scalar_funcs.rs b/native/spark-expr/src/scalar_funcs.rs index 7cb163905..4212e48ab 100644 --- a/native/spark-expr/src/scalar_funcs.rs +++ b/native/spark-expr/src/scalar_funcs.rs @@ -36,7 +36,6 @@ use num::{ }; use std::fmt::Write; use std::{cmp::min, sync::Arc}; -use unicode_segmentation::UnicodeSegmentation; mod unhex; pub use unhex::spark_unhex; @@ -412,7 +411,6 @@ fn spark_rpad_internal( ) -> Result { let string_array = as_generic_string_array::(array)?; let length = 0.max(length) as usize; - let empty_str = ""; let space_string = " ".repeat(length); let mut builder = @@ -421,21 +419,16 @@ fn spark_rpad_internal( for string in string_array.iter() { match string { Some(string) => { - if length == 0 { - builder.append_value(empty_str); - } else if length == 1 && string.len() > 0 { - // Special case: when length == 1, no need to calculate expensive graphemes + // It looks Spark's UTF8String is closer to chars rather than graphemes + // https://stackoverflow.com/a/46290728 + let char_len = string.chars().count(); + if length <= char_len { builder.append_value(string); } else { - let graphemes_len = string.graphemes(true).count(); - if length <= graphemes_len { - builder.append_value(string); - } else { - // write_str updates only the value buffer, not null nor offset buffer - // This is convenient for concatenating str(s) - builder.write_str(string)?; - builder.append_value(&space_string[graphemes_len..]); - } + // write_str updates only the value buffer, not null nor offset buffer + // This is convenient for concatenating str(s) + builder.write_str(string)?; + builder.append_value(&space_string[char_len..]); } } _ => builder.append_null(), diff --git a/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql b/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql new file mode 100644 index 000000000..8a5359d4c --- /dev/null +++ b/spark/src/test/resources/tpcds-micro-benchmarks/char_type.sql @@ -0,0 +1,7 @@ +SELECT + cd_gender +FROM customer_demographics +WHERE + cd_gender = 'M' AND + cd_marital_status = 'S' AND + cd_education_status = 'College' diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala index 01909a4d7..9e6c2fd7b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala @@ -62,6 +62,7 @@ object CometTPCDSMicroBenchmark extends CometTPCQueryBenchmarkBase { "agg_sum_integers_no_grouping", "case_when_column_or_null", "case_when_scalar", + "char_type", "filter_highly_selective", "filter_less_selective", "if_column_or_null",