diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/InitDestroyAnnotationBeanPostProcessor.java b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/InitDestroyAnnotationBeanPostProcessor.java index 708064488acf..05ae7b7d4cb6 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/annotation/InitDestroyAnnotationBeanPostProcessor.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/annotation/InitDestroyAnnotationBeanPostProcessor.java @@ -31,25 +31,34 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Stream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.hint.ExecutableMode; import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanCreationException; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution; +import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor; +import org.springframework.beans.factory.aot.BeanFactoryInitializationCode; import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.DestructionAwareBeanPostProcessor; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor; import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.core.Ordered; import org.springframework.core.PriorityOrdered; import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.MethodSpec; import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; -import org.springframework.util.CollectionUtils; import org.springframework.util.ReflectionUtils; /** @@ -85,7 +94,7 @@ */ @SuppressWarnings("serial") public class InitDestroyAnnotationBeanPostProcessor implements DestructionAwareBeanPostProcessor, - MergedBeanDefinitionPostProcessor, BeanRegistrationAotProcessor, PriorityOrdered, Serializable { + MergedBeanDefinitionPostProcessor, BeanRegistrationAotProcessor, BeanFactoryInitializationAotProcessor, PriorityOrdered, Serializable { private final transient LifecycleMetadata emptyLifecycleMetadata = new LifecycleMetadata(Object.class, Collections.emptyList(), Collections.emptyList()) { @@ -188,15 +197,22 @@ public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registe RootBeanDefinition beanDefinition = registeredBean.getMergedBeanDefinition(); beanDefinition.resolveDestroyMethodIfNecessary(); LifecycleMetadata metadata = findLifecycleMetadata(beanDefinition, registeredBean.getBeanClass()); - if (!CollectionUtils.isEmpty(metadata.initMethods)) { - String[] initMethodNames = safeMerge(beanDefinition.getInitMethodNames(), metadata.initMethods); - beanDefinition.setInitMethodNames(initMethodNames); - } - if (!CollectionUtils.isEmpty(metadata.destroyMethods)) { - String[] destroyMethodNames = safeMerge(beanDefinition.getDestroyMethodNames(), metadata.destroyMethods); - beanDefinition.setDestroyMethodNames(destroyMethodNames); + return (generationContext, beanRegistrationCode) -> { + metadata.initMethods.forEach(lm -> registerLifecycleMethodForInvoke(generationContext, lm)); + metadata.destroyMethods.forEach(lm -> registerLifecycleMethodForInvoke(generationContext, lm)); + }; + } + + private void registerLifecycleMethodForInvoke(GenerationContext generationContext, LifecycleMethod lm) { + generationContext.getRuntimeHints().reflection().registerMethod(lm.getMethod(), ExecutableMode.INVOKE); + } + + @Override + public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) { + if (this.initAnnotationTypes.isEmpty() && this.destroyAnnotationTypes.isEmpty()) { + return null; } - return null; + return new BeanFactoryAotContribution(this.initAnnotationTypes, this.destroyAnnotationTypes); } private LifecycleMetadata findLifecycleMetadata(RootBeanDefinition beanDefinition, Class beanClass) { @@ -205,13 +221,6 @@ private LifecycleMetadata findLifecycleMetadata(RootBeanDefinition beanDefinitio return metadata; } - private static String[] safeMerge(@Nullable String[] existingNames, Collection detectedMethods) { - Stream detectedNames = detectedMethods.stream().map(LifecycleMethod::getIdentifier); - Stream mergedNames = (existingNames != null ? - Stream.concat(detectedNames, Stream.of(existingNames)) : detectedNames); - return mergedNames.distinct().toArray(String[]::new); - } - @Override public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { LifecycleMetadata metadata = findLifecycleMetadata(bean.getClass()); @@ -486,4 +495,33 @@ private static boolean isPrivateOrNotVisible(Method method, Class beanClass) } + private record BeanFactoryAotContribution( + Set> initAnnotationTypes, + Set> destroyAnnotationTypes) implements BeanFactoryInitializationAotContribution { + private static final String BEAN_FACTORY_PARAMETER_NAME = "beanFactory"; + private static final String POST_PROCESSOR_PARAMETER_NAME = "postProcessor"; + + @Override + public void applyTo(GenerationContext generationContext, BeanFactoryInitializationCode beanFactoryInitializationCode) { + // to generate a unique name in case there are multiple InitDestroyAnnotationBeanPostProcessor-s + String[] methodNameParts = {"addInitDestroyBeanPostProcessorMethod"}; + GeneratedMethod generatedMethod = beanFactoryInitializationCode.getMethods() + .add(methodNameParts, this::generateAddInitDestroyBeanPostProcessorMethod); + beanFactoryInitializationCode.addInitializer(generatedMethod.toMethodReference()); + } + + private void generateAddInitDestroyBeanPostProcessorMethod(MethodSpec.Builder method) { + method.addJavadoc("Apply known externally managed init/destroy annotation bean post processors"); + method.addModifiers(javax.lang.model.element.Modifier.PRIVATE); + method.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAMETER_NAME); + + CodeBlock.Builder code = CodeBlock.builder(); + code.addStatement("$T $L = new $T()", + InitDestroyAnnotationBeanPostProcessor.class, POST_PROCESSOR_PARAMETER_NAME, InitDestroyAnnotationBeanPostProcessor.class); + this.initAnnotationTypes.forEach(type -> code.addStatement("$L.addInitAnnotationType($T.class)", POST_PROCESSOR_PARAMETER_NAME, ClassName.get(type))); + this.destroyAnnotationTypes.forEach(type -> code.addStatement("$L.addDestroyAnnotationType($T.class)", POST_PROCESSOR_PARAMETER_NAME, ClassName.get(type))); + code.addStatement("$L.addBeanPostProcessor($L)", BEAN_FACTORY_PARAMETER_NAME, POST_PROCESSOR_PARAMETER_NAME); + method.addCode(code.build()); + } + } } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/InitDestroyAnnotationBeanPostProcessorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/InitDestroyAnnotationBeanPostProcessorTests.java index be1fbaaea641..624802ea9755 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/InitDestroyAnnotationBeanPostProcessorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/InitDestroyAnnotationBeanPostProcessorTests.java @@ -20,6 +20,7 @@ import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.beans.factory.support.RegisteredBean; @@ -57,10 +58,11 @@ void processAheadOfTimeWhenNoCallbackDoesNotMutateRootBeanDefinition() { @Test void processAheadOfTimeWhenHasInitDestroyAnnotationsAddsMethodNames() { RootBeanDefinition beanDefinition = new RootBeanDefinition(InitDestroyBean.class); - processAheadOfTime(beanDefinition); + BeanRegistrationAotContribution beanRegistrationAotContribution = processAheadOfTime(beanDefinition); + assertThat(beanRegistrationAotContribution).isNotNull(); RootBeanDefinition mergedBeanDefinition = getMergedBeanDefinition(); - assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly("initMethod"); - assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly("destroyMethod"); + assertThat(mergedBeanDefinition.getInitMethodNames()).isNull(); + assertThat(mergedBeanDefinition.getDestroyMethodNames()).isNull(); } @Test @@ -68,10 +70,11 @@ void processAheadOfTimeWhenHasInitDestroyAnnotationsAndCustomDefinedMethodNamesA RootBeanDefinition beanDefinition = new RootBeanDefinition(InitDestroyBean.class); beanDefinition.setInitMethodName("customInitMethod"); beanDefinition.setDestroyMethodNames("customDestroyMethod"); - processAheadOfTime(beanDefinition); + BeanRegistrationAotContribution beanRegistrationAotContribution = processAheadOfTime(beanDefinition); + assertThat(beanRegistrationAotContribution).isNotNull(); RootBeanDefinition mergedBeanDefinition = getMergedBeanDefinition(); - assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly("initMethod", "customInitMethod"); - assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly("destroyMethod", "customDestroyMethod"); + assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly("customInitMethod"); + assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly("customDestroyMethod"); } @Test @@ -108,10 +111,11 @@ void processAheadOfTimeWhenHasInferredDestroyMethodAndNoCandidateDoesNotMutateRo @Test void processAheadOfTimeWhenHasMultipleInitDestroyAnnotationsAddsAllMethodNames() { RootBeanDefinition beanDefinition = new RootBeanDefinition(MultiInitDestroyBean.class); - processAheadOfTime(beanDefinition); + BeanRegistrationAotContribution beanRegistrationAotContribution = processAheadOfTime(beanDefinition); + assertThat(beanRegistrationAotContribution).isNotNull(); RootBeanDefinition mergedBeanDefinition = getMergedBeanDefinition(); - assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly("initMethod", "anotherInitMethod"); - assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly("anotherDestroyMethod", "destroyMethod"); + assertThat(mergedBeanDefinition.getInitMethodNames()).isNull(); + assertThat(mergedBeanDefinition.getDestroyMethodNames()).isNull(); } @Test @@ -125,27 +129,24 @@ void processAheadOfTimeWithMultipleLevelsOfPublicAndPrivateInitAndDestroyMethods // to ensure that it will be tracked as such even though it has the same // name as DisposableBean#destroy(). beanDefinition.setDestroyMethodNames("destroy", "customDestroy"); - processAheadOfTime(beanDefinition); + BeanRegistrationAotContribution beanRegistrationAotContribution = processAheadOfTime(beanDefinition); + assertThat(beanRegistrationAotContribution).isNotNull(); RootBeanDefinition mergedBeanDefinition = getMergedBeanDefinition(); assertSoftly(softly -> { softly.assertThat(mergedBeanDefinition.getInitMethodNames()).containsExactly( - CustomAnnotatedPrivateInitDestroyBean.class.getName() + ".privateInit", // fully-qualified private method - CustomAnnotatedPrivateSameNameInitDestroyBean.class.getName() + ".privateInit", // fully-qualified private method "afterPropertiesSet", "customInit" ); softly.assertThat(mergedBeanDefinition.getDestroyMethodNames()).containsExactly( - CustomAnnotatedPrivateSameNameInitDestroyBean.class.getName() + ".privateDestroy", // fully-qualified private method - CustomAnnotatedPrivateInitDestroyBean.class.getName() + ".privateDestroy", // fully-qualified private method "destroy", "customDestroy" ); }); } - private void processAheadOfTime(RootBeanDefinition beanDefinition) { + private BeanRegistrationAotContribution processAheadOfTime(RootBeanDefinition beanDefinition) { RegisteredBean registeredBean = registerBean(beanDefinition); - assertThat(createAotBeanPostProcessor().processAheadOfTime(registeredBean)).isNull(); + return createAotBeanPostProcessor().processAheadOfTime(registeredBean); } private RegisteredBean registerBean(RootBeanDefinition beanDefinition) { diff --git a/spring-context/src/test/java/org/springframework/context/annotation/InitDestroyMethodLifecycleTests.java b/spring-context/src/test/java/org/springframework/context/annotation/InitDestroyMethodLifecycleTests.java index fe78bd41de00..874d4bb20af3 100644 --- a/spring-context/src/test/java/org/springframework/context/annotation/InitDestroyMethodLifecycleTests.java +++ b/spring-context/src/test/java/org/springframework/context/annotation/InitDestroyMethodLifecycleTests.java @@ -185,17 +185,17 @@ void jakartaAnnotationsWithCustomSameMethodNamesWithAotProcessingAndAotRuntime() CustomAnnotatedPrivateSameNameInitDestroyBean bean = aotApplicationContext.getBean("lifecycleTestBean", beanClass); assertThat(bean.initMethods).as("init-methods").containsExactly( - "afterPropertiesSet", "@PostConstruct.privateCustomInit1", "@PostConstruct.sameNameCustomInit1", + "afterPropertiesSet", "customInit" ); aotApplicationContext.close(); assertThat(bean.destroyMethods).as("destroy-methods").containsExactly( - "destroy", "@PreDestroy.sameNameCustomDestroy1", "@PreDestroy.privateCustomDestroy1", + "destroy", "customDestroy" ); }); @@ -220,17 +220,17 @@ void jakartaAnnotationsWithPackagePrivateInitDestroyMethodsWithAotProcessingAndA SubPackagePrivateInitDestroyBean bean = aotApplicationContext.getBean("lifecycleTestBean", beanClass); assertThat(bean.initMethods).as("init-methods").containsExactly( - "InitializingBean.afterPropertiesSet", "PackagePrivateInitDestroyBean.postConstruct", "SubPackagePrivateInitDestroyBean.postConstruct", + "InitializingBean.afterPropertiesSet", "initMethod" ); aotApplicationContext.close(); assertThat(bean.destroyMethods).as("destroy-methods").containsExactly( - "DisposableBean.destroy", "SubPackagePrivateInitDestroyBean.preDestroy", "PackagePrivateInitDestroyBean.preDestroy", + "DisposableBean.destroy", "destroyMethod" ); });