diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ShutdownManager.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ShutdownManager.java index f133e4d85..61c5c85a2 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ShutdownManager.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ShutdownManager.java @@ -77,6 +77,23 @@ public CompletableFuture waitForSemaphorePermitsReleaseUntimed( return future; } + /** + * waitForStickyQueueBalancer -> disableNormalPoll -> timed wait for graceful completion of + * sticky workflows + */ + public CompletableFuture waitForStickyQueueBalancer( + StickyQueueBalancer balancer, Duration timeout) { + CompletableFuture future = new CompletableFuture<>(); + balancer.disableNormalPoll(); + scheduledExecutorService.schedule( + () -> { + future.complete(null); + }, + timeout.toMillis(), + TimeUnit.MILLISECONDS); + return future; + } + /** * Wait for {@code executorToShutdown} to terminate. Only completes the returned CompletableFuture * when the executor is terminated. diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java index cedaa7877..3a6f998dd 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java @@ -58,6 +58,7 @@ public static final class Builder { private long defaultDeadlockDetectionTimeout; private Duration maxHeartbeatThrottleInterval; private Duration defaultHeartbeatThrottleInterval; + private Duration drainStickyTaskQueueTimeout; private Builder() {} @@ -80,6 +81,7 @@ private Builder(SingleWorkerOptions options) { this.defaultHeartbeatThrottleInterval = options.getDefaultHeartbeatThrottleInterval(); this.buildId = options.getBuildId(); this.useBuildIdForVersioning = options.isUsingBuildIdForVersioning(); + this.drainStickyTaskQueueTimeout = options.getDrainStickyTaskQueueTimeout(); } public Builder setIdentity(String identity) { @@ -161,6 +163,11 @@ public Builder setUseBuildIdForVersioning(boolean useBuildIdForVersioning) { return this; } + public Builder setStickyTaskQueueDrainTimeout(Duration drainStickyTaskQueueTimeout) { + this.drainStickyTaskQueueTimeout = drainStickyTaskQueueTimeout; + return this; + } + public SingleWorkerOptions build() { PollerOptions pollerOptions = this.pollerOptions; if (pollerOptions == null) { @@ -177,6 +184,11 @@ public SingleWorkerOptions build() { metricsScope = new NoopScope(); } + Duration drainStickyTaskQueueTimeout = this.drainStickyTaskQueueTimeout; + if (drainStickyTaskQueueTimeout == null) { + drainStickyTaskQueueTimeout = Duration.ofSeconds(0); + } + return new SingleWorkerOptions( this.identity, this.binaryChecksum, @@ -192,7 +204,8 @@ public SingleWorkerOptions build() { this.stickyQueueScheduleToStartTimeout, this.defaultDeadlockDetectionTimeout, this.maxHeartbeatThrottleInterval, - this.defaultHeartbeatThrottleInterval); + this.defaultHeartbeatThrottleInterval, + drainStickyTaskQueueTimeout); } } @@ -211,6 +224,7 @@ public SingleWorkerOptions build() { private final long defaultDeadlockDetectionTimeout; private final Duration maxHeartbeatThrottleInterval; private final Duration defaultHeartbeatThrottleInterval; + private final Duration drainStickyTaskQueueTimeout; private SingleWorkerOptions( String identity, @@ -227,7 +241,8 @@ private SingleWorkerOptions( Duration stickyQueueScheduleToStartTimeout, long defaultDeadlockDetectionTimeout, Duration maxHeartbeatThrottleInterval, - Duration defaultHeartbeatThrottleInterval) { + Duration defaultHeartbeatThrottleInterval, + Duration drainStickyTaskQueueTimeout) { this.identity = identity; this.binaryChecksum = binaryChecksum; this.buildId = buildId; @@ -243,6 +258,7 @@ private SingleWorkerOptions( this.defaultDeadlockDetectionTimeout = defaultDeadlockDetectionTimeout; this.maxHeartbeatThrottleInterval = maxHeartbeatThrottleInterval; this.defaultHeartbeatThrottleInterval = defaultHeartbeatThrottleInterval; + this.drainStickyTaskQueueTimeout = drainStickyTaskQueueTimeout; } public String getIdentity() { @@ -265,6 +281,10 @@ public boolean isUsingBuildIdForVersioning() { return useBuildIdForVersioning; } + public Duration getDrainStickyTaskQueueTimeout() { + return drainStickyTaskQueueTimeout; + } + public DataConverter getDataConverter() { return dataConverter; } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/StickyQueueBalancer.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/StickyQueueBalancer.java index 33c6b6779..087a7b614 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/StickyQueueBalancer.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/StickyQueueBalancer.java @@ -21,6 +21,7 @@ package io.temporal.internal.worker; import io.temporal.api.enums.v1.TaskQueueKind; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.concurrent.ThreadSafe; @@ -30,6 +31,7 @@ public class StickyQueueBalancer { private final boolean stickyQueueEnabled; private final AtomicInteger stickyPollers = new AtomicInteger(0); private final AtomicInteger normalPollers = new AtomicInteger(0); + private final AtomicBoolean disableNormalPoll = new AtomicBoolean(false); private volatile long stickyBacklogSize = 0; @@ -43,6 +45,10 @@ public StickyQueueBalancer(int pollersCount, boolean stickyQueueEnabled) { */ public TaskQueueKind makePoll() { if (stickyQueueEnabled) { + if (disableNormalPoll.get()) { + stickyPollers.incrementAndGet(); + return TaskQueueKind.TASK_QUEUE_KIND_STICKY; + } // If pollersCount >= stickyBacklogSize > 0 we want to go back to a normal ratio to avoid a // situation that too many pollers (all of them in the worst case) will open only sticky queue // polls observing a stickyBacklogSize == 1 for example (which actually can be 0 already at @@ -83,4 +89,12 @@ public void finishPoll(TaskQueueKind taskQueueKind, long backlogSize) { stickyBacklogSize = backlogSize; } } + + public void disableNormalPoll() { + disableNormalPoll.set(true); + } + + public int getNormalPollerCount() { + return normalPollers.get(); + } } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncActivityWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncActivityWorker.java index 0634e0a5a..8e1537b7f 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncActivityWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncActivityWorker.java @@ -98,7 +98,7 @@ public boolean start() { @Override public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean interruptTasks) { return shutdownManager - // we want to shutdown heartbeatExecutor before activity worker, so in-flight activities + // we want to shut down heartbeatExecutor before activity worker, so in-flight activities // could get an ActivityWorkerShutdownException from their heartbeat .shutdownExecutor(heartbeatExecutor, this + "#heartbeatExecutor", Duration.ofSeconds(5)) .thenCompose(r -> worker.shutdown(shutdownManager, interruptTasks)) diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java index 96958d662..1ca14fc4f 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java @@ -78,6 +78,8 @@ final class WorkflowWorker implements SuspendableWorker { // Currently the implementation looks safe without volatile, but it's brittle. @Nonnull private SuspendableWorker poller = new NoopWorker(); + private StickyQueueBalancer stickyQueueBalancer; + public WorkflowWorker( @Nonnull WorkflowServiceStubs service, @Nonnull String namespace, @@ -118,7 +120,7 @@ public boolean start() { options.getTaskExecutorThreadPoolSize(), workerMetricsScope, true); - StickyQueueBalancer stickyQueueBalancer = + stickyQueueBalancer = new StickyQueueBalancer( options.getPollerOptions().getPollThreadCount(), stickyTaskQueueName != null); @@ -153,8 +155,21 @@ public boolean start() { @Override public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean interruptTasks) { String semaphoreName = this + "#executorSlotsSemaphore"; - return poller - .shutdown(shutdownManager, interruptTasks) + + boolean stickyQueueBalancerDrainEnabled = + !interruptTasks + && !options.getDrainStickyTaskQueueTimeout().isZero() + && stickyTaskQueueName != null + && stickyQueueBalancer != null; + + return CompletableFuture.completedFuture(null) + .thenCompose( + ignore -> + stickyQueueBalancerDrainEnabled + ? shutdownManager.waitForStickyQueueBalancer( + stickyQueueBalancer, options.getDrainStickyTaskQueueTimeout()) + : CompletableFuture.completedFuture(null)) + .thenCompose(ignore -> poller.shutdown(shutdownManager, interruptTasks)) .thenCompose( ignore -> !interruptTasks diff --git a/temporal-sdk/src/main/java/io/temporal/worker/Worker.java b/temporal-sdk/src/main/java/io/temporal/worker/Worker.java index 4255368a5..64408b871 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/Worker.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/Worker.java @@ -543,6 +543,7 @@ private static SingleWorkerOptions toWorkflowWorkerOptions( PollerOptions.newBuilder().setPollThreadCount(maxConcurrentWorkflowTaskPollers).build()) .setTaskExecutorThreadPoolSize(options.getMaxConcurrentWorkflowTaskExecutionSize()) .setStickyQueueScheduleToStartTimeout(stickyQueueScheduleToStartTimeout) + .setStickyTaskQueueDrainTimeout(options.getStickyTaskQueueDrainTimeout()) .setDefaultDeadlockDetectionTimeout(options.getDefaultDeadlockDetectionTimeout()) .setMetricsScope(metricsScope.tagged(tags)) .build(); diff --git a/temporal-sdk/src/main/java/io/temporal/worker/WorkerOptions.java b/temporal-sdk/src/main/java/io/temporal/worker/WorkerOptions.java index 2c85b3e76..c02331faf 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/WorkerOptions.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/WorkerOptions.java @@ -24,6 +24,7 @@ import com.google.common.base.Preconditions; import io.temporal.common.Experimental; +import io.temporal.serviceclient.WorkflowServiceStubsOptions; import java.time.Duration; import java.util.Objects; import javax.annotation.Nonnull; @@ -45,6 +46,8 @@ public static WorkerOptions getDefaultInstance() { static final Duration DEFAULT_STICKY_SCHEDULE_TO_START_TIMEOUT = Duration.ofSeconds(5); + static final Duration DEFAULT_STICKY_TASK_QUEUE_DRAIN_TIMEOUT = Duration.ofSeconds(0); + private static final WorkerOptions DEFAULT_INSTANCE; static { @@ -78,6 +81,7 @@ public static final class Builder { private boolean disableEagerExecution; private String buildId; private boolean useBuildIdForVersioning; + private Duration stickyTaskQueueDrainTimeout; private Builder() {} @@ -100,6 +104,7 @@ private Builder(WorkerOptions o) { this.disableEagerExecution = o.disableEagerExecution; this.useBuildIdForVersioning = o.useBuildIdForVersioning; this.buildId = o.buildId; + this.stickyTaskQueueDrainTimeout = o.stickyTaskQueueDrainTimeout; } /** @@ -349,6 +354,22 @@ public Builder setBuildId(String buildId) { return this; } + /** + * During graceful shutdown, as when calling {@link WorkerFactory#shutdown()}, if the workflow + * cache is enabled, this timeout controls how long to wait for the sticky task queue to drain + * before shutting down the worker. If set the worker will stop making new poll requests on the + * normal task queue, but will continue to poll the sticky task queue until the timeout is + * reached. This value should always be greater than clients rpc long poll timeout, which can be + * set via {@link WorkflowServiceStubsOptions.Builder#setRpcLongPollTimeout(Duration)}. + * + *

Default is not to wait. + */ + @Experimental + public Builder setStickyTaskQueueDrainTimeout(Duration stickyTaskQueueDrainTimeout) { + this.stickyTaskQueueDrainTimeout = stickyTaskQueueDrainTimeout; + return this; + } + public WorkerOptions build() { return new WorkerOptions( maxWorkerActivitiesPerSecond, @@ -365,7 +386,8 @@ public WorkerOptions build() { stickyQueueScheduleToStartTimeout, disableEagerExecution, useBuildIdForVersioning, - buildId); + buildId, + stickyTaskQueueDrainTimeout); } public WorkerOptions validateAndBuildWithDefaults() { @@ -396,6 +418,9 @@ public WorkerOptions validateAndBuildWithDefaults() { buildId != null && !buildId.isEmpty(), "buildId must be set non-empty if useBuildIdForVersioning is set true"); } + Preconditions.checkState( + stickyTaskQueueDrainTimeout == null || !stickyTaskQueueDrainTimeout.isNegative(), + "negative stickyTaskQueueDrainTimeout"); return new WorkerOptions( maxWorkerActivitiesPerSecond, @@ -430,7 +455,10 @@ public WorkerOptions validateAndBuildWithDefaults() { : stickyQueueScheduleToStartTimeout, disableEagerExecution, useBuildIdForVersioning, - buildId); + buildId, + stickyTaskQueueDrainTimeout == null + ? DEFAULT_STICKY_TASK_QUEUE_DRAIN_TIMEOUT + : stickyTaskQueueDrainTimeout); } } @@ -449,6 +477,7 @@ public WorkerOptions validateAndBuildWithDefaults() { private final boolean disableEagerExecution; private final boolean useBuildIdForVersioning; private final String buildId; + private final Duration stickyTaskQueueDrainTimeout; private WorkerOptions( double maxWorkerActivitiesPerSecond, @@ -465,7 +494,8 @@ private WorkerOptions( @Nonnull Duration stickyQueueScheduleToStartTimeout, boolean disableEagerExecution, boolean useBuildIdForVersioning, - String buildId) { + String buildId, + Duration stickyTaskQueueDrainTimeout) { this.maxWorkerActivitiesPerSecond = maxWorkerActivitiesPerSecond; this.maxConcurrentActivityExecutionSize = maxConcurrentActivityExecutionSize; this.maxConcurrentWorkflowTaskExecutionSize = maxConcurrentWorkflowExecutionSize; @@ -481,6 +511,7 @@ private WorkerOptions( this.disableEagerExecution = disableEagerExecution; this.useBuildIdForVersioning = useBuildIdForVersioning; this.buildId = buildId; + this.stickyTaskQueueDrainTimeout = stickyTaskQueueDrainTimeout; } public double getMaxWorkerActivitiesPerSecond() { @@ -560,6 +591,10 @@ public String getBuildId() { return buildId; } + public Duration getStickyTaskQueueDrainTimeout() { + return stickyTaskQueueDrainTimeout; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -579,7 +614,8 @@ && compare(that.maxTaskQueueActivitiesPerSecond, maxTaskQueueActivitiesPerSecond && Objects.equals(stickyQueueScheduleToStartTimeout, that.stickyQueueScheduleToStartTimeout) && disableEagerExecution == that.disableEagerExecution && useBuildIdForVersioning == that.useBuildIdForVersioning - && Objects.equals(that.buildId, buildId); + && Objects.equals(that.buildId, buildId) + && Objects.equals(stickyTaskQueueDrainTimeout, that.stickyTaskQueueDrainTimeout); } @Override @@ -599,7 +635,8 @@ public int hashCode() { stickyQueueScheduleToStartTimeout, disableEagerExecution, useBuildIdForVersioning, - buildId); + buildId, + stickyTaskQueueDrainTimeout); } @Override @@ -635,6 +672,8 @@ public String toString() { + useBuildIdForVersioning + ", buildId='" + buildId + + ", stickyTaskQueueDrainTimeout='" + + stickyTaskQueueDrainTimeout + '}'; } } diff --git a/temporal-sdk/src/test/java/io/temporal/worker/shutdown/StickyWorkflowDrainShutdownTest.java b/temporal-sdk/src/test/java/io/temporal/worker/shutdown/StickyWorkflowDrainShutdownTest.java new file mode 100644 index 000000000..50a529b32 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/worker/shutdown/StickyWorkflowDrainShutdownTest.java @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2022 Temporal Technologies, Inc. All Rights Reserved. + * + * Copyright (C) 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Modifications copyright (C) 2017 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this material except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.temporal.worker.shutdown; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowStub; +import io.temporal.serviceclient.WorkflowServiceStubsOptions; +import io.temporal.testing.internal.SDKTestWorkflowRule; +import io.temporal.worker.WorkerOptions; +import io.temporal.workflow.Workflow; +import io.temporal.workflow.shared.TestWorkflows.TestWorkflow1; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import org.junit.Rule; +import org.junit.Test; + +public class StickyWorkflowDrainShutdownTest { + private static final Duration DRAIN_TIME = Duration.ofSeconds(7); + + @Rule + public SDKTestWorkflowRule testWorkflowRule = + SDKTestWorkflowRule.newBuilder() + .setWorkflowTypes(TestWorkflowImpl.class) + .setUseTimeskipping(false) + .setWorkerOptions( + WorkerOptions.newBuilder().setStickyTaskQueueDrainTimeout(DRAIN_TIME).build()) + .setWorkflowServiceStubsOptions( + WorkflowServiceStubsOptions.newBuilder() + .setRpcLongPollTimeout(Duration.ofSeconds(5)) + .build()) + .build(); + + @Test + public void testShutdown() { + TestWorkflow1 workflow = testWorkflowRule.newWorkflowStub(TestWorkflow1.class); + WorkflowClient.start(workflow::execute, null); + testWorkflowRule.getTestEnvironment().shutdown(); + long startTime = System.currentTimeMillis(); + testWorkflowRule.getTestEnvironment().awaitTermination(10, TimeUnit.SECONDS); + long endTime = System.currentTimeMillis(); + assertTrue("Drain time should be respected", endTime - startTime > DRAIN_TIME.toMillis()); + assertTrue(testWorkflowRule.getTestEnvironment().getWorkerFactory().isTerminated()); + // Workflow should complete successfully since the drain time is longer than the workflow + // execution time + assertEquals("Success", workflow.execute(null)); + } + + @Test + public void testShutdownNow() { + TestWorkflow1 workflow = testWorkflowRule.newWorkflowStub(TestWorkflow1.class); + WorkflowClient.start(workflow::execute, null); + long startTime = System.currentTimeMillis(); + testWorkflowRule.getTestEnvironment().shutdownNow(); + long endTime = System.currentTimeMillis(); + testWorkflowRule.getTestEnvironment().awaitTermination(10, TimeUnit.SECONDS); + assertTrue( + "Drain time does not need to be respected", endTime - startTime < DRAIN_TIME.toMillis()); + assertTrue(testWorkflowRule.getTestEnvironment().getWorkerFactory().isTerminated()); + // Cleanup workflow that will not finish + WorkflowStub untyped = WorkflowStub.fromTyped(workflow); + untyped.terminate("terminate"); + } + + public static class TestWorkflowImpl implements TestWorkflow1 { + + @Override + public String execute(String now) { + for (int i = 0; i < 5; i++) { + Workflow.sleep(1000); + } + return "Success"; + } + } +}