From ddd3f2bc3a5cf2d0bfcd384adc8525a3321115d0 Mon Sep 17 00:00:00 2001 From: Dan O'Reilly Date: Thu, 28 Sep 2023 15:28:39 -0400 Subject: [PATCH] Fix caching in WorkflowLocal/WorkflowThreadLocal (#1876) Reverted caching changes made to WorkflowLocal/WorkflowThreadLocal, which broke backwards compatibility and accidentally shared values between Workflows/Threads. Re-implemented caching as an optional feature, and deprecated the factory methods that created non-caching instances. --- .../internal/context/ContextThreadLocal.java | 2 +- .../internal/sync/RunnerLocalInternal.java | 28 +++--- .../sync/WorkflowThreadLocalInternal.java | 27 +++--- .../io/temporal/workflow/WorkflowLocal.java | 38 +++++++- .../workflow/WorkflowThreadLocal.java | 39 +++++++- .../temporal/workflow/WorkflowLocalsTest.java | 97 +++++++++++++++++++ 6 files changed, 198 insertions(+), 33 deletions(-) diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java index 2e00d3f8db..688022dccf 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java @@ -32,7 +32,7 @@ public class ContextThreadLocal { private static final WorkflowThreadLocal> contextPropagators = - WorkflowThreadLocal.withInitial( + WorkflowThreadLocal.withCachedInitial( new Supplier>() { @Override public List get() { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java index 9bacc5f0cd..c3ed400a05 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java @@ -24,24 +24,26 @@ import java.util.function.Supplier; public final class RunnerLocalInternal { - private T supplierResult = null; - private boolean supplierCalled = false; - - Optional invokeSupplier(Supplier supplier) { - if (!supplierCalled) { - T result = supplier.get(); - supplierCalled = true; - supplierResult = result; - return Optional.ofNullable(result); - } else { - return Optional.ofNullable(supplierResult); - } + + private final boolean useCaching; + + public RunnerLocalInternal() { + this.useCaching = false; + } + + public RunnerLocalInternal(boolean useCaching) { + this.useCaching = useCaching; } public T get(Supplier supplier) { Optional> result = DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this); - return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null); + T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null); + if (!result.isPresent() && useCaching) { + // This is the first time we've tried fetching this, and caching is enabled. Store it. + set(out); + } + return out; } public void set(T value) { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadLocalInternal.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadLocalInternal.java index 1226162224..79fe1db5c8 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadLocalInternal.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadLocalInternal.java @@ -25,24 +25,25 @@ public final class WorkflowThreadLocalInternal { - private T supplierResult = null; - private boolean supplierCalled = false; - - Optional invokeSupplier(Supplier supplier) { - if (!supplierCalled) { - T result = supplier.get(); - supplierCalled = true; - supplierResult = result; - return Optional.ofNullable(result); - } else { - return Optional.ofNullable(supplierResult); - } + private final boolean useCaching; + + public WorkflowThreadLocalInternal() { + this(false); + } + + public WorkflowThreadLocalInternal(boolean useCaching) { + this.useCaching = useCaching; } public T get(Supplier supplier) { Optional> result = DeterministicRunnerImpl.currentThreadInternal().getThreadLocal(this); - return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null); + T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null); + if (!result.isPresent() && useCaching) { + // This is the first time we've tried fetching this, and caching is enabled. Store it. + set(out); + } + return out; } public void set(T value) { diff --git a/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowLocal.java b/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowLocal.java index 65194818ca..373aa76f68 100644 --- a/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowLocal.java +++ b/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowLocal.java @@ -49,19 +49,51 @@ */ public final class WorkflowLocal { - private final RunnerLocalInternal impl = new RunnerLocalInternal<>(); + private final RunnerLocalInternal impl; private final Supplier supplier; - private WorkflowLocal(Supplier supplier) { + private WorkflowLocal(Supplier supplier, boolean useCaching) { this.supplier = Objects.requireNonNull(supplier); + this.impl = new RunnerLocalInternal<>(useCaching); } public WorkflowLocal() { this.supplier = () -> null; + this.impl = new RunnerLocalInternal<>(false); } + /** + * Create an instance that returns the value returned by the given {@code Supplier} when {@link + * #set(S)} has not yet been called in the Workflow. Note that the value returned by the {@code + * Supplier} is not stored in the {@code WorkflowLocal} implicitly; repeatedly calling {@link + * #get()} will always re-execute the {@code Supplier} until you call {@link #set(S)} for the + * first time. If you want the value returned by the {@code Supplier} to be stored in the {@code + * WorkflowLocal}, use {@link #withCachedInitial(Supplier)} instead. + * + * @param supplier Callback that will be executed whenever {@link #get()} is called, until {@link + * #set(S)} is called for the first time. + * @return A {@code WorkflowLocal} instance. + * @param The type stored in the {@code WorkflowLocal}. + * @deprecated Because the non-caching behavior of this API is typically not desirable, it's + * recommend to use {@link #withCachedInitial(Supplier)} instead. + */ + @Deprecated public static WorkflowLocal withInitial(Supplier supplier) { - return new WorkflowLocal<>(supplier); + return new WorkflowLocal<>(supplier, false); + } + + /** + * Create an instance that returns the value returned by the given {@code Supplier} when {@link + * #set(S)} has not yet been called in the Workflow, and then stores the returned value inside the + * {@code WorkflowLocal}. + * + * @param supplier Callback that will be executed when {@link #get()} is called for the first + * time, if {@link #set(S)} has not already been called. + * @return A {@code WorkflowLocal} instance. + * @param The type stored in the {@code WorkflowLocal}. + */ + public static WorkflowLocal withCachedInitial(Supplier supplier) { + return new WorkflowLocal<>(supplier, true); } public T get() { diff --git a/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowThreadLocal.java index 921b31810e..1704b70835 100644 --- a/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowThreadLocal.java +++ b/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowThreadLocal.java @@ -27,19 +27,52 @@ /** {@link ThreadLocal} analog for workflow code. */ public final class WorkflowThreadLocal { - private final WorkflowThreadLocalInternal impl = new WorkflowThreadLocalInternal<>(); + private final WorkflowThreadLocalInternal impl; private final Supplier supplier; - private WorkflowThreadLocal(Supplier supplier) { + private WorkflowThreadLocal(Supplier supplier, boolean useCaching) { this.supplier = Objects.requireNonNull(supplier); + this.impl = new WorkflowThreadLocalInternal<>(useCaching); } public WorkflowThreadLocal() { this.supplier = () -> null; + this.impl = new WorkflowThreadLocalInternal<>(false); } + /** + * Create an instance that returns the value returned by the given {@code Supplier} when {@link + * #set(S)} has not yet been called in the thread. Note that the value returned by the {@code + * Supplier} is not stored in the {@code WorkflowThreadLocal} implicitly; repeatedly calling + * {@link #get()} will always re-execute the {@code Supplier} until you call {@link #set(S)} for + * the first time. This differs from the behavior of {@code ThreadLocal}. If you want the value + * returned by the {@code Supplier} to be stored in the {@code WorkflowThreadLocal}, which matches + * the behavior of {@code ThreadLocal}, use {@link #withCachedInitial(Supplier)} instead. + * + * @param supplier Callback that will be executed whenever {@link #get()} is called, until {@link + * #set(S)} is called for the first time. + * @return A {@code WorkflowThreadLocal} instance. + * @param The type stored in the {@code WorkflowThreadLocal}. + * @deprecated Because the non-caching behavior of this API is typically not desirable, it's + * recommend to use {@link #withCachedInitial(Supplier)} instead. + */ + @Deprecated public static WorkflowThreadLocal withInitial(Supplier supplier) { - return new WorkflowThreadLocal<>(supplier); + return new WorkflowThreadLocal<>(supplier, false); + } + + /** + * Create an instance that returns the value returned by the given {@code Supplier} when {@link + * #set(S)} has not yet been called in the Workflow, and then stores the returned value inside the + * {@code WorkflowThreadLocal}. + * + * @param supplier Callback that will be executed when {@link #get()} is called for the first + * time, if {@link #set(S)} has not already been called. + * @return A {@code WorkflowThreadLocal} instance. + * @param The type stored in the {@code WorkflowThreadLocal}. + */ + public static WorkflowThreadLocal withCachedInitial(Supplier supplier) { + return new WorkflowThreadLocal<>(supplier, true); } public T get() { diff --git a/temporal-sdk/src/test/java/io/temporal/workflow/WorkflowLocalsTest.java b/temporal-sdk/src/test/java/io/temporal/workflow/WorkflowLocalsTest.java index df1674ae1b..847d563fa5 100644 --- a/temporal-sdk/src/test/java/io/temporal/workflow/WorkflowLocalsTest.java +++ b/temporal-sdk/src/test/java/io/temporal/workflow/WorkflowLocalsTest.java @@ -21,10 +21,13 @@ package io.temporal.workflow; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import io.temporal.testing.internal.SDKTestWorkflowRule; import io.temporal.workflow.shared.TestWorkflows.TestWorkflow1; +import io.temporal.workflow.shared.TestWorkflows.TestWorkflowReturnString; import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Assert; @@ -47,9 +50,11 @@ public void testWorkflowLocals() { public static class TestWorkflowLocals implements TestWorkflow1 { + @SuppressWarnings("deprecation") private final WorkflowThreadLocal threadLocal = WorkflowThreadLocal.withInitial(() -> 2); + @SuppressWarnings("deprecation") private final WorkflowLocal workflowLocal = WorkflowLocal.withInitial(() -> 5); @Override @@ -84,12 +89,15 @@ public static class TestWorkflowLocalsSupplierReuse implements TestWorkflow1 { private final AtomicInteger localCalls = new AtomicInteger(0); private final AtomicInteger threadLocalCalls = new AtomicInteger(0); + @SuppressWarnings("deprecation") private final WorkflowThreadLocal workflowThreadLocal = WorkflowThreadLocal.withInitial( () -> { threadLocalCalls.addAndGet(1); return null; }); + + @SuppressWarnings("deprecation") private final WorkflowLocal workflowLocal = WorkflowLocal.withInitial( () -> { @@ -131,4 +139,93 @@ public void testWorkflowLocalsSupplierReuse() { String result = workflowStub.execute(testWorkflowRule.getTaskQueue()); Assert.assertEquals("ok", result); } + + @SuppressWarnings("deprecation") + static final WorkflowThreadLocal threadLocal = + WorkflowThreadLocal.withInitial(() -> new AtomicInteger(2)); + + @SuppressWarnings("deprecation") + static final WorkflowLocal workflowLocal = + WorkflowLocal.withInitial(() -> new AtomicInteger(5)); + + static final WorkflowThreadLocal threadLocalCached = + WorkflowThreadLocal.withCachedInitial(() -> new AtomicInteger(2)); + + static final WorkflowLocal workflowLocalCached = + WorkflowLocal.withCachedInitial(() -> new AtomicInteger(5)); + + public static class TestInit implements TestWorkflowReturnString { + + @Override + public String execute() { + assertEquals(2, threadLocal.get().getAndSet(3)); + assertEquals(5, workflowLocal.get().getAndSet(6)); + assertEquals(2, threadLocalCached.get().getAndSet(3)); + assertEquals(5, workflowLocalCached.get().getAndSet(6)); + String out = Workflow.newChildWorkflowStub(TestWorkflow1.class).execute("ign"); + assertEquals("ok", out); + return "result=" + + threadLocal.get().get() + + ", " + + workflowLocal.get().get() + + ", " + + threadLocalCached.get().get() + + ", " + + workflowLocalCached.get().get(); + } + } + + public static class TestChildInit implements TestWorkflow1 { + + @Override + public String execute(String arg1) { + assertEquals(2, threadLocal.get().getAndSet(8)); + assertEquals(5, workflowLocal.get().getAndSet(0)); + return "ok"; + } + } + + @Rule + public SDKTestWorkflowRule testWorkflowRuleInitialValueNotShared = + SDKTestWorkflowRule.newBuilder() + .setWorkflowTypes(TestInit.class, TestChildInit.class) + .build(); + + @Test + public void testWorkflowInitialNotShared() { + TestWorkflowReturnString workflowStub = + testWorkflowRuleInitialValueNotShared.newWorkflowStubTimeoutOptions( + TestWorkflowReturnString.class); + String result = workflowStub.execute(); + Assert.assertEquals("result=2, 5, 3, 6", result); + } + + public static class TestCaching implements TestWorkflow1 { + + @Override + public String execute(String arg1) { + assertNotSame(threadLocal.get(), threadLocal.get()); + assertNotSame(workflowLocal.get(), workflowLocal.get()); + threadLocal.set(threadLocal.get()); + workflowLocal.set(workflowLocal.get()); + assertSame(threadLocal.get(), threadLocal.get()); + assertSame(workflowLocal.get(), workflowLocal.get()); + + assertSame(threadLocalCached.get(), threadLocalCached.get()); + assertSame(workflowLocalCached.get(), workflowLocalCached.get()); + return "ok"; + } + } + + @Rule + public SDKTestWorkflowRule testWorkflowRuleCaching = + SDKTestWorkflowRule.newBuilder().setWorkflowTypes(TestCaching.class).build(); + + @Test + public void testWorkflowLocalCaching() { + TestWorkflow1 workflowStub = + testWorkflowRuleCaching.newWorkflowStubTimeoutOptions(TestWorkflow1.class); + String out = workflowStub.execute("ign"); + assertEquals("ok", out); + } }