diff --git a/.gitignore b/.gitignore index 6700b0d..00e4267 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,7 @@ __pycache__/ .mypy_cache/ .pytest_cache/ .eggs/ +*.egg-info/ # Environments .env diff --git a/caraml-store-spark/build.gradle b/caraml-store-spark/build.gradle index fe70d4b..703bf7a 100644 --- a/caraml-store-spark/build.gradle +++ b/caraml-store-spark/build.gradle @@ -51,6 +51,7 @@ dependencies { } } implementation "org.json4s:json4s-ext_$scalaVersion:3.7.0-M5" + implementation 'com.bucket4j:bucket4j-core:8.5.0' testImplementation "org.scalatest:scalatest_$scalaVersion:3.2.2" testImplementation "org.scalacheck:scalacheck_$scalaVersion:1.14.3" testImplementation "com.dimafeng:testcontainers-scala-scalatest_$scalaVersion:0.40.12" diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala index 49579f4..cb854aa 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala @@ -27,6 +27,8 @@ object BasePipeline { .set("spark.redis.ssl", ssl.toString) .set("spark.redis.properties.maxJitter", properties.maxJitterSeconds.toString) .set("spark.redis.properties.pipelineSize", properties.pipelineSize.toString) + .set("spark.redis.properties.enableRateLimit", properties.enableRateLimit.toString) + .set("spark.redis.properties.ratePerSecondLimit", properties.ratePerSecondLimit.toString) case BigTableConfig(projectId, instanceId) => conf .set("spark.bigtable.projectId", projectId) diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala index 1c2ff69..a13524d 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala @@ -22,7 +22,9 @@ case class RedisConfig( case class RedisWriteProperties( maxJitterSeconds: Int = 3600, pipelineSize: Int = 250, - ttlSeconds: Long = 0L + ttlSeconds: Long = 0L, + enableRateLimit: Boolean = false, + ratePerSecondLimit: Int = 50000 ) case class BigTableConfig(projectId: String, instanceId: String) extends StoreConfig diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala index 0b565e0..a28c69e 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala @@ -2,6 +2,7 @@ package dev.caraml.spark.stores.redis import dev.caraml.spark.RedisWriteProperties import dev.caraml.spark.utils.TypeConversion +import io.github.bucket4j.{Bandwidth, Bucket} import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.metrics.source.RedisSinkMetricSource import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} @@ -9,6 +10,8 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.{SparkConf, SparkEnv} +import java.time.Duration.ofSeconds + /** * High-level writer to Redis. Relies on `Persistence` implementation for actual storage layout. * Here we define general flow: @@ -40,9 +43,22 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC lazy val properties: RedisWriteProperties = RedisWriteProperties( maxJitterSeconds = sparkConf.get("spark.redis.properties.maxJitter").toInt, pipelineSize = sparkConf.get("spark.redis.properties.pipelineSize").toInt, - ttlSeconds = config.entityMaxAge + ttlSeconds = config.entityMaxAge, + enableRateLimit = sparkConf.get("spark.redis.properties.enableRateLimit").toBoolean, + ratePerSecondLimit = sparkConf.get("spark.redis.properties.ratePerSecondLimit").toInt ) + lazy private val rateLimitBucket: Bucket = Bucket + .builder() + .addLimit( + Bandwidth + .builder() + .capacity(properties.ratePerSecondLimit) + .refillIntervally(properties.ratePerSecondLimit, ofSeconds(1)) + .build() + ) + .build() + override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.foreachPartition { partition: Iterator[Row] => java.security.Security.setProperty("networkaddress.cache.ttl", "3"); @@ -52,6 +68,9 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC // grouped iterator to only allocate memory for a portion of rows partition.grouped(properties.pipelineSize).foreach { batch => + if (properties.enableRateLimit) { + rateLimitBucket.asBlocking().consume(batch.length) + } val rowsWithKey: Seq[(String, Row)] = batch.map(row => dataKeyId(row) -> row) pipelineProvider.withPipeline(pipeline => { diff --git a/caraml-store-spark/src/test/scala/dev/caraml/spark/BatchPipelineSpec.scala b/caraml-store-spark/src/test/scala/dev/caraml/spark/BatchPipelineSpec.scala index f87992d..1fcf1ec 100644 --- a/caraml-store-spark/src/test/scala/dev/caraml/spark/BatchPipelineSpec.scala +++ b/caraml-store-spark/src/test/scala/dev/caraml/spark/BatchPipelineSpec.scala @@ -45,6 +45,8 @@ class BatchPipelineSpec extends SparkSpec with ForAllTestContainer { .set("spark.metrics.conf.*.sink.statsd.port", statsDStub.port.toString) .set("spark.redis.properties.maxJitter", "0") .set("spark.redis.properties.pipelineSize", "250") + .set("spark.redis.properties.enableRateLimit", "false") + .set("spark.redis.properties.ratePerSecondLimit", "50000") trait Scope { val jedis = new Jedis("localhost", container.mappedPort(6379)) diff --git a/caraml-store-spark/src/test/scala/dev/caraml/spark/StreamingPipelineSpec.scala b/caraml-store-spark/src/test/scala/dev/caraml/spark/StreamingPipelineSpec.scala index b46f132..717c7ff 100644 --- a/caraml-store-spark/src/test/scala/dev/caraml/spark/StreamingPipelineSpec.scala +++ b/caraml-store-spark/src/test/scala/dev/caraml/spark/StreamingPipelineSpec.scala @@ -43,6 +43,8 @@ class StreamingPipelineSpec extends SparkSpec with ForAllTestContainer { .set("spark.sql.streaming.checkpointLocation", generateTempPath("checkpoint")) .set("spark.redis.properties.maxJitter", "0") .set("spark.redis.properties.pipelineSize", "250") + .set("spark.redis.properties.enableRateLimit", "false") + .set("spark.redis.properties.ratePerSecondLimit", "50000") trait KafkaPublisher { val props = new Properties()