Skip to content

Commit

Permalink
[FEATURE] support tools enhanced by AOP (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
catofdestruction authored Nov 22, 2024
1 parent 10c6cdc commit f934f16
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 2 deletions.
7 changes: 7 additions & 0 deletions langchain4j-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
<version>${spring.boot.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dev.langchain4j.service.spring;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand All @@ -8,11 +10,20 @@
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.tool.DefaultToolExecutor;
import dev.langchain4j.service.tool.ToolExecutor;
import org.springframework.beans.factory.FactoryBean;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static org.springframework.aop.framework.AopProxyUtils.ultimateTargetClass;
import static org.springframework.aop.support.AopUtils.isAopProxy;

class AiServiceFactory implements FactoryBean<Object> {

Expand Down Expand Up @@ -94,7 +105,13 @@ public Object getObject() {
}

if (!isNullOrEmpty(tools)) {
builder = builder.tools(tools);
for (Object tool : tools) {
if (isAopProxy(tool)) {
builder = builder.tools(aopEnhancedTools(tool));
} else {
builder = builder.tools(tool);
}
}
}

return builder.build();
Expand All @@ -120,4 +137,21 @@ public boolean isSingleton() {
* (such as java.io.Closeable.close()) will not be called automatically.
* Instead, a FactoryBean should implement DisposableBean and delegate any such close call to the underlying object.
*/

private Map<ToolSpecification, ToolExecutor> aopEnhancedTools(Object enhancedTool) {
Map<ToolSpecification, ToolExecutor> toolExecutors = new HashMap<>();
Class<?> originalToolClass = ultimateTargetClass(enhancedTool);
for (Method originalToolMethod : originalToolClass.getDeclaredMethods()) {
if (originalToolMethod.isAnnotationPresent(Tool.class)) {
Arrays.stream(enhancedTool.getClass().getDeclaredMethods())
.filter(m -> m.getName().equals(originalToolMethod.getName()))
.findFirst()
.ifPresent(enhancedMethod -> {
ToolSpecification toolSpecification = toolSpecificationFrom(originalToolMethod);
toolExecutors.put(toolSpecification, new DefaultToolExecutor(enhancedTool, enhancedMethod));
});
}
}
return toolExecutors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
Set<String> tools = new HashSet<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
Class<?> beanClass = Class.forName(beanFactory.getBeanDefinition(beanName).getBeanClassName());
String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName();
if (beanClassName == null) {
continue;
}
Class<?> beanClass = Class.forName(beanClassName);
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
tools.add(beanName);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.service.spring.AiServicesAutoConfig;
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserverAspect;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY_NAME_DESCRIPTION;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION;
import static dev.langchain4j.service.spring.mode.automatic.withTools.PackagePrivateTools.CURRENT_TIME;
import static dev.langchain4j.service.spring.mode.automatic.withTools.PublicTools.CURRENT_DATE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

class AiServicesAutoConfigIT {

Expand Down Expand Up @@ -61,6 +69,46 @@ void should_create_AI_service_with_tool_that_is_package_private_method_in_packag
});
}

@Test
void should_create_AI_service_with_tool_which_is_enhanced_by_spring_aop() {
contextRunner
.withPropertyValues(
"langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY,
"langchain4j.open-ai.chat-model.temperature=0.0",
"langchain4j.open-ai.chat-model.log-requests=true",
"langchain4j.open-ai.chat-model.log-responses=true"
)
.withUserConfiguration(AiServiceWithToolsApplication.class)
.run(context -> {

// given
AiServiceWithTools aiService = context.getBean(AiServiceWithTools.class);

// when
String answer = aiService.chat("Which package is the @ToolObserver annotation located in? " +
"And what is the key of the @ToolObserver annotation?" +
"And What is the current time?");

System.out.println("Answer: " + answer);

// then should use AopEnhancedTools.getAspectPackage()
// & AopEnhancedTools.getToolObserverKey()
// & PackagePrivateTools.getCurrentTime()
assertThat(answer).contains(TOOL_OBSERVER_PACKAGE_NAME);
assertThat(answer).contains(TOOL_OBSERVER_KEY);
assertThat(answer).contains(String.valueOf(CURRENT_TIME.getMinute()));

// and AOP aspect should be called
// & only for getToolObserverKey() which is annotated with @ToolObserver
ToolObserverAspect aspect = context.getBean(ToolObserverAspect.class);
assertTrue(aspect.aspectHasBeenCalled());

assertEquals(1, aspect.getObservedTools().size());
assertTrue(aspect.getObservedTools().contains(TOOL_OBSERVER_KEY_NAME_DESCRIPTION));
assertFalse(aspect.getObservedTools().contains(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION));
});
}

// TODO tools which are not @Beans?
// TODO negative cases
// TODO no @AiServices in app, just models
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserver;
import org.springframework.stereotype.Component;

@Component
public class AopEnhancedTools {

public static final String TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION =
"Find the package directory where @ToolObserver is located.";
public static final String TOOL_OBSERVER_PACKAGE_NAME = ToolObserver.class.getPackageName();

public static final String TOOL_OBSERVER_KEY_NAME_DESCRIPTION =
"Find the key name of @ToolObserver";
public static final String TOOL_OBSERVER_KEY = "AOP_ENHANCED_TOOLS_SUPPORT_@_1122";

@Tool(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION)
public String getToolObserverPackageName() {
return TOOL_OBSERVER_PACKAGE_NAME;
}

@ToolObserver(key = TOOL_OBSERVER_KEY)
@Tool(TOOL_OBSERVER_KEY_NAME_DESCRIPTION)
public String getToolObserverKey() {
return TOOL_OBSERVER_KEY;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package dev.langchain4j.service.spring.mode.automatic.withTools.aop;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface ToolObserver {

/**
* key just for example
*
* @return the key
*/
String key();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package dev.langchain4j.service.spring.mode.automatic.withTools.aop;

import dev.langchain4j.agent.tool.Tool;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

@Aspect
@Component
public class ToolObserverAspect {

private final List<String> observedTools = new ArrayList<>();

@Around("@annotation(toolObserver)")
public Object around(ProceedingJoinPoint joinPoint, ToolObserver toolObserver) throws Throwable {
var signature = (MethodSignature) joinPoint.getSignature();
var method = signature.getMethod();
String methodName = method.getName();
if (method.isAnnotationPresent(Tool.class)) {
Tool toolAnnotation = method.getAnnotation(Tool.class);
observedTools.addAll(Arrays.asList(toolAnnotation.value()));
System.out.printf("Found @Tool %s for method: %s%n%n", Arrays.toString(toolAnnotation.value()), methodName);
}
Object result = joinPoint.proceed();
System.out.printf(" | key: %s%n | Method name: %s%n | Method arguments: %s%n | Return type: %s%n | Method return value: %s%n%n",
toolObserver.key(),
methodName,
Arrays.toString(joinPoint.getArgs()),
method.getReturnType().getName(),
result);
return result;
}

public boolean aspectHasBeenCalled() {
return !observedTools.isEmpty();
}

public List<String> getObservedTools() {
return observedTools;
}
}

0 comments on commit f934f16

Please sign in to comment.