Skip to content

Commit

Permalink
[SPARK-47412][SQL] Add Collation Support for LPad/RPad
Browse files Browse the repository at this point in the history
Add collation support for LPAD and RPAD

### What changes were proposed in this pull request?

Add collation support for LPAD and RPAD

### Why are the changes needed?

### Does this PR introduce _any_ user-facing change?

Yes

### How was this patch tested?

Unit tests and spark-shell

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#46041 from GideonPotok/spark_47412_collation_lpad_rpad.

Authored-by: GideonPotok <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
GideonPotok authored and cloud-fan committed Apr 23, 2024
1 parent eba6364 commit 885e98e
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import javax.annotation.Nullable
import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType}
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Overlay}
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Overlay, StringLPad, StringRPad}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
Expand Down Expand Up @@ -52,6 +52,11 @@ object CollationTypeCasts extends TypeCoercionRule {
overlay.withNewChildren(collateToSingleType(Seq(overlay.input, overlay.replace))
++ Seq(overlay.pos, overlay.len))

case stringPadExpr @ (_: StringRPad | _: StringLPad) =>
val Seq(str, len, pad) = stringPadExpr.children
val Seq(newStr, newPad) = collateToSingleType(Seq(str, pad))
stringPadExpr.withNewChildren(Seq(newStr, len, newPad))

case otherExpr @ (
_: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least |
_: Coalesce | _: BinaryExpression | _: ConcatWs) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1586,7 +1586,8 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression)
override def third: Expression = pad

override def dataType: DataType = str.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation)

override def nullSafeEval(string: Any, len: Any, pad: Any): Any = {
string.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
Expand Down Expand Up @@ -1665,7 +1666,8 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera
override def third: Expression = pad

override def dataType: DataType = str.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation)

override def nullSafeEval(string: Any, len: Any, pad: Any): Any = {
string.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,80 @@ class CollationStringExpressionsSuite
}
}

test("Support StringRPad string expressions with collation") {
// Supported collations
case class StringRPadTestCase[R](s: String, len: Int, pad: String, c: String, result: R)
val testCases = Seq(
StringRPadTestCase("", 5, " ", "UTF8_BINARY", " "),
StringRPadTestCase("abc", 5, " ", "UNICODE", "abc "),
StringRPadTestCase("Hello", 7, "Wörld", "UTF8_BINARY_LCASE", "HelloWö"),
StringRPadTestCase("1234567890", 5, "aaaAAa", "UNICODE_CI", "12345"),
StringRPadTestCase("aaAA", 2, " ", "UTF8_BINARY", "aa"),
StringRPadTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ℀℃", 2, "1", "UTF8_BINARY_LCASE", "ÀÃ"),
StringRPadTestCase("ĂȦÄäåäá", 20, "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "UNICODE", "ĂȦÄäåäáÀÃÂĀĂȦÄäåäáâã"),
StringRPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI", "aȦÄäa1a1")
)
testCases.foreach(t => {
val query = s"SELECT rpad(collate('${t.s}', '${t.c}')," +
s" ${t.len}, collate('${t.pad}', '${t.c}'))"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
// Implicit casting
checkAnswer(
sql(s"SELECT rpad(collate('${t.s}', '${t.c}'), ${t.len}, '${t.pad}')"),
Row(t.result))
checkAnswer(
sql(s"SELECT rpad('${t.s}', ${t.len}, collate('${t.pad}', '${t.c}'))"),
Row(t.result))
})
// Collation mismatch
val collationMismatch = intercept[AnalysisException] {
sql("SELECT rpad(collate('abcde', 'UNICODE_CI'),1,collate('C', 'UTF8_BINARY_LCASE'))")
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support StringLPad string expressions with collation") {
// Supported collations
case class StringLPadTestCase[R](s: String, len: Int, pad: String, c: String, result: R)
val testCases = Seq(
StringLPadTestCase("", 5, " ", "UTF8_BINARY", " "),
StringLPadTestCase("abc", 5, " ", "UNICODE", " abc"),
StringLPadTestCase("Hello", 7, "Wörld", "UTF8_BINARY_LCASE", "WöHello"),
StringLPadTestCase("1234567890", 5, "aaaAAa", "UNICODE_CI", "12345"),
StringLPadTestCase("aaAA", 2, " ", "UTF8_BINARY", "aa"),
StringLPadTestCase("ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ℀℃", 2, "1", "UTF8_BINARY_LCASE", "ÀÃ"),
StringLPadTestCase("ĂȦÄäåäá", 20, "ÀÃÂĀĂȦÄäåäáâãȻȻȻȻȻǢǼÆ", "UNICODE", "ÀÃÂĀĂȦÄäåäáâãĂȦÄäåäá"),
StringLPadTestCase("aȦÄä", 8, "a1", "UNICODE_CI", "a1a1aȦÄä")
)
testCases.foreach(t => {
val query = s"SELECT lpad(collate('${t.s}', '${t.c}')," +
s" ${t.len}, collate('${t.pad}', '${t.c}'))"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
// Implicit casting
checkAnswer(
sql(s"SELECT lpad(collate('${t.s}', '${t.c}'), ${t.len}, '${t.pad}')"),
Row(t.result))
checkAnswer(
sql(s"SELECT lpad('${t.s}', ${t.len}, collate('${t.pad}', '${t.c}'))"),
Row(t.result))
})
// Collation mismatch
val collationMismatch = intercept[AnalysisException] {
sql("SELECT lpad(collate('abcde', 'UNICODE_CI'),1,collate('C', 'UTF8_BINARY_LCASE'))")
}
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support StringLPad string expressions with explicit collation on second parameter") {
val query = "SELECT lpad('abc', collate('5', 'unicode_ci'), ' ')"
checkAnswer(sql(query), Row(" abc"))
assert(sql(query).schema.fields.head.dataType.sameType(StringType(0)))
}

// TODO: Add more tests for other string expressions

}
Expand Down

0 comments on commit 885e98e

Please sign in to comment.