Skip to content

Commit

Permalink
For review
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jan 29, 2024
1 parent eaf4c11 commit fbd376e
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 53 deletions.
22 changes: 5 additions & 17 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ mod tests {
use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics};
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{ColumnarValue, Operator, Signature, Volatility};
use datafusion_expr::{ColumnarValue, Operator};
use datafusion_physical_expr::expressions::{
BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr,
};
Expand Down Expand Up @@ -1270,10 +1270,7 @@ mod tests {
],
DataType::Int32,
None,
Signature::exact(
vec![DataType::Float32, DataType::Float32],
Volatility::Immutable,
),
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1340,10 +1337,7 @@ mod tests {
],
DataType::Int32,
None,
Signature::exact(
vec![DataType::Float32, DataType::Float32],
Volatility::Immutable,
),
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 3))),
Expand Down Expand Up @@ -1413,10 +1407,7 @@ mod tests {
],
DataType::Int32,
None,
Signature::exact(
vec![DataType::Float32, DataType::Float32],
Volatility::Immutable,
),
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1483,10 +1474,7 @@ mod tests {
],
DataType::Int32,
None,
Signature::exact(
vec![DataType::Float32, DataType::Float32],
Volatility::Immutable,
),
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d_new", 3))),
Expand Down
36 changes: 14 additions & 22 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@
// under the License.

use arrow::compute::kernels::numeric::add;
use arrow_array::{
ArrayRef, Float64Array, Int32Array, Int64Array, RecordBatch, UInt64Array, UInt8Array,
};
use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
use datafusion_expr::TypeSignature::{Any, Variadic};
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, ScalarUDF,
ScalarUDFImpl, Signature, Volatility,
Expand Down Expand Up @@ -404,10 +401,7 @@ pub struct RandomUDF {
impl RandomUDF {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![Any(0), Variadic(vec![Float64])],
Volatility::Volatile,
),
signature: Signature::any(0, Volatility::Volatile),
}
}
}
Expand All @@ -431,7 +425,9 @@ impl ScalarUDFImpl for RandomUDF {

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let len: usize = match &args[0] {
ColumnarValue::Array(array) => array.len(),
// This udf is always invoked with zero argument so its argument
// is a null array indicating the batch size.
ColumnarValue::Array(array) if array.data_type().is_null() => array.len(),
_ => {
return Err(datafusion::error::DataFusionError::Internal(
"Invalid argument type".to_string(),
Expand All @@ -445,25 +441,21 @@ impl ScalarUDFImpl for RandomUDF {
}
}

/// Ensure that a user defined function with zero argument will be invoked
/// with a null array indicating the batch size.
#[tokio::test]
async fn test_user_defined_functions_zero_argument() -> Result<()> {
let ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![
Field::new("index", DataType::UInt8, false),
Field::new("uint", DataType::UInt64, true),
Field::new("int", DataType::Int64, true),
Field::new("float", DataType::Float64, true),
]));
let schema = Arc::new(Schema::new(vec![Field::new(
"index",
DataType::UInt8,
false,
)]));

let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(UInt8Array::from_iter_values([1, 2, 3])),
Arc::new(UInt64Array::from(vec![Some(2), Some(3), None])),
Arc::new(Int64Array::from(vec![Some(-2), Some(3), None])),
Arc::new(Float64Array::from(vec![Some(1.0), Some(3.3), None])),
],
vec![Arc::new(UInt8Array::from_iter_values([1, 2, 3]))],
)?;

ctx.register_batch("data_table", batch)?;
Expand Down Expand Up @@ -492,7 +484,7 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> {

assert_eq!(random_udf.len(), native_random.len());

let mut previous = 1.0;
let mut previous = -1.0;
for i in 0..random_udf.len() {
assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0);
assert!(random_udf.value(i) != previous);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub fn create_physical_expr(
input_phy_exprs.to_vec(),
data_type,
monotonicity,
fun.signature().clone(),
fun.signature().type_signature.supports_zero_argument(),
)))
}

Expand Down
21 changes: 12 additions & 9 deletions datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_expr::{
expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity,
ScalarFunctionImplementation, Signature,
ScalarFunctionImplementation,
};

/// Physical expression of a scalar function
Expand All @@ -58,8 +58,8 @@ pub struct ScalarFunctionExpr {
// and it specifies the effect of an increase or decrease in
// the corresponding `arg` to the function value.
monotonicity: Option<FuncMonotonicity>,
// Signature of the function
signature: Signature,
// Whether this function can be invoked with zero arguments
supports_zero_argument: bool,
}

impl Debug for ScalarFunctionExpr {
Expand All @@ -81,15 +81,15 @@ impl ScalarFunctionExpr {
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
monotonicity: Option<FuncMonotonicity>,
signature: Signature,
supports_zero_argument: bool,
) -> Self {
Self {
fun,
name: name.to_owned(),
args,
return_type,
monotonicity,
signature,
supports_zero_argument,
}
}

Expand Down Expand Up @@ -142,9 +142,12 @@ impl PhysicalExpr for ScalarFunctionExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
// evaluate the arguments, if there are no arguments we'll instead pass in a null array
// indicating the batch size (as a convention)
let inputs = match (self.args.len(), self.name.parse::<BuiltinScalarFunction>()) {
let inputs = match (
self.args.is_empty(),
self.name.parse::<BuiltinScalarFunction>(),
) {
// MakeArray support zero argument but has the different behavior from the array with one null.
(0, Ok(scalar_fun))
(true, Ok(scalar_fun))
if scalar_fun
.signature()
.type_signature
Expand All @@ -155,7 +158,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
}
// If the function supports zero argument, we pass in a null array indicating the batch size.
// This is for user-defined functions.
(0, Err(_)) if self.signature.type_signature.supports_zero_argument() => {
(true, Err(_)) if self.supports_zero_argument => {
vec![ColumnarValue::create_null_array(batch.num_rows())]
}
_ => self
Expand Down Expand Up @@ -184,7 +187,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
children,
self.return_type().clone(),
self.monotonicity.clone(),
self.signature.clone(),
self.supports_zero_argument,
)))
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub fn create_physical_expr(
input_phy_exprs.to_vec(),
fun.return_type(&input_exprs_types)?,
fun.monotonicity()?,
fun.signature().clone(),
fun.signature().type_signature.supports_zero_argument(),
)))
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/src/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ pub fn parse_physical_expr(
args,
convert_required!(e.return_type)?,
None,
signature.clone(),
signature.type_signature.supports_zero_argument(),
))
}
ExprType::LikeExpr(like_expr) => Arc::new(LikeExpr::new(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> {
vec![col("a", &schema)?],
DataType::Float64,
None,
Signature::exact(vec![DataType::Int64], Volatility::Immutable),
false,
);

let project =
Expand Down Expand Up @@ -618,7 +618,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
vec![col("a", &schema)?],
DataType::Int64,
None,
Signature::exact(vec![DataType::Int64], Volatility::Immutable),
false,
);

let project =
Expand Down

0 comments on commit fbd376e

Please sign in to comment.