Skip to content

Commit

Permalink
fix #1606 AiServicesAutoConfig is unable to detect AiService in the p…
Browse files Browse the repository at this point in the history
…ackage specified by @componentscan (#35)

AiServicesAutoConfig is unable to detect AiService in the package
specified by @componentscan

Changes Made:

- Replaced the original use of Reflections with a combination of
`BeanDefinitionRegistryPostProcessor` and
`ClassPathBeanDefinitionScanner` to register interfaces.
- Scanned all the packages specified by @componentscan and those
obtained from AutoConfigurationPackages as basePackages.


Relevant Links:
- langchain4j/langchain4j#1606
- langchain4j/langchain4j#1593
  • Loading branch information
qing-wq authored Aug 21, 2024
1 parent 17e2470 commit 34d0962
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 31 deletions.
6 changes: 0 additions & 6 deletions langchain4j-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>org.reflections</groupId>
<artifactId>reflections</artifactId>
<version>0.10.2</version>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package dev.langchain4j.service.spring;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.boot.autoconfigure.AutoConfigurationPackages;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.stereotype.Component;

import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

@Component
public class AiServiceScannerProcessor implements BeanDefinitionRegistryPostProcessor {

@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
ClassPathAiServiceScanner classPathInterfaceScanner = new ClassPathAiServiceScanner(registry, false);
classPathInterfaceScanner.registerFilters();
Set<String> basePackages = getBasePackages((ConfigurableListableBeanFactory) registry);
for (String basePackage : basePackages) {
classPathInterfaceScanner.scan(basePackage);
}
}

private Set<String> getBasePackages(ConfigurableListableBeanFactory beanFactory) {
Set<String> basePackages = new LinkedHashSet<>();

List<String> autoConfigPackages = AutoConfigurationPackages.get(beanFactory);
basePackages.addAll(autoConfigPackages);

String[] beanNames = beanFactory.getBeanNamesForAnnotation(ComponentScan.class);
for (String beanName : beanNames) {
Class<?> beanClass = beanFactory.getType(beanName);
if (beanClass != null) {
ComponentScan componentScan = beanClass.getAnnotation(ComponentScan.class);
if (componentScan != null) {
Collections.addAll(basePackages, componentScan.value());
Collections.addAll(basePackages, componentScan.basePackages());
for (Class<?> basePackageClass : componentScan.basePackageClasses()) {
basePackages.add(basePackageClass.getPackage().getName());
}
}
}
}

return basePackages;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,12 @@
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import org.reflections.Reflections;
import org.reflections.util.ConfigurationBuilder;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;

import java.lang.reflect.Method;
Expand Down Expand Up @@ -60,13 +55,9 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
}
}

findAiServices(beanFactory).forEach(aiServiceClass -> {

if (beanFactory.getBeanNamesForType(aiServiceClass).length > 0) {
// User probably wants to configure AI Service bean manually
// TODO or better fail because user should not annotate it with @AiService then?
return;
}
String[] aiServices = beanFactory.getBeanNamesForAnnotation(AiService.class);
for (String aiService : aiServices) {
Class<?> aiServiceClass = beanFactory.getType(aiService);

GenericBeanDefinition aiServiceBeanDefinition = new GenericBeanDefinition();
aiServiceBeanDefinition.setBeanClass(AiServiceFactory.class);
Expand Down Expand Up @@ -144,21 +135,12 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
}

BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
registry.registerBeanDefinition(lowercaseFirstLetter(aiServiceClass.getSimpleName()), aiServiceBeanDefinition);
});
registry.removeBeanDefinition(aiService);
registry.registerBeanDefinition(lowercaseFirstLetter(aiService), aiServiceBeanDefinition);
}
};
}

private static Set<Class<?>> findAiServices(ConfigurableListableBeanFactory beanFactory) {
String[] applicationBean = beanFactory.getBeanNamesForAnnotation(SpringBootApplication.class);
BeanDefinition applicationBeanDefinition = beanFactory.getBeanDefinition(applicationBean[0]);
String basePackage = applicationBeanDefinition.getResolvableType().resolve().getPackage().getName();
Reflections reflections = new Reflections((new ConfigurationBuilder()).forPackage(basePackage));
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(AiService.class);
classes.removeIf(clazz -> !clazz.getName().startsWith(basePackage));
return classes;
}

