diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 5a35c62e3..c40e2e73f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -98,6 +98,7 @@ use datafusion_expr::{ AggregateUDF, ScalarUDF, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_nested::array_has::ArrayHas; use datafusion_physical_expr::expressions::{Literal, StatsType}; use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr::LexOrdering; @@ -719,6 +720,20 @@ impl PhysicalPlanner { expr.legacy_negative_index, ))) } + ExprStruct::ArrayContains(expr) => { + let src_array_expr = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let key_expr = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + let args = vec![Arc::clone(&src_array_expr), key_expr]; + let array_has_expr = Arc::new(ScalarFunctionExpr::new( + "array_has", + Arc::new(ScalarUDF::new_from_impl(ArrayHas::new())), + args, + DataType::Boolean, + )); + Ok(array_has_expr) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 7a8ea78d5..e76ecdccf 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -84,6 +84,7 @@ message Expr { GetArrayStructFields get_array_struct_fields = 57; BinaryExpr array_append = 58; ArrayInsert array_insert = 59; + BinaryExpr array_contains = 60; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 518fa0685..dc081b196 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2266,6 +2266,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(expr, "unsupported arguments for GetArrayStructFields", child) None } + case expr if expr.prettyName == "array_contains" => + createBinaryExpr( + expr.children(0), + expr.children(1), + inputs, + (builder, binaryExpr) => builder.setArrayContains(binaryExpr)) case _ if expr.prettyName == "array_append" => createBinaryExpr( expr.children(0), diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index cce7cb20a..36d370650 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2517,4 +2517,16 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswer(df.select("arrUnsupportedArgs")) } } + + test("array_contains") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = false, n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + } + } }