From a1fce092e0ade141208f4bd36bbb9c62664ea7f9 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Fri, 22 Nov 2024 10:28:44 +0100 Subject: [PATCH] Refactor partitionByKey to get clearer error (#5527) --- .../scala/com/spotify/scio/values/SCollection.scala | 7 ++----- .../com/spotify/scio/values/SCollectionTest.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala b/scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala index 245818eed9..c3fa33cf9a 100644 --- a/scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala +++ b/scio-core/src/main/scala/com/spotify/scio/values/SCollection.scala @@ -382,11 +382,8 @@ sealed trait SCollection[T] extends PCollectionWrapper[T] { * @group collection */ def partitionByKey[U](partitionKeys: Set[U])(f: T => U): Map[U, SCollection[T]] = { - val partitionKeysIndexed = partitionKeys.toIndexedSeq - - partitionKeysIndexed - .zip(partition(partitionKeys.size, (t: T) => partitionKeysIndexed.indexOf(f(t)))) - .toMap + val partitions = partitionKeys.zipWithIndex.toMap + partitionKeys.zip(partition(partitionKeys.size, x => partitions(f(x)))).toMap } /** diff --git a/scio-core/src/test/scala/com/spotify/scio/values/SCollectionTest.scala b/scio-core/src/test/scala/com/spotify/scio/values/SCollectionTest.scala index 3f1c082869..25664539f2 100644 --- a/scio-core/src/test/scala/com/spotify/scio/values/SCollectionTest.scala +++ b/scio-core/src/test/scala/com/spotify/scio/values/SCollectionTest.scala @@ -40,6 +40,7 @@ import scala.jdk.CollectionConverters._ import com.spotify.scio.coders.{Beam, Coder, MaterializedCoder} import com.spotify.scio.options.ScioOptions import com.spotify.scio.schemas.Schema +import org.apache.beam.sdk.Pipeline.PipelineExecutionException import org.apache.beam.sdk.coders.{NullableCoder, StringUtf8Coder} import java.nio.charset.StandardCharsets @@ -249,6 +250,16 @@ class SCollectionTest extends PipelineSpec { m("b") should containInAnyOrder(Seq("b4", "b5")) m("c") should containInAnyOrder(Seq("c6")) } + + val e = the[PipelineExecutionException] thrownBy { + runWithContext { sc => + sc + .parallelize(Seq("x")) + .partitionByKey(Set("a", "b", "c"))(_.substring(0, 1)) + } + } + e.getCause shouldBe a[NoSuchElementException] + e.getCause.getMessage shouldBe "key not found: x" } it should "support hashPartition() based on Object.hashCode()" in {