diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java index d647027a9ae5..4895c1478766 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java @@ -395,6 +395,7 @@ static Object toBeamObject(Object value, FieldType fieldType, boolean verifyValu } return ((ByteString) value).getBytes(); case ARRAY: + case ITERABLE: return toBeamList((List) value, fieldType.getCollectionElementType(), verifyValues); case MAP: return toBeamMap( @@ -558,6 +559,9 @@ private static Expression getBeamField( case ROW: value = Expressions.call(expression, "getRow", fieldName); break; + case ITERABLE: + value = Expressions.call(expression, "getIterable", fieldName); + break; case LOGICAL_TYPE: String identifier = fieldType.getLogicalType().getIdentifier(); if (FixedString.IDENTIFIER.equals(identifier) @@ -634,6 +638,7 @@ private static Expression toCalciteValue(Expression value, FieldType fieldType) return nullOr( value, Expressions.new_(ByteString.class, Expressions.convert_(value, byte[].class))); case ARRAY: + case ITERABLE: return nullOr(value, toCalciteList(value, fieldType.getCollectionElementType())); case MAP: return nullOr(value, toCalciteMap(value, fieldType.getMapValueType())); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslBase.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslBase.java index 426b95ae6df6..68b672e1814c 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslBase.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslBase.java @@ -38,6 +38,7 @@ import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Before; @@ -94,6 +95,7 @@ public static void prepareClass() throws ParseException { .addDateTimeField("f_timestamp") .addInt32Field("f_int2") .addDecimalField("f_decimal") + .addIterableField("f_iterable", FieldType.STRING) .build(); rowsInTableA = @@ -111,7 +113,8 @@ public static void prepareClass() throws ParseException { LocalDateTime.of(2017, 1, 1, 1, 1, 3), parseTimestampWithoutTimeZone("2017-01-01 01:01:03"), 0, - new BigDecimal(1)) + new BigDecimal(1), + Lists.newArrayList("s1", "s2")) .addRows( 2, 2000L, @@ -125,7 +128,8 @@ public static void prepareClass() throws ParseException { LocalDateTime.of(2017, 1, 1, 1, 2, 3), parseTimestampWithoutTimeZone("2017-01-01 01:02:03"), 0, - new BigDecimal(2)) + new BigDecimal(2), + Lists.newArrayList("s1", "s2")) .addRows( 3, 3000L, @@ -139,7 +143,8 @@ public static void prepareClass() throws ParseException { LocalDateTime.of(2017, 1, 1, 1, 6, 3), parseTimestampWithoutTimeZone("2017-01-01 01:06:03"), 0, - new BigDecimal(3)) + new BigDecimal(3), + Lists.newArrayList("s1", "s2")) .addRows( 4, 4000L, @@ -153,7 +158,8 @@ public static void prepareClass() throws ParseException { LocalDateTime.of(2017, 1, 1, 2, 4, 3), parseTimestampWithoutTimeZone("2017-01-01 02:04:03"), 0, - new BigDecimal(4)) + new BigDecimal(4), + Lists.newArrayList("s1", "s2")) .getRows(); monthlyRowsInTableA = @@ -171,7 +177,8 @@ public static void prepareClass() throws ParseException { LocalDateTime.of(2017, 1, 1, 1, 1, 3), parseTimestampWithUTCTimeZone("2017-01-01 01:01:03"), 0, - new BigDecimal(1)) + new BigDecimal(1), + Lists.newArrayList("s1", "s2")) .addRows( 2, 2000L, @@ -185,7 +192,8 @@ public static void prepareClass() throws ParseException { LocalDateTime.of(2017, 1, 1, 1, 2, 3), parseTimestampWithUTCTimeZone("2017-02-01 01:02:03"), 0, - new BigDecimal(2)) + new BigDecimal(2), + Lists.newArrayList("s1", "s2")) .addRows( 3, 3000L, @@ -199,7 +207,8 @@ public static void prepareClass() throws ParseException { LocalDateTime.of(2017, 1, 1, 1, 6, 3), parseTimestampWithUTCTimeZone("2017-03-01 01:06:03"), 0, - new BigDecimal(3)) + new BigDecimal(3), + Lists.newArrayList("s1", "s2")) .getRows(); schemaFloatDouble =