diff --git a/bucket4j-spring-boot-starter-context/pom.xml b/bucket4j-spring-boot-starter-context/pom.xml index 27f1a9ef..07fb0b41 100644 --- a/bucket4j-spring-boot-starter-context/pom.xml +++ b/bucket4j-spring-boot-starter-context/pom.xml @@ -48,6 +48,8 @@ jakarta.servlet jakarta.servlet-api + provided + true diff --git a/bucket4j-spring-boot-starter-context/src/main/java/com/giffing/bucket4j/spring/boot/starter/context/constraintvalidations/Bucket4JConfigurationPredicateNameValidator.java b/bucket4j-spring-boot-starter-context/src/main/java/com/giffing/bucket4j/spring/boot/starter/context/constraintvalidations/Bucket4JConfigurationPredicateNameValidator.java index 9d22a190..3332075a 100644 --- a/bucket4j-spring-boot-starter-context/src/main/java/com/giffing/bucket4j/spring/boot/starter/context/constraintvalidations/Bucket4JConfigurationPredicateNameValidator.java +++ b/bucket4j-spring-boot-starter-context/src/main/java/com/giffing/bucket4j/spring/boot/starter/context/constraintvalidations/Bucket4JConfigurationPredicateNameValidator.java @@ -5,11 +5,11 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import jakarta.servlet.http.HttpServletRequest; import jakarta.validation.ConstraintValidator; import jakarta.validation.ConstraintValidatorContext; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.GenericTypeResolver; import org.springframework.http.server.reactive.ServerHttpRequest; import com.giffing.bucket4j.spring.boot.starter.context.ExecutePredicate; @@ -23,9 +23,18 @@ public class Bucket4JConfigurationPredicateNameValidator implements ConstraintVa private final Map>> filterPredicates = new HashMap<>(); @Autowired - public Bucket4JConfigurationPredicateNameValidator( - List> servletPredicates, - List> webfluxPredicates) { + public Bucket4JConfigurationPredicateNameValidator(List> executePredicates) { + List> servletPredicates = new ArrayList<>(); + List> webfluxPredicates = new ArrayList<>(); + executePredicates.forEach(x -> { + Class genericType = GenericTypeResolver.resolveTypeArgument(x.getClass(), ExecutePredicate.class); + if(genericType == null) return; + if(genericType.getName().equals("jakarta.servlet.http.HttpServletRequest")){ + servletPredicates.add(x); + } else if (genericType == ServerHttpRequest.class){ + webfluxPredicates.add(x); + } + }); filterPredicates.put(FilterMethod.SERVLET, servletPredicates.stream() .collect(Collectors.toMap(ExecutePredicate::name, Function.identity()))); diff --git a/bucket4j-spring-boot-starter/src/test/java/com/giffing/bucket4j/spring/boot/starter/config/filter/predicate/ConfigPredicateNameValidatorTest.java b/bucket4j-spring-boot-starter/src/test/java/com/giffing/bucket4j/spring/boot/starter/config/filter/predicate/ConfigPredicateNameValidatorTest.java index d1cafa84..bbcb99c1 100644 --- a/bucket4j-spring-boot-starter/src/test/java/com/giffing/bucket4j/spring/boot/starter/config/filter/predicate/ConfigPredicateNameValidatorTest.java +++ b/bucket4j-spring-boot-starter/src/test/java/com/giffing/bucket4j/spring/boot/starter/config/filter/predicate/ConfigPredicateNameValidatorTest.java @@ -34,15 +34,18 @@ class ConfigPredicateNameValidatorTest { - List> servletPredicates; - List> webfluxPredicates; + List> executePredicates; Bucket4JConfigurationPredicateNameValidator validator; @BeforeEach void setup() { - servletPredicates = Arrays.asList(new ServletPathExecutePredicate(), new ServletMethodPredicate()); - webfluxPredicates = Arrays.asList(new WebfluxPathExecutePredicate(), new WebfluxHeaderExecutePredicate()); - validator = new Bucket4JConfigurationPredicateNameValidator(servletPredicates, webfluxPredicates); + executePredicates = Arrays.asList( + //Servlet predicates + new ServletPathExecutePredicate(), new ServletMethodPredicate(), + //Webflux predicates + new WebfluxPathExecutePredicate(), new WebfluxHeaderExecutePredicate() + ); + validator = new Bucket4JConfigurationPredicateNameValidator(executePredicates); } /** @@ -144,10 +147,11 @@ void testInvalidWebfluxWithServletPredicate() { */ @Test void customPredicateTest() { - List> customServletPredicates = new ArrayList<>(this.servletPredicates); - customServletPredicates.add(new CustomTestPredicate()); + List> includingCustomPredicate = new ArrayList<>(this.executePredicates); + includingCustomPredicate.add(new CustomTestPredicate()); + Bucket4JConfigurationPredicateNameValidator customPredicateValidator = - new Bucket4JConfigurationPredicateNameValidator(customServletPredicates, webfluxPredicates); + new Bucket4JConfigurationPredicateNameValidator(includingCustomPredicate); List executePredicates = List.of("CUSTOM-QUERY=custom-servlet"); List skipPredicates = List.of();