private static void addBeanReference(Class<?> beanType,
AiService aiServiceAnnotation,
String customBeanName,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package dev.langchain4j.service.spring;

import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.context.annotation.ClassPathBeanDefinitionScanner;
import org.springframework.core.type.filter.AnnotationTypeFilter;

public class ClassPathAiServiceScanner extends ClassPathBeanDefinitionScanner {

public ClassPathAiServiceScanner(BeanDefinitionRegistry registry, boolean useDefaultFilters) {
super(registry, useDefaultFilters);
}

@Override
protected boolean isCandidateComponent(AnnotatedBeanDefinition beanDefinition) {
return beanDefinition.getMetadata().isInterface() && beanDefinition.getMetadata().isIndependent();
}

public void registerFilters() {
addIncludeFilter(new AnnotationTypeFilter(AiService.class));
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package dev.langchain4j.spring;

import dev.langchain4j.rag.spring.RagAutoConfig;
import dev.langchain4j.service.spring.AiServiceScannerProcessor;
import dev.langchain4j.service.spring.AiServicesAutoConfig;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.context.annotation.Import;

@AutoConfiguration
@Import({
AiServicesAutoConfig.class,
RagAutoConfig.class
RagAutoConfig.class,
AiServiceScannerProcessor.class
})
public class LangChain4jAutoConfig {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dev.langchain4j.service.spring.mode.automatic.differentPackage.package1;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class DifferentPackageAiServiceApplication {

public static void main(String[] args) {
SpringApplication.run(DifferentPackageAiServiceApplication.class, args);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package dev.langchain4j.service.spring.mode.automatic.differentPackage.package1;

import dev.langchain4j.service.spring.AiServicesAutoConfig;
import dev.langchain4j.service.spring.mode.automatic.differentPackage.package2.DifferentPackageAiService;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.ComponentScan;

import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY;
import static org.assertj.core.api.Assertions.assertThat;

class DifferentPackageAiServiceIT {

ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class));

@ComponentScan(value = "dev.langchain4j.service.spring.mode.automatic.differentPackage.package2")
static class ComponentScanWithValue {
}

@ComponentScan(basePackages = "dev.langchain4j.service.spring.mode.automatic.differentPackage.package2")
static class ComponentScanWithBasePackages {
}

@ComponentScan(basePackageClasses = dev.langchain4j.service.spring.mode.automatic.differentPackage.package2.DifferentPackageAiService.class)
static class ComponentScanWithBasePackageClasses {
}

@Test
void should_create_AI_service_that_use_componentScan_value() {

contextRunner
.withPropertyValues(
"langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY,
"langchain4j.open-ai.chat-model.max-tokens=20",
"langchain4j.open-ai.chat-model.temperature=0.0"
)
.withUserConfiguration(DifferentPackageAiServiceApplication.class)
.withUserConfiguration(ComponentScanWithValue.class)
.run(context -> {

// given
DifferentPackageAiService aiService = context.getBean(DifferentPackageAiService.class);

// when
String answer = aiService.chat("What is the capital of Germany?");

// then
assertThat(answer).containsIgnoringCase("Berlin");
});
}

@Test
void should_create_AI_service_that_use_componentScan_basePackages() {

contextRunner
.withPropertyValues(
"langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY,
"langchain4j.open-ai.chat-model.max-tokens=20",
"langchain4j.open-ai.chat-model.temperature=0.0"
)
.withUserConfiguration(DifferentPackageAiServiceApplication.class)
.withUserConfiguration(ComponentScanWithBasePackages.class)
.run(context -> {

// given
DifferentPackageAiService aiService = context.getBean(DifferentPackageAiService.class);

// when
String answer = aiService.chat("What is the capital of Germany?");

// then
assertThat(answer).containsIgnoringCase("Berlin");
});
}

@Test
void should_create_AI_service_that_use_componentScan_basePackageClasses() {

contextRunner
.withPropertyValues(
"langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY,
"langchain4j.open-ai.chat-model.max-tokens=20",
"langchain4j.open-ai.chat-model.temperature=0.0"
)
.withUserConfiguration(DifferentPackageAiServiceApplication.class)
.withUserConfiguration(ComponentScanWithBasePackageClasses.class)
.run(context -> {

// given
DifferentPackageAiService aiService = context.getBean(DifferentPackageAiService.class);

// when
String answer = aiService.chat("What is the capital of Germany?");

// then
assertThat(answer).containsIgnoringCase("Berlin");
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package dev.langchain4j.service.spring.mode.automatic.differentPackage.package2;

import dev.langchain4j.service.spring.AiService;

@AiService
public interface DifferentPackageAiService {

String chat(String userMessage);
}

0 comments on commit 34d0962

Please sign in to comment.