diff --git a/cluster-sharding-typed/src/main/scala/org/apache/pekko/cluster/sharding/typed/internal/ClusterShardingImpl.scala b/cluster-sharding-typed/src/main/scala/org/apache/pekko/cluster/sharding/typed/internal/ClusterShardingImpl.scala index 525718d2277..026611912c7 100644 --- a/cluster-sharding-typed/src/main/scala/org/apache/pekko/cluster/sharding/typed/internal/ClusterShardingImpl.scala +++ b/cluster-sharding-typed/src/main/scala/org/apache/pekko/cluster/sharding/typed/internal/ClusterShardingImpl.scala @@ -20,7 +20,6 @@ import java.util.concurrent.CompletionStage import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future -import scala.runtime.AbstractPartialFunction import org.apache.pekko import pekko.actor.ActorRefProvider @@ -173,16 +172,10 @@ import pekko.util.JavaDurationConverters._ allocationStrategy: Option[ShardAllocationStrategy]): ActorRef[E] = { val extractorAdapter = new ExtractorAdapter(extractor) - // !!!important is only applicable if you know that isDefinedAt(x) is always called before apply(x) (with the same x) - val extractEntityId: ShardRegion.ExtractEntityId = new AbstractPartialFunction[Any, (String, Any)] { - var cache: String = _ - - override def isDefinedAt(msg: Any): Boolean = { - cache = extractorAdapter.entityId(msg) - cache != null - } - - override def apply(x: Any): (String, Any) = (cache, extractorAdapter.unwrapMessage(x)) + val extractEntityId: ShardRegion.ExtractEntityId = { + // TODO is it possible to avoid the double evaluation of entityId + case message if extractorAdapter.entityId(message) != null => + (extractorAdapter.entityId(message), extractorAdapter.unwrapMessage(message)) } val extractShardId: ShardRegion.ExtractShardId = { message => extractorAdapter.entityId(message) match { diff --git a/cluster-sharding/src/main/scala/org/apache/pekko/cluster/sharding/ClusterSharding.scala b/cluster-sharding/src/main/scala/org/apache/pekko/cluster/sharding/ClusterSharding.scala index 835d2fcfbb2..9278839da3d 100755 --- a/cluster-sharding/src/main/scala/org/apache/pekko/cluster/sharding/ClusterSharding.scala +++ b/cluster-sharding/src/main/scala/org/apache/pekko/cluster/sharding/ClusterSharding.scala @@ -19,7 +19,6 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.immutable import scala.concurrent.Await -import scala.runtime.AbstractPartialFunction import scala.util.control.NonFatal import org.apache.pekko @@ -430,26 +429,15 @@ class ClusterSharding(system: ExtendedActorSystem) extends Extension { typeName, _ => entityProps, settings, - extractEntityId = extractEntityIdFromExtractor(messageExtractor), + extractEntityId = { + case msg if messageExtractor.entityId(msg) ne null => + (messageExtractor.entityId(msg), messageExtractor.entityMessage(msg)) + }, extractShardId = msg => messageExtractor.shardId(msg), allocationStrategy = allocationStrategy, handOffStopMessage = handOffStopMessage) } - // !!!important is only applicable if you know that isDefinedAt(x) is always called before apply(x) (with the same x) - private def extractEntityIdFromExtractor( - messageExtractor: ShardRegion.MessageExtractor): ShardRegion.ExtractEntityId = - new AbstractPartialFunction[Any, (String, Any)] { - var cache: String = _ - - override def isDefinedAt(msg: Any): Boolean = { - cache = messageExtractor.entityId(msg) - cache != null - } - - override def apply(x: Any): (String, Any) = (cache, messageExtractor.entityMessage(x)) - } - /** * Java/Scala API: Register a named entity type by defining the [[pekko.actor.Props]] of the entity actor * and functions to extract entity and shard identifier from messages. The [[ShardRegion]] actor @@ -624,12 +612,11 @@ class ClusterSharding(system: ExtendedActorSystem) extends Extension { dataCenter: Optional[String], messageExtractor: ShardRegion.MessageExtractor): ActorRef = { - startProxy( - typeName, - Option(role.orElse(null)), - Option(dataCenter.orElse(null)), - extractEntityId = extractEntityIdFromExtractor(messageExtractor), - msg => messageExtractor.shardId(msg)) + startProxy(typeName, Option(role.orElse(null)), Option(dataCenter.orElse(null)), + extractEntityId = { + case msg if messageExtractor.entityId(msg) ne null => + (messageExtractor.entityId(msg), messageExtractor.entityMessage(msg)) + }, extractShardId = msg => messageExtractor.shardId(msg)) }