diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java index ed16fe174e71..6dbfddd320f1 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIO.java @@ -2047,6 +2047,9 @@ public enum ExistingPipelineOptions { public abstract static class ReadChangeStream extends PTransform>> { + private static final Duration DEFAULT_BACKLOG_REPLICATION_ADJUSTMENT = + Duration.standardSeconds(30); + static ReadChangeStream create() { BigtableConfig config = BigtableConfig.builder().setValidate(true).build(); BigtableConfig metadataTableconfig = BigtableConfig.builder().setValidate(true).build(); @@ -2075,6 +2078,8 @@ static ReadChangeStream create() { abstract @Nullable Boolean getCreateOrUpdateMetadataTable(); + abstract @Nullable Duration getBacklogReplicationAdjustment(); + abstract ReadChangeStream.Builder toBuilder(); /** @@ -2259,6 +2264,26 @@ public ReadChangeStream withCreateOrUpdateMetadataTable(boolean shouldCreate) { return toBuilder().setCreateOrUpdateMetadataTable(shouldCreate).build(); } + /** + * Returns a new {@link BigtableIO.ReadChangeStream} that overrides the replication delay + * adjustment duration with the provided duration. + * + *

Backlog is calculated for each partition using watermarkLag * throughput. Replication + * delay holds back the watermark for each partition. This can cause the backlog to stay + * persistently above dataflow's downscaling threshold (10 seconds) even when a pipeline is + * caught up. + * + *

This adjusts the backlog downward to account for this. For unreplicated instances it can + * be set to zero to upscale as quickly as possible. + * + *

Optional: defaults to 30 seconds. + * + *

Does not modify this object. + */ + public ReadChangeStream withBacklogReplicationAdjustment(Duration adjustment) { + return toBuilder().setBacklogReplicationAdjustment(adjustment).build(); + } + @Override public PCollection> expand(PBegin input) { checkArgument( @@ -2312,6 +2337,10 @@ public PCollection> expand(PBegin input) { if (getCreateOrUpdateMetadataTable() != null) { shouldCreateOrUpdateMetadataTable = getCreateOrUpdateMetadataTable(); } + Duration backlogReplicationAdjustment = getBacklogReplicationAdjustment(); + if (backlogReplicationAdjustment == null) { + backlogReplicationAdjustment = DEFAULT_BACKLOG_REPLICATION_ADJUSTMENT; + } ActionFactory actionFactory = new ActionFactory(); ChangeStreamMetrics metrics = new ChangeStreamMetrics(); @@ -2356,7 +2385,8 @@ public PCollection> expand(PBegin input) { DetectNewPartitionsDoFn detectNewPartitionsDoFn = new DetectNewPartitionsDoFn(getEndTime(), actionFactory, daoFactory, metrics); ReadChangeStreamPartitionDoFn readChangeStreamPartitionDoFn = - new ReadChangeStreamPartitionDoFn(daoFactory, actionFactory, metrics); + new ReadChangeStreamPartitionDoFn( + daoFactory, actionFactory, metrics, backlogReplicationAdjustment); PCollection> readChangeStreamOutput = input @@ -2397,6 +2427,8 @@ abstract ReadChangeStream.Builder setExistingPipelineOptions( abstract ReadChangeStream.Builder setCreateOrUpdateMetadataTable(boolean shouldCreate); + abstract ReadChangeStream.Builder setBacklogReplicationAdjustment(Duration adjustment); + abstract ReadChangeStream build(); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java index 92590a7e4b89..826710d9c588 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFn.java @@ -43,6 +43,7 @@ import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.Manual; import org.apache.beam.sdk.values.KV; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.joda.time.Duration; import org.joda.time.Instant; @@ -63,15 +64,32 @@ public class ReadChangeStreamPartitionDoFn private final DaoFactory daoFactory; private final ChangeStreamMetrics metrics; private final ActionFactory actionFactory; + private final Duration backlogReplicationAdjustment; private SizeEstimator> sizeEstimator; private ReadChangeStreamPartitionAction readChangeStreamPartitionAction; + private final SerializableSupplier clock; public ReadChangeStreamPartitionDoFn( - DaoFactory daoFactory, ActionFactory actionFactory, ChangeStreamMetrics metrics) { + DaoFactory daoFactory, + ActionFactory actionFactory, + ChangeStreamMetrics metrics, + Duration backlogReplicationAdjustment) { + this(daoFactory, actionFactory, metrics, backlogReplicationAdjustment, Instant::now); + } + + @VisibleForTesting + ReadChangeStreamPartitionDoFn( + DaoFactory daoFactory, + ActionFactory actionFactory, + ChangeStreamMetrics metrics, + Duration backlogReplicationAdjustment, + SerializableSupplier clock) { this.daoFactory = daoFactory; this.metrics = metrics; this.actionFactory = actionFactory; + this.backlogReplicationAdjustment = backlogReplicationAdjustment; this.sizeEstimator = new NullSizeEstimator<>(); + this.clock = clock; } @GetInitialWatermarkEstimatorState @@ -126,12 +144,15 @@ public double getSize(@Restriction StreamProgress streamProgress) { // this to count against the backlog and prevent scaling down, so we estimate heartbeat backlog // using the time we most recently processed a heartbeat. Otherwise, (for mutations) we use the // watermark. - Duration processingTimeLag = - Duration.millis( - Instant.now().getMillis() - streamProgress.getLastRunTimestamp().getMillis()); - Duration watermarkLag = Duration.millis(Instant.now().getMillis() - lowWatermark.getMillis()); + long processingTimeLagMillis = + clock.get().getMillis() - streamProgress.getLastRunTimestamp().getMillis(); + Duration watermarkLag = Duration.millis(clock.get().getMillis() - lowWatermark.getMillis()); + // Remove the backlogReplicationAdjustment from watermarkLag to allow replicated instances to + // downscale more easily. + long adjustedWatermarkLagMillis = + Math.max(0, watermarkLag.minus(backlogReplicationAdjustment).getMillis()); long lagInMillis = - (streamProgress.isHeartbeat() ? processingTimeLag : watermarkLag).getMillis(); + streamProgress.isHeartbeat() ? processingTimeLagMillis : adjustedWatermarkLagMillis; // Return the estimated bytes per second throughput multiplied by the amount of known work // outstanding (watermark lag). Cap at max double to avoid overflow. double estimatedSize = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/SerializableSupplier.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/SerializableSupplier.java new file mode 100644 index 000000000000..2b09adbc75dd --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/SerializableSupplier.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigtable.changestreams.dofn; + +import java.io.Serializable; +import java.util.function.Supplier; + +/** Union of Supplier and Serializable interfaces to allow serialized supplier for testing. */ +@FunctionalInterface +interface SerializableSupplier extends Supplier, Serializable {} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java index b89b2bf15aa3..d4f9da768088 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dofn/ReadChangeStreamPartitionDoFnTest.java @@ -58,20 +58,23 @@ public class ReadChangeStreamPartitionDoFnTest { private ChangeStreamDao changeStreamDao; private MetadataTableDao metadataTableDao; private CoderSizeEstimator> sizeEstimator; + private DaoFactory daoFactory; + private ActionFactory actionFactory; + private ChangeStreamMetrics metrics; private ReadChangeStreamPartitionDoFn doFn; @Before public void setup() throws IOException { Duration heartbeatDuration = Duration.standardSeconds(1); - DaoFactory daoFactory = mock(DaoFactory.class); + daoFactory = mock(DaoFactory.class); changeStreamDao = mock(ChangeStreamDao.class); metadataTableDao = mock(MetadataTableDao.class); when(daoFactory.getChangeStreamDao()).thenReturn(changeStreamDao); when(daoFactory.getMetadataTableDao()).thenReturn(metadataTableDao); when(daoFactory.getChangeStreamName()).thenReturn("test-id"); - ActionFactory actionFactory = mock(ActionFactory.class); - ChangeStreamMetrics metrics = mock(ChangeStreamMetrics.class); + actionFactory = mock(ActionFactory.class); + metrics = mock(ChangeStreamMetrics.class); sizeEstimator = mock(CoderSizeEstimator.class); ChangeStreamAction changeStreamAction = new ChangeStreamAction(metrics); @@ -93,7 +96,7 @@ public void setup() throws IOException { sizeEstimator)) .thenReturn(readChangeStreamPartitionAction); - doFn = new ReadChangeStreamPartitionDoFn(daoFactory, actionFactory, metrics); + doFn = new ReadChangeStreamPartitionDoFn(daoFactory, actionFactory, metrics, Duration.ZERO); doFn.setSizeEstimator(sizeEstimator); } @@ -182,4 +185,42 @@ public void testGetSizeCantBeNegative() throws IOException { true)); assertEquals(0, heartbeatEstimate, 0); } + + @Test + public void backlogReplicationAdjustment() throws IOException { + SerializableSupplier mockClock = () -> Instant.ofEpochSecond(1000); + doFn = + new ReadChangeStreamPartitionDoFn( + daoFactory, actionFactory, metrics, Duration.standardSeconds(30), mockClock); + long mutationSize = 100L; + when(sizeEstimator.sizeOf(any())).thenReturn(mutationSize); + doFn.setSizeEstimator(sizeEstimator); + + Range.ByteStringRange partitionRange = Range.ByteStringRange.create("", ""); + ChangeStreamContinuationToken testToken = + ChangeStreamContinuationToken.create(partitionRange, "test"); + doFn.setup(); + + double mutationEstimate10Second = + doFn.getSize( + new StreamProgress( + testToken, + mockClock.get().minus(Duration.standardSeconds(10)), + BigDecimal.valueOf(1000), + mockClock.get().minus(Duration.standardSeconds(10)), + false)); + // With 30s backlogReplicationAdjustment we should have no backlog when watermarkLag is < 30s + assertEquals(0, mutationEstimate10Second, 0); + + double mutationEstimateOneMinute = + doFn.getSize( + new StreamProgress( + testToken, + mockClock.get().minus(Duration.standardSeconds(60)), + BigDecimal.valueOf(1000), + mockClock.get().minus(Duration.standardSeconds(60)), + false)); + // We ignore the first 30s of backlog so this should be throughput * (60 - 30) + assertEquals(1000 * 30, mutationEstimateOneMinute, 0); + } }