From a47e771807d7155571449c6a4177b8b8b25afcf0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 22 Dec 2023 09:46:20 -0800 Subject: [PATCH] Improve regexp_match performance by avoiding cloning Regex --- .../physical-expr/src/regex_expressions.rs | 83 ++++++++++++++++++- 1 file changed, 80 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 41cd01949595..8658e4671e42 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -25,7 +25,8 @@ use arrow::array::{ new_null_array, Array, ArrayDataBuilder, ArrayRef, BufferBuilder, GenericStringArray, OffsetSizeTrait, }; -use arrow::compute; +use arrow_array::builder::{GenericStringBuilder, ListBuilder}; +use arrow_schema::ArrowError; use datafusion_common::plan_err; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, @@ -58,7 +59,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { 2 => { let values = as_generic_string_array::(&args[0])?; let regex = as_generic_string_array::(&args[1])?; - compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError) + _regexp_match(values, regex, None).map_err(DataFusionError::ArrowError) } 3 => { let values = as_generic_string_array::(&args[0])?; @@ -69,7 +70,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { Some(f) if f.iter().any(|s| s == Some("g")) => { plan_err!("regexp_match() does not support the \"global\" option") }, - _ => compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError), + _ => _regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError), } } other => internal_err!( @@ -78,6 +79,82 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { } } +/// TODO: Remove this once it is included in arrow-rs new release. +fn _regexp_match( + array: &GenericStringArray, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, +) -> std::result::Result { + let mut patterns: std::collections::HashMap = + std::collections::HashMap::new(); + let builder: GenericStringBuilder = + GenericStringBuilder::with_capacity(0, 0); + let mut list_builder = ListBuilder::new(builder); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(value) => format!("(?{value}){pattern}"), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + (Some(_), Some(pattern)) if pattern == *"" => { + list_builder.values().append_value(""); + list_builder.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.insert(pattern.clone(), re); + patterns.get(&pattern).unwrap() + } + }; + match re.captures(value) { + Some(caps) => { + let mut iter = caps.iter(); + if caps.len() > 1 { + iter.next(); + } + for m in iter.flatten() { + list_builder.values().append_value(m.as_str()); + } + + list_builder.append(true); + } + None => list_builder.append(false), + } + } + _ => list_builder.append(false), + } + Ok(()) + }) + .collect::, ArrowError>>()?; + Ok(Arc::new(list_builder.finish())) +} + /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String {