diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java index 4fc5d3beca31..5f9bb6392ec2 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java @@ -20,8 +20,8 @@ import com.google.auto.value.AutoValue; import java.io.Serializable; import java.util.Arrays; -import java.util.HashSet; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import org.apache.beam.sdk.metrics.StringSetResult; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; @@ -54,13 +54,13 @@ public static StringSetData create(Set set) { if (set.isEmpty()) { return empty(); } - HashSet combined = new HashSet<>(); + Set combined = ConcurrentHashMap.newKeySet(); long stringSize = addUntilCapacity(combined, 0L, set); return new AutoValue_StringSetData(combined, stringSize); } /** Returns a {@link StringSetData} which is made from the given set in place. */ - private static StringSetData createInPlace(HashSet set, long stringSize) { + private static StringSetData createInPlace(Set set, long stringSize) { return new AutoValue_StringSetData(set, stringSize); } @@ -76,11 +76,12 @@ public static StringSetData empty() { *

>Should only be used by {@link StringSetCell#add}. */ public StringSetData addAll(String... strings) { - HashSet combined; - if (this.stringSet() instanceof HashSet) { - combined = (HashSet) this.stringSet(); + Set combined; + if (this.stringSet() instanceof ConcurrentHashMap.KeySetView) { + combined = this.stringSet(); } else { - combined = new HashSet<>(this.stringSet()); + combined = ConcurrentHashMap.newKeySet(); + combined.addAll(this.stringSet()); } long stringSize = addUntilCapacity(combined, this.stringSize(), Arrays.asList(strings)); return StringSetData.createInPlace(combined, stringSize); @@ -95,7 +96,8 @@ public StringSetData combine(StringSetData other) { } else if (other.stringSet().isEmpty()) { return this; } else { - HashSet combined = new HashSet<>(this.stringSet()); + Set combined = ConcurrentHashMap.newKeySet(); + combined.addAll(this.stringSet()); long stringSize = addUntilCapacity(combined, this.stringSize(), other.stringSet()); return StringSetData.createInPlace(combined, stringSize); } @@ -105,7 +107,8 @@ public StringSetData combine(StringSetData other) { * Combines this {@link StringSetData} with others, all original StringSetData are left intact. */ public StringSetData combine(Iterable others) { - HashSet combined = new HashSet<>(this.stringSet()); + Set combined = ConcurrentHashMap.newKeySet(); + combined.addAll(this.stringSet()); long stringSize = this.stringSize(); for (StringSetData other : others) { stringSize = addUntilCapacity(combined, stringSize, other.stringSet()); @@ -120,7 +123,7 @@ public StringSetResult extractResult() { /** Add strings into set until reach capacity. Return the all string size of added set. */ private static long addUntilCapacity( - HashSet combined, long currentSize, Iterable others) { + Set combined, long currentSize, Iterable others) { if (currentSize > STRING_SET_SIZE_LIMIT) { // already at capacity return currentSize; diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java index f78ed01603fb..9497bbe43d0e 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java @@ -20,7 +20,13 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.junit.Assert; @@ -94,4 +100,42 @@ public void testReset() { assertThat(stringSetCell.getCumulative(), equalTo(StringSetData.empty())); assertThat(stringSetCell.getDirty(), equalTo(new DirtyState())); } + + @Test(timeout = 5000) + public void testStringSetCellConcurrentAddRetrieval() throws InterruptedException { + StringSetCell cell = new StringSetCell(MetricName.named("namespace", "name")); + AtomicBoolean finished = new AtomicBoolean(false); + Thread increment = + new Thread( + () -> { + for (long i = 0; !finished.get(); ++i) { + cell.add(String.valueOf(i)); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + break; + } + } + }); + increment.start(); + Instant start = Instant.now(); + try { + while (true) { + Set s = cell.getCumulative().stringSet(); + List snapshot = new ArrayList<>(s); + if (Instant.now().isAfter(start.plusSeconds(3)) && snapshot.size() > 0) { + finished.compareAndSet(false, true); + break; + } + } + } finally { + increment.interrupt(); + increment.join(); + } + + Set s = cell.getCumulative().stringSet(); + for (long i = 0; i < s.size(); ++i) { + assertTrue(s.contains(String.valueOf(i))); + } + } }