From e46d19ca57e27167185498b9e920309fc9e261c7 Mon Sep 17 00:00:00 2001 From: Khor Shu Heng Date: Wed, 18 Oct 2023 11:18:31 +0800 Subject: [PATCH] feat: Add support for local rate limiter --- .gitignore | 1 + caraml-store-spark/build.gradle | 1 + .../scala/dev/caraml/spark/BasePipeline.scala | 2 ++ .../dev/caraml/spark/IngestionJobConfig.scala | 4 +++- .../stores/redis/RedisSinkRelation.scala | 21 ++++++++++++++++++- 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 6700b0d..00e4267 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,7 @@ __pycache__/ .mypy_cache/ .pytest_cache/ .eggs/ +*.egg-info/ # Environments .env diff --git a/caraml-store-spark/build.gradle b/caraml-store-spark/build.gradle index fe70d4b..703bf7a 100644 --- a/caraml-store-spark/build.gradle +++ b/caraml-store-spark/build.gradle @@ -51,6 +51,7 @@ dependencies { } } implementation "org.json4s:json4s-ext_$scalaVersion:3.7.0-M5" + implementation 'com.bucket4j:bucket4j-core:8.5.0' testImplementation "org.scalatest:scalatest_$scalaVersion:3.2.2" testImplementation "org.scalacheck:scalacheck_$scalaVersion:1.14.3" testImplementation "com.dimafeng:testcontainers-scala-scalatest_$scalaVersion:0.40.12" diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala index 49579f4..cb854aa 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/BasePipeline.scala @@ -27,6 +27,8 @@ object BasePipeline { .set("spark.redis.ssl", ssl.toString) .set("spark.redis.properties.maxJitter", properties.maxJitterSeconds.toString) .set("spark.redis.properties.pipelineSize", properties.pipelineSize.toString) + .set("spark.redis.properties.enableRateLimit", properties.enableRateLimit.toString) + .set("spark.redis.properties.ratePerSecondLimit", properties.ratePerSecondLimit.toString) case BigTableConfig(projectId, instanceId) => conf .set("spark.bigtable.projectId", projectId) diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala index 1c2ff69..a13524d 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/IngestionJobConfig.scala @@ -22,7 +22,9 @@ case class RedisConfig( case class RedisWriteProperties( maxJitterSeconds: Int = 3600, pipelineSize: Int = 250, - ttlSeconds: Long = 0L + ttlSeconds: Long = 0L, + enableRateLimit: Boolean = false, + ratePerSecondLimit: Int = 50000 ) case class BigTableConfig(projectId: String, instanceId: String) extends StoreConfig diff --git a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala index 0b565e0..a28c69e 100644 --- a/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala +++ b/caraml-store-spark/src/main/scala/dev/caraml/spark/stores/redis/RedisSinkRelation.scala @@ -2,6 +2,7 @@ package dev.caraml.spark.stores.redis import dev.caraml.spark.RedisWriteProperties import dev.caraml.spark.utils.TypeConversion +import io.github.bucket4j.{Bandwidth, Bucket} import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.metrics.source.RedisSinkMetricSource import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} @@ -9,6 +10,8 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.{SparkConf, SparkEnv} +import java.time.Duration.ofSeconds + /** * High-level writer to Redis. Relies on `Persistence` implementation for actual storage layout. * Here we define general flow: @@ -40,9 +43,22 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC lazy val properties: RedisWriteProperties = RedisWriteProperties( maxJitterSeconds = sparkConf.get("spark.redis.properties.maxJitter").toInt, pipelineSize = sparkConf.get("spark.redis.properties.pipelineSize").toInt, - ttlSeconds = config.entityMaxAge + ttlSeconds = config.entityMaxAge, + enableRateLimit = sparkConf.get("spark.redis.properties.enableRateLimit").toBoolean, + ratePerSecondLimit = sparkConf.get("spark.redis.properties.ratePerSecondLimit").toInt ) + lazy private val rateLimitBucket: Bucket = Bucket + .builder() + .addLimit( + Bandwidth + .builder() + .capacity(properties.ratePerSecondLimit) + .refillIntervally(properties.ratePerSecondLimit, ofSeconds(1)) + .build() + ) + .build() + override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.foreachPartition { partition: Iterator[Row] => java.security.Security.setProperty("networkaddress.cache.ttl", "3"); @@ -52,6 +68,9 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC // grouped iterator to only allocate memory for a portion of rows partition.grouped(properties.pipelineSize).foreach { batch => + if (properties.enableRateLimit) { + rateLimitBucket.asBlocking().consume(batch.length) + } val rowsWithKey: Seq[(String, Row)] = batch.map(row => dataKeyId(row) -> row) pipelineProvider.withPipeline(pipeline => {