From 209d50e6765e48936b2efb7d52dd828c256734d4 Mon Sep 17 00:00:00 2001 From: Vitaly Terentyev Date: Tue, 2 Jul 2024 20:16:45 +0400 Subject: [PATCH] Support custom JdbcReadWithPartitionsHelper (#31733) --- .../org/apache/beam/sdk/io/jdbc/JdbcIO.java | 64 +++++++++++---- .../io/jdbc/JdbcReadWithPartitionsHelper.java | 48 +++++++++++ .../org/apache/beam/sdk/io/jdbc/JdbcUtil.java | 46 +++-------- .../apache/beam/sdk/io/jdbc/JdbcIOTest.java | 81 ++++++++++++++++++- .../apache/beam/sdk/io/jdbc/JdbcUtilTest.java | 11 +-- 5 files changed, 189 insertions(+), 61 deletions(-) create mode 100644 sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadWithPartitionsHelper.java diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index d466c0076c89..2f164fa3bb78 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -209,8 +209,10 @@ *

Parallel reading from a JDBC datasource

* *

Beam supports partitioned reading of all data from a table. Automatic partitioning is - * supported for a few data types: {@link Long}, {@link org.joda.time.DateTime}, {@link String}. To - * enable this, use {@link JdbcIO#readWithPartitions(TypeDescriptor)}. + * supported for a few data types: {@link Long}, {@link org.joda.time.DateTime}. To enable this, use + * {@link JdbcIO#readWithPartitions(TypeDescriptor)}. For other types, use {@link + * ReadWithPartitions#readWithPartitions(JdbcReadWithPartitionsHelper)} with custom {@link + * JdbcReadWithPartitionsHelper}. * *

The partitioning scheme depends on these parameters, which can be user-provided, or * automatically inferred by Beam (for the supported types): @@ -361,6 +363,7 @@ public static ReadAll readAll() { * Like {@link #readAll}, but executes multiple instances of the query on the same table * (subquery) using ranges. * + * @param partitioningColumnType Type descriptor for the partition column. * @param Type of the data to be read. */ public static ReadWithPartitions readWithPartitions( @@ -373,6 +376,23 @@ public static ReadWithPartitions read .build(); } + /** + * Like {@link #readAll}, but executes multiple instances of the query on the same table + * (subquery) using ranges. + * + * @param partitionsHelper Custom helper for defining partitions. + * @param Type of the data to be read. + */ + public static ReadWithPartitions readWithPartitions( + JdbcReadWithPartitionsHelper partitionsHelper) { + return new AutoValue_JdbcIO_ReadWithPartitions.Builder() + .setPartitionsHelper(partitionsHelper) + .setNumPartitions(DEFAULT_NUM_PARTITIONS) + .setFetchSize(DEFAULT_FETCH_SIZE) + .setUseBeamSchema(false) + .build(); + } + public static ReadWithPartitions readWithPartitions() { return JdbcIO.readWithPartitions(TypeDescriptors.longs()); } @@ -1229,7 +1249,10 @@ public abstract static class ReadWithPartitions abstract @Nullable String getTable(); @Pure - abstract TypeDescriptor getPartitionColumnType(); + abstract @Nullable TypeDescriptor getPartitionColumnType(); + + @Pure + abstract @Nullable JdbcReadWithPartitionsHelper getPartitionsHelper(); @Pure abstract Builder toBuilder(); @@ -1261,6 +1284,9 @@ abstract Builder setDataSourceProviderFn( abstract Builder setPartitionColumnType( TypeDescriptor partitionColumnType); + abstract Builder setPartitionsHelper( + JdbcReadWithPartitionsHelper partitionsHelper); + abstract ReadWithPartitions build(); } @@ -1360,10 +1386,19 @@ && getLowerBound() instanceof Comparable) { ((Comparable) getLowerBound()).compareTo(getUpperBound()) < EQUAL, "The lower bound of partitioning column is larger or equal than the upper bound"); } - checkNotNull( - JdbcUtil.JdbcReadWithPartitionsHelper.getPartitionsHelper(getPartitionColumnType()), - "readWithPartitions only supports the following types: %s", - JdbcUtil.PRESET_HELPERS.keySet()); + + JdbcReadWithPartitionsHelper partitionsHelper = getPartitionsHelper(); + if (partitionsHelper == null) { + partitionsHelper = + JdbcUtil.getPartitionsHelper( + checkStateNotNull( + getPartitionColumnType(), + "Provide partitionColumnType or partitionsHelper for JdbcIO.readWithPartitions()")); + checkNotNull( + partitionsHelper, + "readWithPartitions only supports the following types: %s", + JdbcUtil.PRESET_HELPERS.keySet()); + } PCollection>> params; @@ -1383,10 +1418,7 @@ && getLowerBound() instanceof Comparable) { JdbcIO.>>read() .withQuery(query) .withDataSourceProviderFn(dataSourceProviderFn) - .withRowMapper( - checkStateNotNull( - JdbcUtil.JdbcReadWithPartitionsHelper.getPartitionsHelper( - getPartitionColumnType()))) + .withRowMapper(checkStateNotNull(partitionsHelper)) .withFetchSize(getFetchSize())) .apply( MapElements.via( @@ -1441,7 +1473,9 @@ public KV> apply( PCollection> ranges = params - .apply("Partitioning", ParDo.of(new PartitioningFn<>(getPartitionColumnType()))) + .apply( + "Partitioning", + ParDo.of(new PartitioningFn<>(checkStateNotNull(partitionsHelper)))) .apply("Reshuffle partitions", Reshuffle.viaRandomKey()); JdbcIO.ReadAll, T> readAll = @@ -1452,11 +1486,7 @@ public KV> apply( "select * from %1$s where %2$s >= ? and %2$s < ?", table, partitionColumn)) .withRowMapper(rowMapper) .withFetchSize(getFetchSize()) - .withParameterSetter( - checkStateNotNull( - JdbcUtil.JdbcReadWithPartitionsHelper.getPartitionsHelper( - getPartitionColumnType())) - ::setParameters) + .withParameterSetter(checkStateNotNull(partitionsHelper)) .withOutputParallelization(false); if (getUseBeamSchema()) { diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadWithPartitionsHelper.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadWithPartitionsHelper.java new file mode 100644 index 000000000000..d04162eaa0e8 --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadWithPartitionsHelper.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import org.apache.beam.sdk.io.jdbc.JdbcIO.PreparedStatementSetter; +import org.apache.beam.sdk.io.jdbc.JdbcIO.RowMapper; +import org.apache.beam.sdk.values.KV; + +/** + * A helper for {@link JdbcIO.ReadWithPartitions} that handles range calculations. + * + * @param Element type of the column used for partition. + */ +public interface JdbcReadWithPartitionsHelper + extends PreparedStatementSetter>, + RowMapper>> { + + /** + * Calculate the range of each partition from the lower and upper bound, and number of partitions. + * + *

Return a list of pairs for each lower and upper bound within each partition. + */ + Iterable> calculateRanges( + PartitionT lowerBound, PartitionT upperBound, Long partitions); + + @Override + void setParameters(KV element, PreparedStatement preparedStatement); + + @Override + KV> mapRow(ResultSet resultSet) throws Exception; +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java index 8d82938e5968..b3f46492f745 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.io.jdbc; import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; -import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import java.io.File; import java.io.IOException; @@ -47,9 +46,6 @@ import java.util.stream.IntStream; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.fs.ResourceId; -import org.apache.beam.sdk.io.jdbc.JdbcIO.PreparedStatementSetter; -import org.apache.beam.sdk.io.jdbc.JdbcIO.ReadWithPartitions; -import org.apache.beam.sdk.io.jdbc.JdbcIO.RowMapper; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.logicaltypes.FixedPrecisionNumeric; import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant; @@ -438,50 +434,30 @@ private static Calendar withTimestampAndTimezone(DateTime dateTime) { return calendar; } - /** - * A helper for {@link ReadWithPartitions} that handles range calculations. - * - * @param - */ - interface JdbcReadWithPartitionsHelper - extends PreparedStatementSetter>, - RowMapper>> { - static @Nullable JdbcReadWithPartitionsHelper getPartitionsHelper( - TypeDescriptor type) { - // This cast is unchecked, thus this is a small type-checking risk. We just need - // to make sure that all preset helpers in `JdbcUtil.PRESET_HELPERS` are matched - // in type from their Key and their Value. - return (JdbcReadWithPartitionsHelper) PRESET_HELPERS.get(type.getRawType()); - } - - Iterable> calculateRanges( - PartitionT lowerBound, PartitionT upperBound, Long partitions); - - @Override - void setParameters(KV element, PreparedStatement preparedStatement); - - @Override - KV> mapRow(ResultSet resultSet) throws Exception; + /** @return a {@code JdbcReadPartitionsHelper} instance associated with the given {@param type} */ + static @Nullable JdbcReadWithPartitionsHelper getPartitionsHelper(TypeDescriptor type) { + // This cast is unchecked, thus this is a small type-checking risk. We just need + // to make sure that all preset helpers in `JdbcUtil.PRESET_HELPERS` are matched + // in type from their Key and their Value. + return (JdbcReadWithPartitionsHelper) PRESET_HELPERS.get(type.getRawType()); } /** Create partitions on a table. */ static class PartitioningFn extends DoFn>, KV> { private static final Logger LOG = LoggerFactory.getLogger(PartitioningFn.class); - final TypeDescriptor partitioningColumnType; + final JdbcReadWithPartitionsHelper partitionsHelper; - PartitioningFn(TypeDescriptor partitioningColumnType) { - this.partitioningColumnType = partitioningColumnType; + PartitioningFn(JdbcReadWithPartitionsHelper partitionsHelper) { + this.partitionsHelper = partitionsHelper; } @ProcessElement public void processElement(ProcessContext c) { T lowerBound = c.element().getValue().getKey(); T upperBound = c.element().getValue().getValue(); - JdbcReadWithPartitionsHelper helper = - checkStateNotNull( - JdbcReadWithPartitionsHelper.getPartitionsHelper(partitioningColumnType)); List> ranges = - Lists.newArrayList(helper.calculateRanges(lowerBound, upperBound, c.element().getKey())); + Lists.newArrayList( + partitionsHelper.calculateRanges(lowerBound, upperBound, c.element().getKey())); LOG.warn("Total of {} ranges: {}", ranges.size(), ranges); for (KV e : ranges) { c.output(e); diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index ccee14c4691b..013fc7996a95 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -45,6 +45,7 @@ import java.sql.Date; import java.sql.JDBCType; import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.sql.Time; @@ -91,6 +92,7 @@ import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.commons.dbcp2.PoolingDataSource; +import org.apache.commons.lang3.StringUtils; import org.hamcrest.Description; import org.hamcrest.TypeSafeMatcher; import org.joda.time.DateTime; @@ -104,11 +106,14 @@ import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** Test on the JdbcIO. */ @RunWith(JUnit4.class) public class JdbcIOTest implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(JdbcIOTest.class); private static final DataSourceConfiguration DATA_SOURCE_CONFIGURATION = DataSourceConfiguration.create( "org.apache.derby.jdbc.EmbeddedDriver", "jdbc:derby:memory:testDB;create=true"); @@ -1326,7 +1331,10 @@ public void testPartitioningDateTime() { PCollection> ranges = pipeline .apply(Create.of(KV.of(10L, KV.of(new DateTime(0), DateTime.now())))) - .apply(ParDo.of(new PartitioningFn<>(TypeDescriptor.of(DateTime.class)))); + .apply( + ParDo.of( + new PartitioningFn<>( + JdbcUtil.getPartitionsHelper(TypeDescriptor.of(DateTime.class))))); PAssert.that(ranges.apply(Count.globally())) .satisfies( @@ -1407,9 +1415,78 @@ public void testPartitioningLongs() { PCollection> ranges = pipeline .apply(Create.of(KV.of(10L, KV.of(0L, 12346789L)))) - .apply(ParDo.of(new PartitioningFn<>(TypeDescriptors.longs()))); + .apply( + ParDo.of( + new PartitioningFn<>(JdbcUtil.getPartitionsHelper(TypeDescriptors.longs())))); PAssert.that(ranges.apply(Count.globally())).containsInAnyOrder(10L); pipeline.run().waitUntilFinish(); } + + @Test + public void testPartitioningStringsWithCustomPartitionsHelper() { + JdbcReadWithPartitionsHelper helper = + new JdbcReadWithPartitionsHelper() { + @Override + public Iterable> calculateRanges( + String lowerBound, String upperBound, Long partitions) { + // we expect the elements in the test case follow the format idx + String prefix = StringUtils.getCommonPrefix(lowerBound, upperBound); + int minChar = lowerBound.charAt(prefix.length()); + int maxChar = upperBound.charAt(prefix.length()); + int numPartition; + if (maxChar - minChar < partitions) { + LOG.warn( + "Partition large than possible! Adjust to {} partition instead", + maxChar - minChar); + numPartition = maxChar - minChar; + } else { + numPartition = Math.toIntExact(partitions); + } + List> ranges = new ArrayList<>(); + int stride = (maxChar - minChar) / numPartition + 1; + int highest = minChar; + for (int i = minChar; i < maxChar - stride; i += stride) { + ranges.add(KV.of(prefix + (char) i, prefix + (char) (i + stride))); + highest = i + stride; + } + if (highest <= maxChar) { + ranges.add(KV.of(prefix + (char) highest, prefix + (char) (highest + stride))); + } + return ranges; + } + + @Override + public void setParameters( + KV element, PreparedStatement preparedStatement) { + try { + preparedStatement.setString(1, element.getKey()); + preparedStatement.setString(2, element.getValue()); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + @Override + public KV> mapRow(ResultSet resultSet) throws Exception { + if (resultSet.getMetaData().getColumnCount() == 3) { + return KV.of( + resultSet.getLong(3), KV.of(resultSet.getString(1), resultSet.getString(2))); + } else { + return KV.of(0L, KV.of(resultSet.getString(1), resultSet.getString(2))); + } + } + }; + + PCollection rows = + pipeline.apply( + JdbcIO.readWithPartitions(helper) + .withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION) + .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId()) + .withTable(READ_TABLE_NAME) + .withNumPartitions(5) + .withPartitionColumn("name")); + PAssert.thatSingleton(rows.apply("Count All", Count.globally())).isEqualTo(1000L); + pipeline.run(); + } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java index 360b579195cf..5b2e9f27f0a8 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java @@ -35,7 +35,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Random; -import org.apache.beam.sdk.io.jdbc.JdbcUtil.JdbcReadWithPartitionsHelper; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TypeDescriptor; @@ -189,8 +188,7 @@ public void testStringPartitioningWithMultiletter() { @Test public void testDatetimePartitioningWithSingleKey() { JdbcReadWithPartitionsHelper helper = - JdbcUtil.JdbcReadWithPartitionsHelper.getPartitionsHelper( - TypeDescriptor.of(DateTime.class)); + JdbcUtil.getPartitionsHelper(TypeDescriptor.of(DateTime.class)); DateTime onlyPoint = DateTime.now(); List> expectedRanges = Lists.newArrayList(KV.of(onlyPoint, onlyPoint.plusMillis(1))); @@ -207,8 +205,7 @@ public void testDatetimePartitioningWithSingleKey() { @Test public void testDatetimePartitioningWithMultiKey() { JdbcReadWithPartitionsHelper helper = - JdbcUtil.JdbcReadWithPartitionsHelper.getPartitionsHelper( - TypeDescriptor.of(DateTime.class)); + JdbcUtil.getPartitionsHelper(TypeDescriptor.of(DateTime.class)); DateTime lastPoint = DateTime.now(); // At least 10ms in the past, or more. DateTime firstPoint = lastPoint.minusMillis(10 + new Random().nextInt(Integer.MAX_VALUE)); @@ -222,7 +219,7 @@ public void testDatetimePartitioningWithMultiKey() { @Test public void testLongPartitioningWithSingleKey() { JdbcReadWithPartitionsHelper helper = - JdbcUtil.JdbcReadWithPartitionsHelper.getPartitionsHelper(TypeDescriptors.longs()); + JdbcUtil.getPartitionsHelper(TypeDescriptors.longs()); List> expectedRanges = Lists.newArrayList(KV.of(12L, 13L)); List> ranges = Lists.newArrayList(helper.calculateRanges(12L, 12L, 10L)); // It is not possible to generate any more than one range, because the lower and upper range are @@ -236,7 +233,7 @@ public void testLongPartitioningWithSingleKey() { @Test public void testLongPartitioningNotEnoughRanges() { JdbcReadWithPartitionsHelper helper = - JdbcUtil.JdbcReadWithPartitionsHelper.getPartitionsHelper(TypeDescriptors.longs()); + JdbcUtil.getPartitionsHelper(TypeDescriptors.longs()); // The minimum stride is one, which is what causes this sort of partitioning. List> expectedRanges = Lists.newArrayList(KV.of(12L, 14L), KV.of(14L, 16L), KV.of(16L, 18L), KV.of(18L, 21L));