Skip to content

Commit

Permalink
build: Use specified branch of arrow-rs with workaround to invalid of…
Browse files Browse the repository at this point in the history
…fset buffers from Java Arrow (#239)

* feat: Use specified branch of arrow-rs with workaround to invalid offset buffers from Java Arrow

* Use FunctionRegistry

* Fix

* Update

* Restore config

* Restore plan stability
  • Loading branch information
viirya authored Apr 8, 2024
1 parent 8a512ba commit 59f535c
Show file tree
Hide file tree
Showing 15 changed files with 688 additions and 524 deletions.
699 changes: 345 additions & 354 deletions core/Cargo.lock

Large diffs are not rendered by default.

19 changes: 10 additions & 9 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ include = [

[dependencies]
parquet-format = "4.0.0" # This must be kept in sync with that from parquet crate
arrow = { version = "~50.0.0", features = ["prettyprint", "ffi", "chrono-tz"] }
arrow-array = { version = "~50.0.0" }
arrow-data = { version = "~50.0.0" }
arrow-schema = { version = "~50.0.0" }
arrow-string = { version = "~50.0.0" }
parquet = { version = "~50.0.0", default-features = false, features = ["experimental"] }
arrow = { git = "https://github.com/viirya/arrow-rs.git", rev = "3f1ae0c", features = ["prettyprint", "ffi", "chrono-tz"] }
arrow-array = { git = "https://github.com/viirya/arrow-rs.git", rev = "3f1ae0c" }
arrow-data = { git = "https://github.com/viirya/arrow-rs.git", rev = "3f1ae0c" }
arrow-schema = { git = "https://github.com/viirya/arrow-rs.git", rev = "3f1ae0c" }
arrow-string = { git = "https://github.com/viirya/arrow-rs.git", rev = "3f1ae0c" }
parquet = { git = "https://github.com/viirya/arrow-rs.git", rev = "3f1ae0c", default-features = false, features = ["experimental"] }
half = { version = "~2.1", default-features = false }
futures = "0.3.28"
mimalloc = { version = "*", default-features = false, optional = true }
Expand Down Expand Up @@ -66,9 +66,10 @@ itertools = "0.11.0"
chrono = { version = "0.4", default-features = false, features = ["clock"] }
chrono-tz = { version = "0.8" }
paste = "1.0.14"
datafusion-common = { version = "36.0.0" }
datafusion = { default-features = false, version = "36.0.0", features = ["unicode_expressions"] }
datafusion-physical-expr = { version = "36.0.0", default-features = false , features = ["unicode_expressions"] }
datafusion-common = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "111a940" }
datafusion = { default-features = false, git = "https://github.com/viirya/arrow-datafusion.git", rev = "111a940", features = ["unicode_expressions"] }
datafusion-functions = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "111a940" }
datafusion-physical-expr = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "111a940", default-features = false, features = ["unicode_expressions"] }
unicode-segmentation = "^1.10.1"
once_cell = "1.18.0"
regex = "1.9.6"
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/datafusion/expressions/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use arrow_schema::{DataType, Field};
use datafusion::logical_expr::{
type_coercion::aggregates::avg_return_type, Accumulator, EmitTo, GroupsAccumulator,
};
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr};
use std::{any::Any, sync::Arc};

Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/datafusion/expressions/avg_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use arrow_array::{
};
use arrow_schema::{DataType, Field};
use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator};
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr};
use std::{any::Any, sync::Arc};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow::record_batch::RecordBatch;
use arrow_array::cast::as_primitive_array;
use arrow_schema::{DataType, Schema};
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, PhysicalExpr};
use std::{
any::Any,
Expand Down
163 changes: 147 additions & 16 deletions core/src/execution/datafusion/expressions/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::{cmp::min, str::FromStr, sync::Arc};
use std::{any::Any, cmp::min, fmt::Debug, str::FromStr, sync::Arc};

use arrow::{
array::{
Expand All @@ -27,16 +27,18 @@ use arrow::{
use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array};
use arrow_schema::DataType;
use datafusion::{
logical_expr::{BuiltinScalarFunction, ScalarFunctionImplementation},
execution::FunctionRegistry,
logical_expr::{
BuiltinScalarFunction, ScalarFunctionDefinition, ScalarFunctionImplementation,
ScalarUDFImpl, Signature, Volatility,
},
physical_plan::ColumnarValue,
};
use datafusion_common::{
cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
Result as DataFusionResult, ScalarValue,
};
use datafusion_physical_expr::{
execution_props::ExecutionProps, functions::create_physical_fun, math_expressions,
};
use datafusion_physical_expr::{math_expressions, udf::ScalarUDF};
use num::{
integer::{div_ceil, div_floor},
BigInt, Signed, ToPrimitive,
Expand All @@ -46,20 +48,94 @@ use unicode_segmentation::UnicodeSegmentation;
/// Create a physical scalar function.
pub fn create_comet_physical_fun(
fun_name: &str,
execution_props: &ExecutionProps,
data_type: DataType,
) -> Result<ScalarFunctionImplementation, DataFusionError> {
registry: &dyn FunctionRegistry,
) -> Result<ScalarFunctionDefinition, DataFusionError> {
match fun_name {
"ceil" => Ok(Arc::new(move |x| spark_ceil(x, &data_type))),
"floor" => Ok(Arc::new(move |x| spark_floor(x, &data_type))),
"rpad" => Ok(Arc::new(spark_rpad)),
"round" => Ok(Arc::new(move |x| spark_round(x, &data_type))),
"unscaled_value" => Ok(Arc::new(spark_unscaled_value)),
"make_decimal" => Ok(Arc::new(move |x| spark_make_decimal(x, &data_type))),
"decimal_div" => Ok(Arc::new(move |x| spark_decimal_div(x, &data_type))),
"ceil" => {
let scalar_func = CometScalarFunction::new(
"ceil".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(move |args| spark_ceil(args, &data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
}
"floor" => {
let scalar_func = CometScalarFunction::new(
"floor".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(move |args| spark_floor(args, &data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
}
"rpad" => {
let scalar_func = CometScalarFunction::new(
"rpad".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(spark_rpad),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
}
"round" => {
let scalar_func = CometScalarFunction::new(
"round".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(move |args| spark_round(args, &data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
}
"unscaled_value" => {
let scalar_func = CometScalarFunction::new(
"unscaled_value".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(spark_unscaled_value),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
}
"make_decimal" => {
let scalar_func = CometScalarFunction::new(
"make_decimal".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(move |args| spark_make_decimal(args, &data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
}
"decimal_div" => {
let scalar_func = CometScalarFunction::new(
"decimal_div".to_string(),
Signature::variadic_any(Volatility::Immutable),
data_type.clone(),
Arc::new(move |args| spark_decimal_div(args, &data_type)),
);
Ok(ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(scalar_func),
)))
}
_ => {
let fun = &BuiltinScalarFunction::from_str(fun_name)?;
create_physical_fun(fun, execution_props)
let fun = BuiltinScalarFunction::from_str(fun_name);
if fun.is_err() {
Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?))
} else {
Ok(ScalarFunctionDefinition::BuiltIn(fun?))
}
}
}
}
Expand Down Expand Up @@ -89,6 +165,61 @@ macro_rules! downcast_compute_op {
}};
}

struct CometScalarFunction {
name: String,
signature: Signature,
data_type: DataType,
func: ScalarFunctionImplementation,
}

impl Debug for CometScalarFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CometScalarFunction")
.field("name", &self.name)
.field("signature", &self.signature)
.field("data_type", &self.data_type)
.finish()
}
}

impl CometScalarFunction {
fn new(
name: String,
signature: Signature,
data_type: DataType,
func: ScalarFunctionImplementation,
) -> Self {
Self {
name,
signature,
data_type,
func,
}
}
}

impl ScalarUDFImpl for CometScalarFunction {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
self.name.as_str()
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> DataFusionResult<DataType> {
Ok(self.data_type.clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
(self.func)(args)
}
}

/// `ceil` function that simulates Spark `ceil` expression
pub fn spark_ceil(
args: &[ColumnarValue],
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/datafusion/expressions/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Schema, TimeUnit};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, DataFusionError, ScalarValue};
use datafusion_common::{internal_err, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use jni::{
objects::JByteArray,
Expand Down
8 changes: 4 additions & 4 deletions core/src/execution/datafusion/expressions/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::{
};

use arrow::{
compute::{hour_dyn, minute_dyn, second_dyn},
compute::{date_part, DatePart},
record_batch::RecordBatch,
};
use arrow_schema::{DataType, Schema, TimeUnit::Microsecond};
Expand Down Expand Up @@ -101,7 +101,7 @@ impl PhysicalExpr for HourExec {
Some(self.timezone.clone().into()),
)),
);
let result = hour_dyn(&array)?;
let result = date_part(&array, DatePart::Hour)?;

Ok(ColumnarValue::Array(result))
}
Expand Down Expand Up @@ -195,7 +195,7 @@ impl PhysicalExpr for MinuteExec {
Some(self.timezone.clone().into()),
)),
);
let result = minute_dyn(&array)?;
let result = date_part(&array, DatePart::Minute)?;

Ok(ColumnarValue::Array(result))
}
Expand Down Expand Up @@ -289,7 +289,7 @@ impl PhysicalExpr for SecondExec {
Some(self.timezone.clone().into()),
)),
);
let result = second_dyn(&array)?;
let result = date_part(&array, DatePart::Second)?;

Ok(ColumnarValue::Array(result))
}
Expand Down
26 changes: 15 additions & 11 deletions core/src/execution/datafusion/operators/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ use arrow_schema::SchemaRef;
use datafusion::{
execution::TaskContext,
physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
SendableRecordBatchStream,
DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties,
RecordBatchStream, SendableRecordBatchStream,
},
};
use datafusion_common::DataFusionError;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
use futures::{Stream, StreamExt};
use std::{
any::Any,
Expand All @@ -41,6 +41,7 @@ pub struct CometExpandExec {
projections: Vec<Vec<Arc<dyn PhysicalExpr>>>,
child: Arc<dyn ExecutionPlan>,
schema: SchemaRef,
cache: PlanProperties,
}

impl CometExpandExec {
Expand All @@ -50,10 +51,17 @@ impl CometExpandExec {
child: Arc<dyn ExecutionPlan>,
schema: SchemaRef,
) -> Self {
let cache = PlanProperties::new(
EquivalenceProperties::new(schema.clone()),
Partitioning::UnknownPartitioning(1),
ExecutionMode::Bounded,
);

Self {
projections,
child,
schema,
cache,
}
}
}
Expand Down Expand Up @@ -88,14 +96,6 @@ impl ExecutionPlan for CometExpandExec {
self.schema.clone()
}

fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
}

fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.child.clone()]
}
Expand All @@ -122,6 +122,10 @@ impl ExecutionPlan for CometExpandExec {
ExpandStream::new(self.projections.clone(), child_stream, self.schema.clone());
Ok(Box::pin(expand_stream))
}

fn properties(&self) -> &PlanProperties {
&self.cache
}
}

pub struct ExpandStream {
Expand Down
Loading

0 comments on commit 59f535c

Please sign in to comment.