Skip to content

Commit

Permalink
Fix caching in WorkflowLocal/WorkflowThreadLocal (temporalio#1876)
Browse files Browse the repository at this point in the history
Reverted caching changes made to WorkflowLocal/WorkflowThreadLocal,
which broke backwards compatibility and accidentally shared values
between Workflows/Threads. Re-implemented caching as an optional
feature, and deprecated the factory methods that created non-caching
instances.
  • Loading branch information
dano committed Oct 2, 2023
1 parent 717ee05 commit 9c0cc84
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
public class ContextThreadLocal {

private static final WorkflowThreadLocal<List<ContextPropagator>> contextPropagators =
WorkflowThreadLocal.withInitial(
WorkflowThreadLocal.withCachedInitial(
new Supplier<List<ContextPropagator>>() {
@Override
public List<ContextPropagator> get() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,26 @@
import java.util.function.Supplier;

public final class RunnerLocalInternal<T> {
private T supplierResult = null;
private boolean supplierCalled = false;

Optional<T> invokeSupplier(Supplier<? extends T> supplier) {
if (!supplierCalled) {
T result = supplier.get();
supplierCalled = true;
supplierResult = result;
return Optional.ofNullable(result);
} else {
return Optional.ofNullable(supplierResult);
}

private final boolean useCaching;

public RunnerLocalInternal() {
this.useCaching = false;
}

public RunnerLocalInternal(boolean useCaching) {
this.useCaching = useCaching;
}

public T get(Supplier<? extends T> supplier) {
Optional<Optional<T>> result =
DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this);
return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null);
T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null);
if (!result.isPresent() && useCaching) {
// This is the first time we've tried fetching this, and caching is enabled. Store it.
set(out);
}
return out;
}

public void set(T value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,25 @@

public final class WorkflowThreadLocalInternal<T> {

private T supplierResult = null;
private boolean supplierCalled = false;

Optional<T> invokeSupplier(Supplier<? extends T> supplier) {
if (!supplierCalled) {
T result = supplier.get();
supplierCalled = true;
supplierResult = result;
return Optional.ofNullable(result);
} else {
return Optional.ofNullable(supplierResult);
}
private final boolean useCaching;

public WorkflowThreadLocalInternal() {
this(false);
}

public WorkflowThreadLocalInternal(boolean useCaching) {
this.useCaching = useCaching;
}

public T get(Supplier<? extends T> supplier) {
Optional<Optional<T>> result =
DeterministicRunnerImpl.currentThreadInternal().getThreadLocal(this);
return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null);
T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null);
if (!result.isPresent() && useCaching) {
// This is the first time we've tried fetching this, and caching is enabled. Store it.
set(out);
}
return out;
}

public void set(T value) {
Expand Down
40 changes: 36 additions & 4 deletions temporal-sdk/src/main/java/io/temporal/workflow/WorkflowLocal.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
* <pre>{@code
* public class Workflow {
*
* private static final WorkflowLocal<Boolean> signaled = WorkflowLocal.withInitial(() -> false);
* private static final WorkflowLocal<Boolean> signaled = WorkflowLocal.withCachedInitial(() -> false);
*
* public static boolean isSignaled() {
* return signaled.get();
Expand All @@ -49,19 +49,51 @@
*/
public final class WorkflowLocal<T> {

private final RunnerLocalInternal<T> impl = new RunnerLocalInternal<>();
private final RunnerLocalInternal<T> impl;
private final Supplier<? extends T> supplier;

private WorkflowLocal(Supplier<? extends T> supplier) {
private WorkflowLocal(Supplier<? extends T> supplier, boolean useCaching) {
this.supplier = Objects.requireNonNull(supplier);
this.impl = new RunnerLocalInternal<>(useCaching);
}

public WorkflowLocal() {
this.supplier = () -> null;
this.impl = new RunnerLocalInternal<>(false);
}

/**
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
* #set(S)} has not yet been called in the Workflow. Note that the value returned by the {@code
* Supplier} is not stored in the {@code WorkflowLocal} implicitly; repeatedly calling {@link
* #get()} will always re-execute the {@code Supplier} until you call {@link #set(S)} for the
* first time. If you want the value returned by the {@code Supplier} to be stored in the {@code
* WorkflowLocal}, use {@link #withCachedInitial(Supplier)} instead.
*
* @param supplier Callback that will be executed whenever {@link #get()} is called, until {@link
* #set(S)} is called for the first time.
* @return A {@code WorkflowLocal} instance.
* @param <S> The type stored in the {@code WorkflowLocal}.
* @deprecated Because the non-caching behavior of this API is typically not desirable, it's
* recommend to use {@link #withCachedInitial(Supplier)} instead.
*/
@Deprecated
public static <S> WorkflowLocal<S> withInitial(Supplier<? extends S> supplier) {
return new WorkflowLocal<>(supplier);
return new WorkflowLocal<>(supplier, false);
}

/**
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
* #set(S)} has not yet been called in the Workflow, and then stores the returned value inside the
* {@code WorkflowLocal}.
*
* @param supplier Callback that will be executed when {@link #get()} is called for the first
* time, if {@link #set(S)} has not already been called.
* @return A {@code WorkflowLocal} instance.
* @param <S> The type stored in the {@code WorkflowLocal}.
*/
public static <S> WorkflowLocal<S> withCachedInitial(Supplier<? extends S> supplier) {
return new WorkflowLocal<>(supplier, true);
}

public T get() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,52 @@
/** {@link ThreadLocal} analog for workflow code. */
public final class WorkflowThreadLocal<T> {

private final WorkflowThreadLocalInternal<T> impl = new WorkflowThreadLocalInternal<>();
private final WorkflowThreadLocalInternal<T> impl;
private final Supplier<? extends T> supplier;

private WorkflowThreadLocal(Supplier<? extends T> supplier) {
private WorkflowThreadLocal(Supplier<? extends T> supplier, boolean useCaching) {
this.supplier = Objects.requireNonNull(supplier);
this.impl = new WorkflowThreadLocalInternal<>(useCaching);
}

public WorkflowThreadLocal() {
this.supplier = () -> null;
this.impl = new WorkflowThreadLocalInternal<>(false);
}

/**
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
* #set(S)} has not yet been called in the thread. Note that the value returned by the {@code
* Supplier} is not stored in the {@code WorkflowThreadLocal} implicitly; repeatedly calling
* {@link #get()} will always re-execute the {@code Supplier} until you call {@link #set(S)} for
* the first time. This differs from the behavior of {@code ThreadLocal}. If you want the value
* returned by the {@code Supplier} to be stored in the {@code WorkflowThreadLocal}, which matches
* the behavior of {@code ThreadLocal}, use {@link #withCachedInitial(Supplier)} instead.
*
* @param supplier Callback that will be executed whenever {@link #get()} is called, until {@link
* #set(S)} is called for the first time.
* @return A {@code WorkflowThreadLocal} instance.
* @param <S> The type stored in the {@code WorkflowThreadLocal}.
* @deprecated Because the non-caching behavior of this API is typically not desirable, it's
* recommend to use {@link #withCachedInitial(Supplier)} instead.
*/
@Deprecated
public static <S> WorkflowThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new WorkflowThreadLocal<>(supplier);
return new WorkflowThreadLocal<>(supplier, false);
}

/**
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
* #set(S)} has not yet been called in the Workflow, and then stores the returned value inside the
* {@code WorkflowThreadLocal}.
*
* @param supplier Callback that will be executed when {@link #get()} is called for the first
* time, if {@link #set(S)} has not already been called.
* @return A {@code WorkflowThreadLocal} instance.
* @param <S> The type stored in the {@code WorkflowThreadLocal}.
*/
public static <S> WorkflowThreadLocal<S> withCachedInitial(Supplier<? extends S> supplier) {
return new WorkflowThreadLocal<>(supplier, true);
}

public T get() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -945,14 +945,14 @@ private static Supplier<String> getStringSupplier(AtomicInteger supplierCalls) {
}

@Test
public void testSupplierCalledOnce() {
public void testSupplierCalledOnceWithCaching() {
AtomicInteger supplierCalls = new AtomicInteger();
DeterministicRunnerImpl d =
new DeterministicRunnerImpl(
threadPool::submit,
DummySyncWorkflowContext.newDummySyncWorkflowContext(),
() -> {
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>();
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>(true);
runnerLocalInternal.get(getStringSupplier(supplierCalls));
runnerLocalInternal.get(getStringSupplier(supplierCalls));
runnerLocalInternal.get(getStringSupplier(supplierCalls));
Expand All @@ -963,4 +963,24 @@ public void testSupplierCalledOnce() {
});
d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
}

@Test
public void testSupplierCalledMultipleWithoutCaching() {
AtomicInteger supplierCalls = new AtomicInteger();
DeterministicRunnerImpl d =
new DeterministicRunnerImpl(
threadPool::submit,
DummySyncWorkflowContext.newDummySyncWorkflowContext(),
() -> {
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>(false);
runnerLocalInternal.get(getStringSupplier(supplierCalls));
runnerLocalInternal.get(getStringSupplier(supplierCalls));
runnerLocalInternal.get(getStringSupplier(supplierCalls));
assertEquals(
"supplier default value",
runnerLocalInternal.get(getStringSupplier(supplierCalls)));
assertEquals(4, supplierCalls.get());
});
d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
package io.temporal.workflow;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;

import io.temporal.testing.internal.SDKTestWorkflowRule;
import io.temporal.workflow.shared.TestWorkflows.TestWorkflow1;
import io.temporal.workflow.shared.TestWorkflows.TestWorkflowReturnString;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Assert;
Expand All @@ -47,9 +50,11 @@ public void testWorkflowLocals() {

public static class TestWorkflowLocals implements TestWorkflow1 {

@SuppressWarnings("deprecation")
private final WorkflowThreadLocal<Integer> threadLocal =
WorkflowThreadLocal.withInitial(() -> 2);

@SuppressWarnings("deprecation")
private final WorkflowLocal<Integer> workflowLocal = WorkflowLocal.withInitial(() -> 5);

@Override
Expand Down Expand Up @@ -84,12 +89,15 @@ public static class TestWorkflowLocalsSupplierReuse implements TestWorkflow1 {
private final AtomicInteger localCalls = new AtomicInteger(0);
private final AtomicInteger threadLocalCalls = new AtomicInteger(0);

@SuppressWarnings("deprecation")
private final WorkflowThreadLocal<Integer> workflowThreadLocal =
WorkflowThreadLocal.withInitial(
() -> {
threadLocalCalls.addAndGet(1);
return null;
});

@SuppressWarnings("deprecation")
private final WorkflowLocal<Integer> workflowLocal =
WorkflowLocal.withInitial(
() -> {
Expand Down Expand Up @@ -131,4 +139,93 @@ public void testWorkflowLocalsSupplierReuse() {
String result = workflowStub.execute(testWorkflowRule.getTaskQueue());
Assert.assertEquals("ok", result);
}

@SuppressWarnings("deprecation")
static final WorkflowThreadLocal<AtomicInteger> threadLocal =
WorkflowThreadLocal.withInitial(() -> new AtomicInteger(2));

@SuppressWarnings("deprecation")
static final WorkflowLocal<AtomicInteger> workflowLocal =
WorkflowLocal.withInitial(() -> new AtomicInteger(5));

static final WorkflowThreadLocal<AtomicInteger> threadLocalCached =
WorkflowThreadLocal.withCachedInitial(() -> new AtomicInteger(2));

static final WorkflowLocal<AtomicInteger> workflowLocalCached =
WorkflowLocal.withCachedInitial(() -> new AtomicInteger(5));

public static class TestInit implements TestWorkflowReturnString {

@Override
public String execute() {
assertEquals(2, threadLocal.get().getAndSet(3));
assertEquals(5, workflowLocal.get().getAndSet(6));
assertEquals(2, threadLocalCached.get().getAndSet(3));
assertEquals(5, workflowLocalCached.get().getAndSet(6));
String out = Workflow.newChildWorkflowStub(TestWorkflow1.class).execute("ign");
assertEquals("ok", out);
return "result="
+ threadLocal.get().get()
+ ", "
+ workflowLocal.get().get()
+ ", "
+ threadLocalCached.get().get()
+ ", "
+ workflowLocalCached.get().get();
}
}

public static class TestChildInit implements TestWorkflow1 {

@Override
public String execute(String arg1) {
assertEquals(2, threadLocal.get().getAndSet(8));
assertEquals(5, workflowLocal.get().getAndSet(0));
return "ok";
}
}

@Rule
public SDKTestWorkflowRule testWorkflowRuleInitialValueNotShared =
SDKTestWorkflowRule.newBuilder()
.setWorkflowTypes(TestInit.class, TestChildInit.class)
.build();

@Test
public void testWorkflowInitialNotShared() {
TestWorkflowReturnString workflowStub =
testWorkflowRuleInitialValueNotShared.newWorkflowStubTimeoutOptions(
TestWorkflowReturnString.class);
String result = workflowStub.execute();
Assert.assertEquals("result=2, 5, 3, 6", result);
}

public static class TestCaching implements TestWorkflow1 {

@Override
public String execute(String arg1) {
assertNotSame(threadLocal.get(), threadLocal.get());
assertNotSame(workflowLocal.get(), workflowLocal.get());
threadLocal.set(threadLocal.get());
workflowLocal.set(workflowLocal.get());
assertSame(threadLocal.get(), threadLocal.get());
assertSame(workflowLocal.get(), workflowLocal.get());

assertSame(threadLocalCached.get(), threadLocalCached.get());
assertSame(workflowLocalCached.get(), workflowLocalCached.get());
return "ok";
}
}

@Rule
public SDKTestWorkflowRule testWorkflowRuleCaching =
SDKTestWorkflowRule.newBuilder().setWorkflowTypes(TestCaching.class).build();

@Test
public void testWorkflowLocalCaching() {
TestWorkflow1 workflowStub =
testWorkflowRuleCaching.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
String out = workflowStub.execute("ign");
assertEquals("ok", out);
}
}

0 comments on commit 9c0cc84

Please sign in to comment.