diff --git a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java index 2e00d3f8d..688022dcc 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/context/ContextThreadLocal.java @@ -32,7 +32,7 @@ public class ContextThreadLocal { private static final WorkflowThreadLocal> contextPropagators = - WorkflowThreadLocal.withInitial( + WorkflowThreadLocal.withCachedInitial( new Supplier>() { @Override public List get() { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java index 9bacc5f0c..c3ed400a0 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/RunnerLocalInternal.java @@ -24,24 +24,26 @@ import java.util.function.Supplier; public final class RunnerLocalInternal { - private T supplierResult = null; - private boolean supplierCalled = false; - - Optional invokeSupplier(Supplier supplier) { - if (!supplierCalled) { - T result = supplier.get(); - supplierCalled = true; - supplierResult = result; - return Optional.ofNullable(result); - } else { - return Optional.ofNullable(supplierResult); - } + + private final boolean useCaching; + + public RunnerLocalInternal() { + this.useCaching = false; + } + + public RunnerLocalInternal(boolean useCaching) { + this.useCaching = useCaching; } public T get(Supplier supplier) { Optional> result = DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this); - return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null); + T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null); + if (!result.isPresent() && useCaching) { + // This is the first time we've tried fetching this, and caching is enabled. Store it. + set(out); + } + return out; } public void set(T value) { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadLocalInternal.java b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadLocalInternal.java index 122616222..79fe1db5c 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadLocalInternal.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/sync/WorkflowThreadLocalInternal.java @@ -25,24 +25,25 @@ public final class WorkflowThreadLocalInternal { - private T supplierResult = null; - private boolean supplierCalled = false; - - Optional invokeSupplier(Supplier supplier) { - if (!supplierCalled) { - T result = supplier.get(); - supplierCalled = true; - supplierResult = result; - return Optional.ofNullable(result); - } else { - return Optional.ofNullable(supplierResult); - } + private final boolean useCaching; + + public WorkflowThreadLocalInternal() { + this(false); + } + + public WorkflowThreadLocalInternal(boolean useCaching) { + this.useCaching = useCaching; } public T get(Supplier supplier) { Optional> result = DeterministicRunnerImpl.currentThreadInternal().getThreadLocal(this); - return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null); + T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null); + if (!result.isPresent() && useCaching) { + // This is the first time we've tried fetching this, and caching is enabled. Store it. + set(out); + } + return out; } public void set(T value) { diff --git a/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowLocal.java b/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowLocal.java index 65194818c..94b3635b2 100644 --- a/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowLocal.java +++ b/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowLocal.java @@ -33,7 +33,7 @@ *
{@code
  * public class Workflow {
  *
- *   private static final WorkflowLocal signaled = WorkflowLocal.withInitial(() -> false);
+ *   private static final WorkflowLocal signaled = WorkflowLocal.withCachedInitial(() -> false);
  *
  *   public static boolean isSignaled() {
  *     return signaled.get();
@@ -49,19 +49,51 @@
  */
 public final class WorkflowLocal {
 
-  private final RunnerLocalInternal impl = new RunnerLocalInternal<>();
+  private final RunnerLocalInternal impl;
   private final Supplier supplier;
 
-  private WorkflowLocal(Supplier supplier) {
+  private WorkflowLocal(Supplier supplier, boolean useCaching) {
     this.supplier = Objects.requireNonNull(supplier);
+    this.impl = new RunnerLocalInternal<>(useCaching);
   }
 
   public WorkflowLocal() {
     this.supplier = () -> null;
+    this.impl = new RunnerLocalInternal<>(false);
   }
 
+  /**
+   * Create an instance that returns the value returned by the given {@code Supplier} when {@link
+   * #set(S)} has not yet been called in the Workflow. Note that the value returned by the {@code
+   * Supplier} is not stored in the {@code WorkflowLocal} implicitly; repeatedly calling {@link
+   * #get()} will always re-execute the {@code Supplier} until you call {@link #set(S)} for the
+   * first time. If you want the value returned by the {@code Supplier} to be stored in the {@code
+   * WorkflowLocal}, use {@link #withCachedInitial(Supplier)} instead.
+   *
+   * @param supplier Callback that will be executed whenever {@link #get()} is called, until {@link
+   *     #set(S)} is called for the first time.
+   * @return A {@code WorkflowLocal} instance.
+   * @param  The type stored in the {@code WorkflowLocal}.
+   * @deprecated Because the non-caching behavior of this API is typically not desirable, it's
+   *     recommend to use {@link #withCachedInitial(Supplier)} instead.
+   */
+  @Deprecated
   public static  WorkflowLocal withInitial(Supplier supplier) {
-    return new WorkflowLocal<>(supplier);
+    return new WorkflowLocal<>(supplier, false);
+  }
+
+  /**
+   * Create an instance that returns the value returned by the given {@code Supplier} when {@link
+   * #set(S)} has not yet been called in the Workflow, and then stores the returned value inside the
+   * {@code WorkflowLocal}.
+   *
+   * @param supplier Callback that will be executed when {@link #get()} is called for the first
+   *     time, if {@link #set(S)} has not already been called.
+   * @return A {@code WorkflowLocal} instance.
+   * @param  The type stored in the {@code WorkflowLocal}.
+   */
+  public static  WorkflowLocal withCachedInitial(Supplier supplier) {
+    return new WorkflowLocal<>(supplier, true);
   }
 
   public T get() {
diff --git a/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowThreadLocal.java b/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowThreadLocal.java
index 921b31810..1704b7083 100644
--- a/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowThreadLocal.java
+++ b/temporal-sdk/src/main/java/io/temporal/workflow/WorkflowThreadLocal.java
@@ -27,19 +27,52 @@
 /** {@link ThreadLocal} analog for workflow code. */
 public final class WorkflowThreadLocal {
 
-  private final WorkflowThreadLocalInternal impl = new WorkflowThreadLocalInternal<>();
+  private final WorkflowThreadLocalInternal impl;
   private final Supplier supplier;
 
-  private WorkflowThreadLocal(Supplier supplier) {
+  private WorkflowThreadLocal(Supplier supplier, boolean useCaching) {
     this.supplier = Objects.requireNonNull(supplier);
+    this.impl = new WorkflowThreadLocalInternal<>(useCaching);
   }
 
   public WorkflowThreadLocal() {
     this.supplier = () -> null;
+    this.impl = new WorkflowThreadLocalInternal<>(false);
   }
 
+  /**
+   * Create an instance that returns the value returned by the given {@code Supplier} when {@link
+   * #set(S)} has not yet been called in the thread. Note that the value returned by the {@code
+   * Supplier} is not stored in the {@code WorkflowThreadLocal} implicitly; repeatedly calling
+   * {@link #get()} will always re-execute the {@code Supplier} until you call {@link #set(S)} for
+   * the first time. This differs from the behavior of {@code ThreadLocal}. If you want the value
+   * returned by the {@code Supplier} to be stored in the {@code WorkflowThreadLocal}, which matches
+   * the behavior of {@code ThreadLocal}, use {@link #withCachedInitial(Supplier)} instead.
+   *
+   * @param supplier Callback that will be executed whenever {@link #get()} is called, until {@link
+   *     #set(S)} is called for the first time.
+   * @return A {@code WorkflowThreadLocal} instance.
+   * @param  The type stored in the {@code WorkflowThreadLocal}.
+   * @deprecated Because the non-caching behavior of this API is typically not desirable, it's
+   *     recommend to use {@link #withCachedInitial(Supplier)} instead.
+   */
+  @Deprecated
   public static  WorkflowThreadLocal withInitial(Supplier supplier) {
-    return new WorkflowThreadLocal<>(supplier);
+    return new WorkflowThreadLocal<>(supplier, false);
+  }
+
+  /**
+   * Create an instance that returns the value returned by the given {@code Supplier} when {@link
+   * #set(S)} has not yet been called in the Workflow, and then stores the returned value inside the
+   * {@code WorkflowThreadLocal}.
+   *
+   * @param supplier Callback that will be executed when {@link #get()} is called for the first
+   *     time, if {@link #set(S)} has not already been called.
+   * @return A {@code WorkflowThreadLocal} instance.
+   * @param  The type stored in the {@code WorkflowThreadLocal}.
+   */
+  public static  WorkflowThreadLocal withCachedInitial(Supplier supplier) {
+    return new WorkflowThreadLocal<>(supplier, true);
   }
 
   public T get() {
diff --git a/temporal-sdk/src/test/java/io/temporal/internal/sync/DeterministicRunnerTest.java b/temporal-sdk/src/test/java/io/temporal/internal/sync/DeterministicRunnerTest.java
index 98ff1b931..5ba73b8e8 100644
--- a/temporal-sdk/src/test/java/io/temporal/internal/sync/DeterministicRunnerTest.java
+++ b/temporal-sdk/src/test/java/io/temporal/internal/sync/DeterministicRunnerTest.java
@@ -945,14 +945,14 @@ private static Supplier getStringSupplier(AtomicInteger supplierCalls) {
   }
 
   @Test
-  public void testSupplierCalledOnce() {
+  public void testSupplierCalledOnceWithCaching() {
     AtomicInteger supplierCalls = new AtomicInteger();
     DeterministicRunnerImpl d =
         new DeterministicRunnerImpl(
             threadPool::submit,
             DummySyncWorkflowContext.newDummySyncWorkflowContext(),
             () -> {
-              RunnerLocalInternal runnerLocalInternal = new RunnerLocalInternal<>();
+              RunnerLocalInternal runnerLocalInternal = new RunnerLocalInternal<>(true);
               runnerLocalInternal.get(getStringSupplier(supplierCalls));
               runnerLocalInternal.get(getStringSupplier(supplierCalls));
               runnerLocalInternal.get(getStringSupplier(supplierCalls));
@@ -963,4 +963,24 @@ public void testSupplierCalledOnce() {
             });
     d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
   }
+
+  @Test
+  public void testSupplierCalledMultipleWithoutCaching() {
+    AtomicInteger supplierCalls = new AtomicInteger();
+    DeterministicRunnerImpl d =
+        new DeterministicRunnerImpl(
+            threadPool::submit,
+            DummySyncWorkflowContext.newDummySyncWorkflowContext(),
+            () -> {
+              RunnerLocalInternal runnerLocalInternal = new RunnerLocalInternal<>(false);
+              runnerLocalInternal.get(getStringSupplier(supplierCalls));
+              runnerLocalInternal.get(getStringSupplier(supplierCalls));
+              runnerLocalInternal.get(getStringSupplier(supplierCalls));
+              assertEquals(
+                  "supplier default value",
+                  runnerLocalInternal.get(getStringSupplier(supplierCalls)));
+              assertEquals(4, supplierCalls.get());
+            });
+    d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
+  }
 }
diff --git a/temporal-sdk/src/test/java/io/temporal/workflow/WorkflowLocalsTest.java b/temporal-sdk/src/test/java/io/temporal/workflow/WorkflowLocalsTest.java
index df1674ae1..847d563fa 100644
--- a/temporal-sdk/src/test/java/io/temporal/workflow/WorkflowLocalsTest.java
+++ b/temporal-sdk/src/test/java/io/temporal/workflow/WorkflowLocalsTest.java
@@ -21,10 +21,13 @@
 package io.temporal.workflow;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotSame;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertSame;
 
 import io.temporal.testing.internal.SDKTestWorkflowRule;
 import io.temporal.workflow.shared.TestWorkflows.TestWorkflow1;
+import io.temporal.workflow.shared.TestWorkflows.TestWorkflowReturnString;
 import java.time.Duration;
 import java.util.concurrent.atomic.AtomicInteger;
 import org.junit.Assert;
@@ -47,9 +50,11 @@ public void testWorkflowLocals() {
 
   public static class TestWorkflowLocals implements TestWorkflow1 {
 
+    @SuppressWarnings("deprecation")
     private final WorkflowThreadLocal threadLocal =
         WorkflowThreadLocal.withInitial(() -> 2);
 
+    @SuppressWarnings("deprecation")
     private final WorkflowLocal workflowLocal = WorkflowLocal.withInitial(() -> 5);
 
     @Override
@@ -84,12 +89,15 @@ public static class TestWorkflowLocalsSupplierReuse implements TestWorkflow1 {
     private final AtomicInteger localCalls = new AtomicInteger(0);
     private final AtomicInteger threadLocalCalls = new AtomicInteger(0);
 
+    @SuppressWarnings("deprecation")
     private final WorkflowThreadLocal workflowThreadLocal =
         WorkflowThreadLocal.withInitial(
             () -> {
               threadLocalCalls.addAndGet(1);
               return null;
             });
+
+    @SuppressWarnings("deprecation")
     private final WorkflowLocal workflowLocal =
         WorkflowLocal.withInitial(
             () -> {
@@ -131,4 +139,93 @@ public void testWorkflowLocalsSupplierReuse() {
     String result = workflowStub.execute(testWorkflowRule.getTaskQueue());
     Assert.assertEquals("ok", result);
   }
+
+  @SuppressWarnings("deprecation")
+  static final WorkflowThreadLocal threadLocal =
+      WorkflowThreadLocal.withInitial(() -> new AtomicInteger(2));
+
+  @SuppressWarnings("deprecation")
+  static final WorkflowLocal workflowLocal =
+      WorkflowLocal.withInitial(() -> new AtomicInteger(5));
+
+  static final WorkflowThreadLocal threadLocalCached =
+      WorkflowThreadLocal.withCachedInitial(() -> new AtomicInteger(2));
+
+  static final WorkflowLocal workflowLocalCached =
+      WorkflowLocal.withCachedInitial(() -> new AtomicInteger(5));
+
+  public static class TestInit implements TestWorkflowReturnString {
+
+    @Override
+    public String execute() {
+      assertEquals(2, threadLocal.get().getAndSet(3));
+      assertEquals(5, workflowLocal.get().getAndSet(6));
+      assertEquals(2, threadLocalCached.get().getAndSet(3));
+      assertEquals(5, workflowLocalCached.get().getAndSet(6));
+      String out = Workflow.newChildWorkflowStub(TestWorkflow1.class).execute("ign");
+      assertEquals("ok", out);
+      return "result="
+          + threadLocal.get().get()
+          + ", "
+          + workflowLocal.get().get()
+          + ", "
+          + threadLocalCached.get().get()
+          + ", "
+          + workflowLocalCached.get().get();
+    }
+  }
+
+  public static class TestChildInit implements TestWorkflow1 {
+
+    @Override
+    public String execute(String arg1) {
+      assertEquals(2, threadLocal.get().getAndSet(8));
+      assertEquals(5, workflowLocal.get().getAndSet(0));
+      return "ok";
+    }
+  }
+
+  @Rule
+  public SDKTestWorkflowRule testWorkflowRuleInitialValueNotShared =
+      SDKTestWorkflowRule.newBuilder()
+          .setWorkflowTypes(TestInit.class, TestChildInit.class)
+          .build();
+
+  @Test
+  public void testWorkflowInitialNotShared() {
+    TestWorkflowReturnString workflowStub =
+        testWorkflowRuleInitialValueNotShared.newWorkflowStubTimeoutOptions(
+            TestWorkflowReturnString.class);
+    String result = workflowStub.execute();
+    Assert.assertEquals("result=2, 5, 3, 6", result);
+  }
+
+  public static class TestCaching implements TestWorkflow1 {
+
+    @Override
+    public String execute(String arg1) {
+      assertNotSame(threadLocal.get(), threadLocal.get());
+      assertNotSame(workflowLocal.get(), workflowLocal.get());
+      threadLocal.set(threadLocal.get());
+      workflowLocal.set(workflowLocal.get());
+      assertSame(threadLocal.get(), threadLocal.get());
+      assertSame(workflowLocal.get(), workflowLocal.get());
+
+      assertSame(threadLocalCached.get(), threadLocalCached.get());
+      assertSame(workflowLocalCached.get(), workflowLocalCached.get());
+      return "ok";
+    }
+  }
+
+  @Rule
+  public SDKTestWorkflowRule testWorkflowRuleCaching =
+      SDKTestWorkflowRule.newBuilder().setWorkflowTypes(TestCaching.class).build();
+
+  @Test
+  public void testWorkflowLocalCaching() {
+    TestWorkflow1 workflowStub =
+        testWorkflowRuleCaching.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
+    String out = workflowStub.execute("ign");
+    assertEquals("ok", out);
+  }
 }