Skip to content

Commit

Permalink
Add support for virtual workflow threads (#2297)
Browse files Browse the repository at this point in the history
Add support for virtual workflow threads
  • Loading branch information
Quinn-With-Two-Ns authored Nov 4, 2024
1 parent 37081cc commit f6bf576
Show file tree
Hide file tree
Showing 27 changed files with 902 additions and 72 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ jobs:
- name: Set up Java
uses: actions/setup-java@v4
with:
java-version: "11"
java-version: |
21
11
distribution: "temurin"

- name: Set up Gradle
Expand All @@ -79,6 +81,13 @@ jobs:
USE_DOCKER_SERVICE: true
run: ./gradlew --no-daemon test -x checkLicenseMain -x checkLicenses -x spotlessCheck -x spotlessApply -x spotlessJava

- name: Run virtual thread tests
env:
USER: unittest
TEMPORAL_SERVICE_ADDRESS: localhost:7233
USE_DOCKER_SERVICE: true
run: ./gradlew --no-daemon :temporal-sdk:virtualThreadTests -x checkLicenseMain -x checkLicenses -x spotlessCheck -x spotlessApply -x spotlessJava

- name: Publish Test Report
uses: mikepenz/action-junit-report@v4
if: success() || failure() # always run even if the previous step fails
Expand Down
105 changes: 105 additions & 0 deletions temporal-sdk/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,36 @@ dependencies {
testImplementation group: 'ch.qos.logback', name: 'logback-classic', version: "${logbackVersion}"
}

// Temporal SDK supports Java 8 or later so to support virtual threads
// we need to compile the code with Java 21 and package it in a multi-release jar.
sourceSets {
java21 {
java {
srcDirs = ['src/main/java21']
}
}
}

dependencies {
java21Implementation files(sourceSets.main.output.classesDirs) { builtBy compileJava }
}

tasks.named('compileJava21Java') {
javaCompiler = javaToolchains.compilerFor {
languageVersion = JavaLanguageVersion.of(21)
}
options.release = 21
}

jar {
into('META-INF/versions/21') {
from sourceSets.java21.output
}
manifest.attributes(
'Multi-Release': 'true'
)
}

task registerNamespace(type: JavaExec) {
getMainClass().set('io.temporal.internal.docker.RegisterTestNamespace')
classpath = sourceSets.test.runtimeClasspath
Expand All @@ -49,4 +79,79 @@ task testResourceIndependent(type: Test) {
includeCategories 'io.temporal.worker.IndependentResourceBasedTests'
maxParallelForks = 1
}
}

// To test the virtual thread support we need to run a separate test suite with Java 21
testing {
suites {
// Common setup for all test suites
configureEach {
useJUnit(junitVersion)
dependencies {
implementation project()
implementation "ch.qos.logback:logback-classic:${logbackVersion}"
implementation project(':temporal-testing')

implementation "junit:junit:${junitVersion}"
implementation "org.mockito:mockito-core:${mockitoVersion}"
implementation 'pl.pragmatists:JUnitParams:1.1.1'
implementation("com.jayway.jsonpath:json-path:$jsonPathVersion"){
exclude group: 'org.slf4j', module: 'slf4j-api'
}
}
targets {
all {
testTask.configure {
testLogging {
events 'passed', 'skipped', 'failed'
exceptionFormat 'full'
// Uncomment the following line if you want to see test logs in gradlew run.
showStandardStreams true
}
}
}
}
}

virtualThreadTests(JvmTestSuite) {
targets {
all {
testTask.configure {
javaLauncher = javaToolchains.launcherFor {
languageVersion = JavaLanguageVersion.of(21)
}
shouldRunAfter(test)
}
}
}
}

// Run the same test as the normal test task with virtual threads
testsWithVirtualThreads(JvmTestSuite) {
// Use the same source and resources as the main test set
sources {
java {
srcDirs = ['src/test/java']
}
resources {
srcDirs = ["src/test/resources"]
}
}

targets {
all {
testTask.configure {
javaLauncher = javaToolchains.launcherFor {
languageVersion = JavaLanguageVersion.of(21)
}
environment("USE_VIRTUAL_THREADS", "false")
}
}
}
}
}
}

tasks.named('check') {
dependsOn(testing.suites.virtualThreadTests)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (C) 2022 Temporal Technologies, Inc. All Rights Reserved.
*
* Copyright (C) 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Modifications copyright (C) 2017 Uber Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this material except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.temporal.internal.task;

/**
* Function interface for {@link VirtualThreadDelegate#newVirtualThreadExecutor(ThreadConfigurator)}
* called for every thread created.
*/
@FunctionalInterface
public interface ThreadConfigurator {
/** Invoked for every thread created by {@link VirtualThreadDelegate#newVirtualThreadExecutor}. */
void configure(Thread t);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (C) 2022 Temporal Technologies, Inc. All Rights Reserved.
*
* Copyright (C) 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Modifications copyright (C) 2017 Uber Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this material except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.temporal.internal.task;

import java.util.concurrent.ExecutorService;

/**
* Internal delegate for virtual thread handling on JDK 21. This is a dummy version for reachability
* on JDK <21.
*/
public final class VirtualThreadDelegate {
public static ExecutorService newVirtualThreadExecutor(ThreadConfigurator configurator) {
throw new UnsupportedOperationException("Virtual threads not supported on JDK <21");
}

private VirtualThreadDelegate() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ public boolean start() {
new TaskHandlerImpl(handler),
pollerOptions,
slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
true);
true,
options.isUsingVirtualThreads());
poller =
new Poller<>(
options.getIdentity(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,8 @@ public boolean start() {
new AttemptTaskHandlerImpl(handler),
pollerOptions,
slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
false);
false,
options.isUsingVirtualThreads());

this.workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1);
this.slotQueue.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ public boolean start() {
new TaskHandlerImpl(handler),
pollerOptions,
slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
true);
true,
options.isUsingVirtualThreads());
poller =
new Poller<>(
options.getIdentity(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@

import com.google.common.base.Preconditions;
import io.temporal.internal.logging.LoggerTag;
import io.temporal.internal.task.VirtualThreadDelegate;
import java.util.Objects;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nonnull;
import org.slf4j.MDC;

Expand All @@ -41,7 +43,7 @@ public interface TaskHandler<TT> {
private final TaskHandler<T> handler;
private final PollerOptions pollerOptions;

private final ThreadPoolExecutor taskExecutor;
private final ExecutorService taskExecutor;
private final String pollThreadNamePrefix;

PollTaskExecutor(
Expand All @@ -51,35 +53,46 @@ public interface TaskHandler<TT> {
@Nonnull TaskHandler<T> handler,
@Nonnull PollerOptions pollerOptions,
int workerTaskSlots,
boolean synchronousQueue) {
boolean synchronousQueue,
boolean useVirtualThreads) {
this.namespace = Objects.requireNonNull(namespace);
this.taskQueue = Objects.requireNonNull(taskQueue);
this.identity = Objects.requireNonNull(identity);
this.handler = Objects.requireNonNull(handler);
this.pollerOptions = Objects.requireNonNull(pollerOptions);

this.taskExecutor =
new ThreadPoolExecutor(
// for SynchronousQueue we can afford to set it to 0, because the queue is always full
// or empty
// for LinkedBlockingQueue we have to set slots to workerTaskSlots to avoid situation
// when the queue grows, but the amount of threads is not, because the queue is not (and
// never) full
synchronousQueue ? 0 : workerTaskSlots,
workerTaskSlots,
10,
TimeUnit.SECONDS,
synchronousQueue ? new SynchronousQueue<>() : new LinkedBlockingQueue<>());
this.taskExecutor.allowCoreThreadTimeOut(true);

this.pollThreadNamePrefix =
pollerOptions.getPollThreadNamePrefix().replaceFirst("Poller", "Executor");

this.taskExecutor.setThreadFactory(
new ExecutorThreadFactory(
pollerOptions.getPollThreadNamePrefix().replaceFirst("Poller", "Executor"),
pollerOptions.getUncaughtExceptionHandler()));
this.taskExecutor.setRejectedExecutionHandler(new BlockCallerPolicy());
// If virtual threads are enabled, we use a virtual thread executor.
if (useVirtualThreads) {
AtomicInteger threadIndex = new AtomicInteger();
this.taskExecutor =
VirtualThreadDelegate.newVirtualThreadExecutor(
(t) -> {
t.setName(this.pollThreadNamePrefix + ": " + threadIndex.incrementAndGet());
t.setUncaughtExceptionHandler(pollerOptions.getUncaughtExceptionHandler());
});
} else {
ThreadPoolExecutor threadPoolTaskExecutor =
new ThreadPoolExecutor(
// for SynchronousQueue we can afford to set it to 0, because the queue is always full
// or empty
// for LinkedBlockingQueue we have to set slots to workerTaskSlots to avoid situation
// when the queue grows, but the amount of threads is not, because the queue is not
// (and
// never) full
synchronousQueue ? 0 : workerTaskSlots,
workerTaskSlots,
10,
TimeUnit.SECONDS,
synchronousQueue ? new SynchronousQueue<>() : new LinkedBlockingQueue<>());
threadPoolTaskExecutor.allowCoreThreadTimeOut(true);
threadPoolTaskExecutor.setThreadFactory(
new ExecutorThreadFactory(
this.pollThreadNamePrefix, pollerOptions.getUncaughtExceptionHandler()));
threadPoolTaskExecutor.setRejectedExecutionHandler(new BlockCallerPolicy());
this.taskExecutor = threadPoolTaskExecutor;
}
}

@Override
Expand Down
46 changes: 31 additions & 15 deletions temporal-sdk/src/main/java/io/temporal/internal/worker/Poller.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import io.grpc.StatusRuntimeException;
import io.temporal.internal.BackoffThrottler;
import io.temporal.internal.common.GrpcUtils;
import io.temporal.internal.task.VirtualThreadDelegate;
import io.temporal.worker.MetricsType;
import java.time.Duration;
import java.util.Objects;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -57,7 +59,7 @@ interface ThrowingRunnable {
private final PollTask<T> pollTask;
private final PollerOptions pollerOptions;
private static final Logger log = LoggerFactory.getLogger(Poller.class);
private ThreadPoolExecutor pollExecutor;
private ExecutorService pollExecutor;
private final Scope workerMetricsScope;

private final AtomicReference<CountDownLatch> suspendLatch = new AtomicReference<>();
Expand Down Expand Up @@ -97,20 +99,34 @@ public boolean start() {
pollerOptions.getMaximumPollRatePerSecond(),
pollerOptions.getMaximumPollRateIntervalMilliseconds());
}

// It is important to pass blocking queue of at least options.getPollThreadCount() capacity. As
// task enqueues next task the buffering is needed to queue task until the previous one releases
// a thread.
pollExecutor =
new ThreadPoolExecutor(
pollerOptions.getPollThreadCount(),
pollerOptions.getPollThreadCount(),
1,
TimeUnit.SECONDS,
new ArrayBlockingQueue<>(pollerOptions.getPollThreadCount()));
pollExecutor.setThreadFactory(
new ExecutorThreadFactory(
pollerOptions.getPollThreadNamePrefix(), pollerOptions.getUncaughtExceptionHandler()));
// If virtual threads are enabled, we use a virtual thread executor.
if (pollerOptions.isUsingVirtualThreads()) {
AtomicInteger threadIndex = new AtomicInteger();
pollExecutor =
VirtualThreadDelegate.newVirtualThreadExecutor(
(t) -> {
// TODO: Consider using a more descriptive name for the thread.
t.setName(
pollerOptions.getPollThreadNamePrefix() + ": " + threadIndex.incrementAndGet());
t.setUncaughtExceptionHandler(uncaughtExceptionHandler);
});
} else {
// It is important to pass blocking queue of at least options.getPollThreadCount() capacity.
// As task enqueues next task the buffering is needed to queue task until the previous one
// releases a thread.
ThreadPoolExecutor threadPoolPoller =
new ThreadPoolExecutor(
pollerOptions.getPollThreadCount(),
pollerOptions.getPollThreadCount(),
1,
TimeUnit.SECONDS,
new ArrayBlockingQueue<>(pollerOptions.getPollThreadCount()));
threadPoolPoller.setThreadFactory(
new ExecutorThreadFactory(
pollerOptions.getPollThreadNamePrefix(),
pollerOptions.getUncaughtExceptionHandler()));
pollExecutor = threadPoolPoller;
}

for (int i = 0; i < pollerOptions.getPollThreadCount(); i++) {
pollExecutor.execute(new PollLoopTask(new PollExecutionTask()));
Expand Down
Loading

0 comments on commit f6bf576

Please sign in to comment.