From cbbd46eb448615ce1ac5a63fc078b426459b4e9e Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 27 Nov 2024 15:36:35 +0800 Subject: [PATCH] [SPARK-50333][SQL][FOLLOWUP] Codegen Support for CsvToStructs(from_csv) - remove Invoke --- .../csv/CsvExpressionEvalUtils.scala | 3 +- .../catalyst/expressions/csvExpressions.scala | 50 ++++++++++++------- .../explain-results/function_from_csv.explain | 2 +- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala index a91e4ab13001b..fd298b33450b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{DataType, NullType, StructType} import org.apache.spark.unsafe.types.UTF8String /** - * The expression `CsvToStructs` will utilize the `Invoke` to call it, support codegen. + * The expression `CsvToStructs` will utilize it to support codegen. */ case class CsvToStructsEvaluator( options: Map[String, String], @@ -86,6 +86,7 @@ case class CsvToStructsEvaluator( } final def evaluate(csv: UTF8String): InternalRow = { + if (csv == null) return null converter(parser.parse(csv.toString)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 02e5488835c91..739151b6b05a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -23,10 +23,10 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.csv._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.expressions.csv.{CsvToStructsEvaluator, SchemaOfCsvEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf @@ -57,17 +57,12 @@ case class CsvToStructs( timeZoneId: Option[String] = None, requiredSchema: Option[StructType] = None) extends UnaryExpression - with RuntimeReplaceable - with ExpectsInputTypes - with TimeZoneAwareExpression { + with TimeZoneAwareExpression + with ExpectsInputTypes { override def nullable: Boolean = child.nullable - override def nodePatternsInternal(): Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE) - - // The CSV input data might be missing certain fields. We force the nullability - // of the user-provided schema to avoid data corruptions. - private val nullableSchema: StructType = schema.asNullable + override def nullIntolerant: Boolean = true // Used in `FunctionRegistry` def this(child: Expression, schema: Expression, options: Map[String, String]) = @@ -86,8 +81,6 @@ case class CsvToStructs( child = child, timeZoneId = None) - private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) - override def dataType: DataType = requiredSchema.getOrElse(schema).asNullable override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = { @@ -98,16 +91,37 @@ case class CsvToStructs( override def prettyName: String = "from_csv" + // The CSV input data might be missing certain fields. We force the nullability + // of the user-provided schema to avoid data corruptions. + private val nullableSchema: StructType = schema.asNullable + + @transient + private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) + @transient private lazy val evaluator: CsvToStructsEvaluator = CsvToStructsEvaluator( options, nullableSchema, nameOfCorruptRecord, timeZoneId, requiredSchema) - override def replacement: Expression = Invoke( - Literal.create(evaluator, ObjectType(classOf[CsvToStructsEvaluator])), - "evaluate", - dataType, - Seq(child), - Seq(child.dataType)) + override def nullSafeEval(input: Any): Any = { + evaluator.evaluate(input.asInstanceOf[UTF8String]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val refEvaluator = ctx.addReferenceObj("evaluator", evaluator) + val eval = child.genCode(ctx) + val resultType = CodeGenerator.boxedType(dataType) + val resultTerm = ctx.freshName("result") + ev.copy(code = + code""" + |${eval.code} + |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(${eval.value}); + |boolean ${ev.isNull} = $resultTerm == null; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $resultTerm; + |} + |""".stripMargin) + } override protected def withNewChildInternal(newChild: Expression): CsvToStructs = copy(child = newChild) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain index ef87c18948b23..89e03c8188232 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain @@ -1,2 +1,2 @@ -Project [invoke(CsvToStructsEvaluator(Map(mode -> FAILFAST),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),None).evaluate(g#0)) AS from_csv(g)#0] +Project [from_csv(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), (mode,FAILFAST), g#0, Some(America/Los_Angeles), None) AS from_csv(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]