Skip to content

Commit

Permalink
Merge branch 'master' into dynamically-updating-rate-limits
Browse files Browse the repository at this point in the history
  • Loading branch information
EdwinHeuver92 committed Dec 17, 2023
1 parent 72bcaa9 commit 7f91cc2
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 221 deletions.
2 changes: 1 addition & 1 deletion bucket4j-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<java.version>17</java.version>
<redisson.version>3.23.5</redisson.version>
<redisson.version>3.24.3</redisson.version>
<ignite-core.version>2.15.0</ignite-core.version>
</properties>
<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.BucketConfiguration;
import io.github.bucket4j.ConfigurationBuilder;
import io.github.bucket4j.Refill;
import lombok.extern.slf4j.Slf4j;

/**
Expand Down Expand Up @@ -138,14 +137,14 @@ private ConfigurationBuilder prepareBucket4jConfigurationBuilder(RateLimit rl) {
long refillCapacity = bandWidth.getRefillCapacity() != null ? bandWidth.getRefillCapacity() : bandWidth.getCapacity();
var refillPeriod = Duration.of(bandWidth.getTime(), bandWidth.getUnit());
var bucket4jBandWidth = switch(bandWidth.getRefillSpeed()) {
case GREEDY -> Bandwidth.classic(capacity, Refill.greedy(refillCapacity, refillPeriod)).withId(bandWidth.getId());
case INTERVAL -> Bandwidth.classic(capacity, Refill.intervally(refillCapacity, refillPeriod)).withId(bandWidth.getId());
default -> throw new IllegalStateException("Unsupported Refill type: " + bandWidth.getRefillSpeed());
};
case GREEDY -> Bandwidth.builder().capacity(capacity).refillGreedy(refillCapacity, refillPeriod).id(bandWidth.getId());
case INTERVAL -> Bandwidth.builder().capacity(capacity).refillIntervally(refillCapacity, refillPeriod).id(bandWidth.getId());
};

if(bandWidth.getInitialCapacity() != null) {
bucket4jBandWidth = bucket4jBandWidth.withInitialTokens(bandWidth.getInitialCapacity());
bucket4jBandWidth = bucket4jBandWidth.initialTokens(bandWidth.getInitialCapacity());
}
configBuilder = configBuilder.addLimit(bucket4jBandWidth);
configBuilder = configBuilder.addLimit(bucket4jBandWidth.build());
}
return configBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import static java.nio.charset.StandardCharsets.UTF_8;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import com.giffing.bucket4j.spring.boot.starter.context.ConsumptionProbeHolder;
import com.giffing.bucket4j.spring.boot.starter.context.RateLimitCheck;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
Expand All @@ -22,9 +24,9 @@
@Data
@Slf4j
public class AbstractReactiveFilter {

private FilterConfiguration<ServerHttpRequest> filterConfig;

public AbstractReactiveFilter(FilterConfiguration<ServerHttpRequest> filterConfig) {
this.filterConfig = filterConfig;
}
Expand All @@ -36,37 +38,28 @@ public void setFilterConfig(FilterConfiguration<ServerHttpRequest> filterConfig)
protected boolean urlMatches(ServerHttpRequest request) {
return request.getURI().getPath().matches(filterConfig.getUrl());
}

protected Mono<Void> chainWithRateLimitCheck(ServerWebExchange exchange, ReactiveFilterChain chain) {
log.debug("reate-limit-check;method:{};uri:{}", exchange.getRequest().getMethod(), exchange.getRequest().getURI());
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
List<Mono<ConsumptionProbe>> asyncConsumptionProbes = filterConfig.getRateLimitChecks()
.stream()
.map(rl -> rl.rateLimit(request))
.filter(cph -> cph != null && cph.getConsumptionProbeCompletableFuture() != null)
.map(cph -> Mono.fromFuture(cph.getConsumptionProbeCompletableFuture()))
.toList();
List<Mono<ConsumptionProbe>> asyncConsumptionProbes = new ArrayList<>();
for (RateLimitCheck<ServerHttpRequest> rlc : filterConfig.getRateLimitChecks()) {
ConsumptionProbeHolder cph = rlc.rateLimit(request);
if(cph != null && cph.getConsumptionProbeCompletableFuture() != null){
asyncConsumptionProbes.add(Mono.fromFuture(cph.getConsumptionProbeCompletableFuture()));
if(filterConfig.getStrategy() == RateLimitConditionMatchingStrategy.FIRST){
break;
}
}
}
if(asyncConsumptionProbes.isEmpty()) {
return chain.apply(exchange);
}
AtomicInteger consumptionProbeCounter = new AtomicInteger(0);
return Flux
.concat(asyncConsumptionProbes)
//.takeWhile(Objects::nonNull)
.doOnNext(cp -> consumptionProbeCounter.incrementAndGet())
.takeWhile(cp -> shouldTakeMoreConsumptionProbe(consumptionProbeCounter))
.reduce(this::reduceConsumptionProbe)
.flatMap(consumptionProbe -> handleConsumptionProbe(exchange, chain, response, consumptionProbe));

}

protected boolean shouldTakeMoreConsumptionProbe(AtomicInteger consumptionProbeCounter) {
boolean shouldTakeMore = filterConfig.getStrategy().equals(RateLimitConditionMatchingStrategy.ALL)
||
(filterConfig.getStrategy().equals(RateLimitConditionMatchingStrategy.FIRST) && consumptionProbeCounter.get() == 1);
log.debug("take-more-probes:{};probe-index:{};matching-strategy:{}", shouldTakeMore, consumptionProbeCounter.get(), filterConfig.getStrategy());
return shouldTakeMore;
.concat(asyncConsumptionProbes)
.reduce(this::reduceConsumptionProbe)
.flatMap(consumptionProbe -> handleConsumptionProbe(exchange, chain, response, consumptionProbe));
}

protected ConsumptionProbe reduceConsumptionProbe(ConsumptionProbe x, ConsumptionProbe y) {
Expand All @@ -76,23 +69,23 @@ protected ConsumptionProbe reduceConsumptionProbe(ConsumptionProbe x, Consumptio
} else if(!y.isConsumed()) {
result = y;
} else {
result = x.getRemainingTokens() < y.getRemainingTokens() ? x : y;
result = x.getRemainingTokens() < y.getRemainingTokens() ? x : y;
}
log.debug("reduce-probes;result-isConsumed:{};result-getremainingTokens:{};x-isConsumed:{};x-getremainingTokens{};y-isConsumed:{};y-getremainingTokens{}",
result.isConsumed(), result.getRemainingTokens(),
x.isConsumed(), x.getRemainingTokens(),
y.isConsumed(), y.getRemainingTokens());
return result;
}

protected Mono<Void> handleConsumptionProbe(ServerWebExchange exchange, ReactiveFilterChain chain,
ServerHttpResponse response, ConsumptionProbe cp) {
log.debug("probe-results;isConsumed:{};remainingTokens:{};nanosToWaitForRefill:{};nanosToWaitForReset:{}",
cp.isConsumed(),
cp.getRemainingTokens(),
cp.getNanosToWaitForRefill(),
log.debug("probe-results;isConsumed:{};remainingTokens:{};nanosToWaitForRefill:{};nanosToWaitForReset:{}",
cp.isConsumed(),
cp.getRemainingTokens(),
cp.getNanosToWaitForRefill(),
cp.getNanosToWaitForReset());

if(!cp.isConsumed()) {
if(Boolean.FALSE.equals(filterConfig.getHideHttpResponseHeaders())) {
filterConfig.getHttpResponseHeaders().forEach(response.getHeaders()::addIfAbsent);
Expand All @@ -101,7 +94,7 @@ protected Mono<Void> handleConsumptionProbe(ServerWebExchange exchange, Reactive
response.setStatusCode(filterConfig.getHttpStatusCode());
response.getHeaders().set("Content-Type", filterConfig.getHttpContentType());
DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(filterConfig.getHttpResponseBody().getBytes(UTF_8));
return response.writeWith(Flux.just(buffer));
return response.writeWith(Flux.just(buffer));
} else {
return Mono.error(new ReactiveRateLimitException("HTTP ResponseBody is null"));
}
Expand All @@ -112,6 +105,4 @@ protected Mono<Void> handleConsumptionProbe(ServerWebExchange exchange, Reactive
}
return chain.apply(exchange);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -39,52 +39,52 @@
class SpringCloudGatewayRateLimitFilterTest {

private GlobalFilter filter;
private FilterConfiguration configuration;
private RateLimitCheck rateLimitCheck1;
private RateLimitCheck rateLimitCheck2;
private RateLimitCheck rateLimitCheck3;
private FilterConfiguration<ServerHttpRequest> configuration;
private RateLimitCheck<ServerHttpRequest> rateLimitCheck1;
private RateLimitCheck<ServerHttpRequest> rateLimitCheck2;
private RateLimitCheck<ServerHttpRequest> rateLimitCheck3;

private ServerWebExchange exchange;
private GatewayFilterChain chain;


private ServerHttpResponse serverHttpResponse;

@BeforeEach
public void setup() throws URISyntaxException {
rateLimitCheck1 = mock(RateLimitCheck.class);
rateLimitCheck2 = mock(RateLimitCheck.class);
rateLimitCheck3 = mock(RateLimitCheck.class);

exchange = Mockito.mock(ServerWebExchange.class);
ServerHttpRequest serverHttpRequest = Mockito.mock(ServerHttpRequest.class);
URI uri = new URI("url");
when(serverHttpRequest.getURI()).thenReturn(uri);
public void setup() throws URISyntaxException {
rateLimitCheck1 = mock(RateLimitCheck.class);
rateLimitCheck2 = mock(RateLimitCheck.class);
rateLimitCheck3 = mock(RateLimitCheck.class);

exchange = Mockito.mock(ServerWebExchange.class);

ServerHttpRequest serverHttpRequest = Mockito.mock(ServerHttpRequest.class);
URI uri = new URI("url");
when(serverHttpRequest.getURI()).thenReturn(uri);
when(exchange.getRequest()).thenReturn(serverHttpRequest);

serverHttpResponse = Mockito.mock(ServerHttpResponse.class);
when(exchange.getResponse()).thenReturn(serverHttpResponse);
when(exchange.getResponse()).thenReturn(serverHttpResponse);

chain = Mockito.mock(GatewayFilterChain.class);
when(chain.filter(exchange)).thenReturn(Mono.empty());
configuration = new FilterConfiguration();
configuration.setRateLimitChecks(Arrays.asList(rateLimitCheck1, rateLimitCheck2, rateLimitCheck3));
configuration.setUrl(".*");
filter = new SpringCloudGatewayRateLimitFilter(configuration);
}

@Test

configuration = new FilterConfiguration<>();
configuration.setRateLimitChecks(Arrays.asList(rateLimitCheck1, rateLimitCheck2, rateLimitCheck3));
configuration.setUrl(".*");
filter = new SpringCloudGatewayRateLimitFilter(configuration);
}

@Test
void should_throw_rate_limit_exception_with_no_remaining_tokens() {

configuration.setStrategy(RateLimitConditionMatchingStrategy.FIRST);

rateLimitConfig(0L, rateLimitCheck1);
HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);
AtomicBoolean hasRateLimitError = new AtomicBoolean(false);
rateLimitConfig(0L, rateLimitCheck1);
HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);

AtomicBoolean hasRateLimitError = new AtomicBoolean(false);
Mono<Void> result = filter.filter(exchange, chain)
.onErrorResume(ReactiveRateLimitException.class, (e) -> {
hasRateLimitError.set(true);
Expand All @@ -93,63 +93,62 @@ void should_throw_rate_limit_exception_with_no_remaining_tokens() {
result.subscribe();
Assertions.assertTrue(hasRateLimitError.get());
}

@Test
void should_execute_all_checks_when_using_RateLimitConditionMatchingStrategy_All() throws URISyntaxException {
configuration.setStrategy(RateLimitConditionMatchingStrategy.ALL);

rateLimitConfig(30L, rateLimitCheck1);
rateLimitConfig(0L, rateLimitCheck2);
rateLimitConfig(0L, rateLimitCheck3);

HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);
final ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
Mono<Void> result = filter.filter(exchange, chain);
assertThrows(ReactiveRateLimitException.class, () -> {
result.block();
});

configuration.setStrategy(RateLimitConditionMatchingStrategy.ALL);

rateLimitConfig(30L, rateLimitCheck1);
rateLimitConfig(0L, rateLimitCheck2);
rateLimitConfig(0L, rateLimitCheck3);

HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);
final ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);

Mono<Void> result = filter.filter(exchange, chain);
assertThrows(ReactiveRateLimitException.class, () -> {
result.block();
});

verify(rateLimitCheck1, times(1)).rateLimit(any());
verify(rateLimitCheck2, times(1)).rateLimit(any());
verify(rateLimitCheck3, times(1)).rateLimit(any());
verify(rateLimitCheck2, times(1)).rateLimit(any());
verify(rateLimitCheck3, times(1)).rateLimit(any());
}

@Test
void should_execute_only_one_check_when_using_RateLimitConditionMatchingStrategy_FIRST() {
configuration.setStrategy(RateLimitConditionMatchingStrategy.FIRST);

rateLimitConfig(30L, rateLimitCheck1);
rateLimitConfig(0L, rateLimitCheck2);
rateLimitConfig(10L, rateLimitCheck3);
HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);
final ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
configuration.setStrategy(RateLimitConditionMatchingStrategy.FIRST);

rateLimitConfig(30L, rateLimitCheck1);
rateLimitConfig(0L, rateLimitCheck2);
rateLimitConfig(10L, rateLimitCheck3);

HttpHeaders httpHeaders = Mockito.mock(HttpHeaders.class);
when(serverHttpResponse.getHeaders()).thenReturn(httpHeaders);
final ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);

Mono<Void> result = filter.filter(exchange, chain);
result.block();
verify(httpHeaders, times(1)).set(any(), captor.capture());

List<String> values = captor.getAllValues();
Assertions.assertEquals("30", values.stream().findFirst().get());
verify(rateLimitCheck1, times(1)).rateLimit(any());
verify(rateLimitCheck2, times(1)).rateLimit(any());
verify(rateLimitCheck3, times(1)).rateLimit(any());

verify(httpHeaders, times(1)).set(any(), captor.capture());

List<String> values = captor.getAllValues();
Assertions.assertEquals("30", values.stream().findFirst().get());

verify(rateLimitCheck1, times(1)).rateLimit(any());
verify(rateLimitCheck2, times(0)).rateLimit(any());
verify(rateLimitCheck3, times(0)).rateLimit(any());
}

private void rateLimitConfig(Long remainingTokens, RateLimitCheck rateLimitCheck) {
private void rateLimitConfig(Long remainingTokens, RateLimitCheck<ServerHttpRequest> rateLimitCheck) {
ConsumptionProbeHolder consumptionHolder = Mockito.mock(ConsumptionProbeHolder.class);
ConsumptionProbe probe = Mockito.mock(ConsumptionProbe.class);
when(probe.isConsumed()).thenReturn(remainingTokens > 0 ? true : false);
ConsumptionProbe probe = Mockito.mock(ConsumptionProbe.class);
when(probe.isConsumed()).thenReturn(remainingTokens > 0);
when(probe.getRemainingTokens()).thenReturn(remainingTokens);
when(consumptionHolder.getConsumptionProbeCompletableFuture())
.thenReturn(CompletableFuture.completedFuture(probe));
when(rateLimitCheck.rateLimit(any())).thenReturn(consumptionHolder);
.thenReturn(CompletableFuture.completedFuture(probe));
when(rateLimitCheck.rateLimit(any())).thenReturn(consumptionHolder);
}

}
Loading

0 comments on commit 7f91cc2

Please sign in to comment.