Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix caching in WorkflowLocal/WorkflowThreadLocal (#1876) #1878

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
public class ContextThreadLocal {

private static final WorkflowThreadLocal<List<ContextPropagator>> contextPropagators =
WorkflowThreadLocal.withInitial(
WorkflowThreadLocal.withCachedInitial(
new Supplier<List<ContextPropagator>>() {
@Override
public List<ContextPropagator> get() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,26 @@
import java.util.function.Supplier;

public final class RunnerLocalInternal<T> {
private T supplierResult = null;
private boolean supplierCalled = false;

Optional<T> invokeSupplier(Supplier<? extends T> supplier) {
if (!supplierCalled) {
T result = supplier.get();
supplierCalled = true;
supplierResult = result;
return Optional.ofNullable(result);
} else {
return Optional.ofNullable(supplierResult);
}

private final boolean useCaching;

public RunnerLocalInternal() {
this.useCaching = false;
}

public RunnerLocalInternal(boolean useCaching) {
this.useCaching = useCaching;
}

public T get(Supplier<? extends T> supplier) {
Optional<Optional<T>> result =
DeterministicRunnerImpl.currentThreadInternal().getRunner().getRunnerLocal(this);
return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null);
T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null);
if (!result.isPresent() && useCaching) {
// This is the first time we've tried fetching this, and caching is enabled. Store it.
set(out);
}
return out;
}

public void set(T value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,25 @@

public final class WorkflowThreadLocalInternal<T> {

private T supplierResult = null;
private boolean supplierCalled = false;

Optional<T> invokeSupplier(Supplier<? extends T> supplier) {
if (!supplierCalled) {
T result = supplier.get();
supplierCalled = true;
supplierResult = result;
return Optional.ofNullable(result);
} else {
return Optional.ofNullable(supplierResult);
}
private final boolean useCaching;

public WorkflowThreadLocalInternal() {
this(false);
}

public WorkflowThreadLocalInternal(boolean useCaching) {
this.useCaching = useCaching;
}

public T get(Supplier<? extends T> supplier) {
Optional<Optional<T>> result =
DeterministicRunnerImpl.currentThreadInternal().getThreadLocal(this);
return result.orElseGet(() -> invokeSupplier(supplier)).orElse(null);
T out = result.orElseGet(() -> Optional.ofNullable(supplier.get())).orElse(null);
if (!result.isPresent() && useCaching) {
// This is the first time we've tried fetching this, and caching is enabled. Store it.
set(out);
}
return out;
}

public void set(T value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
* <pre>{@code
* public class Workflow {
*
* private static final WorkflowLocal<Boolean> signaled = WorkflowLocal.withInitial(() -> false);
* private static final WorkflowLocal<Boolean> signaled = WorkflowLocal.withCachedInitial(() -> false);
*
* public static boolean isSignaled() {
* return signaled.get();
Expand All @@ -49,19 +49,51 @@
*/
public final class WorkflowLocal<T> {

private final RunnerLocalInternal<T> impl = new RunnerLocalInternal<>();
private final RunnerLocalInternal<T> impl;
private final Supplier<? extends T> supplier;

private WorkflowLocal(Supplier<? extends T> supplier) {
private WorkflowLocal(Supplier<? extends T> supplier, boolean useCaching) {
this.supplier = Objects.requireNonNull(supplier);
this.impl = new RunnerLocalInternal<>(useCaching);
}

public WorkflowLocal() {
this.supplier = () -> null;
this.impl = new RunnerLocalInternal<>(false);
}

/**
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
* #set(S)} has not yet been called in the Workflow. Note that the value returned by the {@code
* Supplier} is not stored in the {@code WorkflowLocal} implicitly; repeatedly calling {@link
* #get()} will always re-execute the {@code Supplier} until you call {@link #set(S)} for the
* first time. If you want the value returned by the {@code Supplier} to be stored in the {@code
* WorkflowLocal}, use {@link #withCachedInitial(Supplier)} instead.
*
* @param supplier Callback that will be executed whenever {@link #get()} is called, until {@link
* #set(S)} is called for the first time.
* @return A {@code WorkflowLocal} instance.
* @param <S> The type stored in the {@code WorkflowLocal}.
* @deprecated Because the non-caching behavior of this API is typically not desirable, it's
* recommend to use {@link #withCachedInitial(Supplier)} instead.
*/
@Deprecated
public static <S> WorkflowLocal<S> withInitial(Supplier<? extends S> supplier) {
return new WorkflowLocal<>(supplier);
return new WorkflowLocal<>(supplier, false);
}

/**
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
* #set(S)} has not yet been called in the Workflow, and then stores the returned value inside the
* {@code WorkflowLocal}.
*
* @param supplier Callback that will be executed when {@link #get()} is called for the first
* time, if {@link #set(S)} has not already been called.
* @return A {@code WorkflowLocal} instance.
* @param <S> The type stored in the {@code WorkflowLocal}.
*/
public static <S> WorkflowLocal<S> withCachedInitial(Supplier<? extends S> supplier) {
return new WorkflowLocal<>(supplier, true);
}

public T get() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,52 @@
/** {@link ThreadLocal} analog for workflow code. */
public final class WorkflowThreadLocal<T> {

private final WorkflowThreadLocalInternal<T> impl = new WorkflowThreadLocalInternal<>();
private final WorkflowThreadLocalInternal<T> impl;
private final Supplier<? extends T> supplier;

private WorkflowThreadLocal(Supplier<? extends T> supplier) {
private WorkflowThreadLocal(Supplier<? extends T> supplier, boolean useCaching) {
this.supplier = Objects.requireNonNull(supplier);
this.impl = new WorkflowThreadLocalInternal<>(useCaching);
}

public WorkflowThreadLocal() {
this.supplier = () -> null;
this.impl = new WorkflowThreadLocalInternal<>(false);
}

/**
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
* #set(S)} has not yet been called in the thread. Note that the value returned by the {@code
* Supplier} is not stored in the {@code WorkflowThreadLocal} implicitly; repeatedly calling
* {@link #get()} will always re-execute the {@code Supplier} until you call {@link #set(S)} for
* the first time. This differs from the behavior of {@code ThreadLocal}. If you want the value
* returned by the {@code Supplier} to be stored in the {@code WorkflowThreadLocal}, which matches
* the behavior of {@code ThreadLocal}, use {@link #withCachedInitial(Supplier)} instead.
*
* @param supplier Callback that will be executed whenever {@link #get()} is called, until {@link
* #set(S)} is called for the first time.
* @return A {@code WorkflowThreadLocal} instance.
* @param <S> The type stored in the {@code WorkflowThreadLocal}.
* @deprecated Because the non-caching behavior of this API is typically not desirable, it's
* recommend to use {@link #withCachedInitial(Supplier)} instead.
*/
@Deprecated
public static <S> WorkflowThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new WorkflowThreadLocal<>(supplier);
return new WorkflowThreadLocal<>(supplier, false);
}

/**
* Create an instance that returns the value returned by the given {@code Supplier} when {@link
* #set(S)} has not yet been called in the Workflow, and then stores the returned value inside the
* {@code WorkflowThreadLocal}.
*
* @param supplier Callback that will be executed when {@link #get()} is called for the first
* time, if {@link #set(S)} has not already been called.
* @return A {@code WorkflowThreadLocal} instance.
* @param <S> The type stored in the {@code WorkflowThreadLocal}.
*/
public static <S> WorkflowThreadLocal<S> withCachedInitial(Supplier<? extends S> supplier) {
return new WorkflowThreadLocal<>(supplier, true);
}

public T get() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -945,14 +945,14 @@ private static Supplier<String> getStringSupplier(AtomicInteger supplierCalls) {
}

@Test
public void testSupplierCalledOnce() {
public void testSupplierCalledOnceWithCaching() {
AtomicInteger supplierCalls = new AtomicInteger();
DeterministicRunnerImpl d =
new DeterministicRunnerImpl(
threadPool::submit,
DummySyncWorkflowContext.newDummySyncWorkflowContext(),
() -> {
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>();
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>(true);
runnerLocalInternal.get(getStringSupplier(supplierCalls));
runnerLocalInternal.get(getStringSupplier(supplierCalls));
runnerLocalInternal.get(getStringSupplier(supplierCalls));
Expand All @@ -963,4 +963,24 @@ public void testSupplierCalledOnce() {
});
d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
}

@Test
public void testSupplierCalledMultipleWithoutCaching() {
AtomicInteger supplierCalls = new AtomicInteger();
DeterministicRunnerImpl d =
new DeterministicRunnerImpl(
threadPool::submit,
DummySyncWorkflowContext.newDummySyncWorkflowContext(),
() -> {
RunnerLocalInternal<String> runnerLocalInternal = new RunnerLocalInternal<>(false);
runnerLocalInternal.get(getStringSupplier(supplierCalls));
runnerLocalInternal.get(getStringSupplier(supplierCalls));
runnerLocalInternal.get(getStringSupplier(supplierCalls));
assertEquals(
"supplier default value",
runnerLocalInternal.get(getStringSupplier(supplierCalls)));
assertEquals(4, supplierCalls.get());
});
d.runUntilAllBlocked(DeterministicRunner.DEFAULT_DEADLOCK_DETECTION_TIMEOUT_MS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
package io.temporal.workflow;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;

import io.temporal.testing.internal.SDKTestWorkflowRule;
import io.temporal.workflow.shared.TestWorkflows.TestWorkflow1;
import io.temporal.workflow.shared.TestWorkflows.TestWorkflowReturnString;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Assert;
Expand All @@ -47,9 +50,11 @@ public void testWorkflowLocals() {

public static class TestWorkflowLocals implements TestWorkflow1 {

@SuppressWarnings("deprecation")
private final WorkflowThreadLocal<Integer> threadLocal =
WorkflowThreadLocal.withInitial(() -> 2);

@SuppressWarnings("deprecation")
private final WorkflowLocal<Integer> workflowLocal = WorkflowLocal.withInitial(() -> 5);

@Override
Expand Down Expand Up @@ -84,12 +89,15 @@ public static class TestWorkflowLocalsSupplierReuse implements TestWorkflow1 {
private final AtomicInteger localCalls = new AtomicInteger(0);
private final AtomicInteger threadLocalCalls = new AtomicInteger(0);

@SuppressWarnings("deprecation")
private final WorkflowThreadLocal<Integer> workflowThreadLocal =
WorkflowThreadLocal.withInitial(
() -> {
threadLocalCalls.addAndGet(1);
return null;
});

@SuppressWarnings("deprecation")
private final WorkflowLocal<Integer> workflowLocal =
WorkflowLocal.withInitial(
() -> {
Expand Down Expand Up @@ -131,4 +139,93 @@ public void testWorkflowLocalsSupplierReuse() {
String result = workflowStub.execute(testWorkflowRule.getTaskQueue());
Assert.assertEquals("ok", result);
}

@SuppressWarnings("deprecation")
static final WorkflowThreadLocal<AtomicInteger> threadLocal =
WorkflowThreadLocal.withInitial(() -> new AtomicInteger(2));

@SuppressWarnings("deprecation")
static final WorkflowLocal<AtomicInteger> workflowLocal =
WorkflowLocal.withInitial(() -> new AtomicInteger(5));

static final WorkflowThreadLocal<AtomicInteger> threadLocalCached =
WorkflowThreadLocal.withCachedInitial(() -> new AtomicInteger(2));

static final WorkflowLocal<AtomicInteger> workflowLocalCached =
WorkflowLocal.withCachedInitial(() -> new AtomicInteger(5));

public static class TestInit implements TestWorkflowReturnString {

@Override
public String execute() {
assertEquals(2, threadLocal.get().getAndSet(3));
assertEquals(5, workflowLocal.get().getAndSet(6));
assertEquals(2, threadLocalCached.get().getAndSet(3));
assertEquals(5, workflowLocalCached.get().getAndSet(6));
String out = Workflow.newChildWorkflowStub(TestWorkflow1.class).execute("ign");
assertEquals("ok", out);
return "result="
+ threadLocal.get().get()
+ ", "
+ workflowLocal.get().get()
+ ", "
+ threadLocalCached.get().get()
+ ", "
+ workflowLocalCached.get().get();
}
}

public static class TestChildInit implements TestWorkflow1 {

@Override
public String execute(String arg1) {
assertEquals(2, threadLocal.get().getAndSet(8));
assertEquals(5, workflowLocal.get().getAndSet(0));
return "ok";
}
}

@Rule
public SDKTestWorkflowRule testWorkflowRuleInitialValueNotShared =
SDKTestWorkflowRule.newBuilder()
.setWorkflowTypes(TestInit.class, TestChildInit.class)
.build();

@Test
public void testWorkflowInitialNotShared() {
TestWorkflowReturnString workflowStub =
testWorkflowRuleInitialValueNotShared.newWorkflowStubTimeoutOptions(
TestWorkflowReturnString.class);
String result = workflowStub.execute();
Assert.assertEquals("result=2, 5, 3, 6", result);
}

public static class TestCaching implements TestWorkflow1 {

@Override
public String execute(String arg1) {
assertNotSame(threadLocal.get(), threadLocal.get());
assertNotSame(workflowLocal.get(), workflowLocal.get());
threadLocal.set(threadLocal.get());
workflowLocal.set(workflowLocal.get());
assertSame(threadLocal.get(), threadLocal.get());
assertSame(workflowLocal.get(), workflowLocal.get());

assertSame(threadLocalCached.get(), threadLocalCached.get());
assertSame(workflowLocalCached.get(), workflowLocalCached.get());
return "ok";
}
}

@Rule
public SDKTestWorkflowRule testWorkflowRuleCaching =
SDKTestWorkflowRule.newBuilder().setWorkflowTypes(TestCaching.class).build();

@Test
public void testWorkflowLocalCaching() {
TestWorkflow1 workflowStub =
testWorkflowRuleCaching.newWorkflowStubTimeoutOptions(TestWorkflow1.class);
String out = workflowStub.execute("ign");
assertEquals("ok", out);
}
}
Loading