From eaa6cf6818cf407adedde5e8ded30939723c5a7e Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Thu, 12 Dec 2024 01:52:36 +0530 Subject: [PATCH] feat: add support for array_contains expression --- .../core/src/execution/datafusion/planner.rs | 24 +++++++++++++++++++ native/proto/src/proto/expr.proto | 1 + .../apache/comet/serde/QueryPlanSerde.scala | 6 +++++ 3 files changed, 31 insertions(+) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 33c4924cb..4b8f01b4e 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -113,6 +113,7 @@ use datafusion_expr::{ AggregateUDF, ScalarUDF, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_nested::array_has::ArrayHas; use datafusion_physical_expr::expressions::{Literal, StatsType}; use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr::LexOrdering; @@ -735,6 +736,29 @@ impl PhysicalPlanner { expr.legacy_negative_index, ))) } + ExprStruct::ArrayContains(expr) => { + println!("dharan code got executed"); + let src_array_expr = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let key_expr = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + let args = vec![Arc::clone(&src_array_expr), key_expr]; + let array_has_expr = Arc::new(ScalarFunctionExpr::new( + "array_has", + Arc::new(ScalarUDF::new_from_impl(ArrayHas::new())), + args, + DataType::Boolean, + )); + let is_array_null: Arc = + Arc::new(IsNullExpr::new(src_array_expr)); + let null_literal_expr: Arc = + Arc::new(Literal::new(ScalarValue::Null)); + Ok(Arc::new(CaseExpr::try_new( + None, + vec![(is_array_null, null_literal_expr)], + Some(array_has_expr), + )?)) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 7a8ea78d5..e76ecdccf 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -84,6 +84,7 @@ message Expr { GetArrayStructFields get_array_struct_fields = 57; BinaryExpr array_append = 58; ArrayInsert array_insert = 59; + BinaryExpr array_contains = 60; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b33f6b5a6..e168b2b21 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2266,6 +2266,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(expr, "unsupported arguments for GetArrayStructFields", child) None } + case expr if expr.prettyName == "array_contains" => + createBinaryExpr( + expr.children(0), + expr.children(1), + inputs, + (builder, binaryExpr) => builder.setArrayContains(binaryExpr)) case _ if expr.prettyName == "array_append" => createBinaryExpr( expr.children(0),