Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j committed Sep 15, 2024
1 parent 8ae6b5b commit fd8ed4b
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package dev.langchain4j.reactor;

import dev.langchain4j.service.TokenStream;
import dev.langchain4j.spi.services.TokenStreamAdapter;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;

import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;

public class TokenStreamToFluxAdapter implements TokenStreamAdapter {

@Override
public boolean canAdaptTokenStreamTo(Type type) {
if (type instanceof ParameterizedType parameterizedType) {
if (parameterizedType.getRawType() == Flux.class) {
Type[] typeArguments = parameterizedType.getActualTypeArguments();
return typeArguments.length == 1 && typeArguments[0] == String.class;
}
}
return false;
}

@Override
public Object adapt(TokenStream tokenStream) {
Sinks.Many<String> sink = Sinks.many().unicast().onBackpressureBuffer();
tokenStream.onNext(sink::tryEmitNext)
.onComplete(aiMessageResponse -> sink.tryEmitComplete())
.onError(sink::tryEmitError)
.start();
return sink.asFlux();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dev.langchain4j.reactor.TokenStreamToFluxAdapter
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package dev.langchain4j.reactor;

import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.mock.StreamingChatModelMock;
import dev.langchain4j.service.AiServices;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;

import java.util.List;

public class AiServiceWithFluxTest {

interface Assistant {

Flux<String> stream(String userMessage);
}

@Test
void should_stream() {

// given
List<String> tokens = List.of("The", " capital", " of", " Germany", " is", " Berlin", ".");

StreamingChatLanguageModel model = StreamingChatModelMock.thatAlwaysStreams(tokens);

Assistant assistant = AiServices.builder(Assistant.class)
.streamingChatLanguageModel(model)
.build();

// when
Flux<String> flux = assistant.stream("What is the capital of Germany?");

// then
StepVerifier.create(flux)
.expectNextSequence(tokens)
.verifyComplete();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package dev.langchain4j.reactor;

import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;

import java.lang.reflect.Type;

import static org.assertj.core.api.Assertions.assertThat;

class TokenStreamToFluxAdapterTest {

interface Assistant {

Flux<String> fluxOfString();

Flux flux();

Flux<Object> fluxOfObject();
}

@Test
void test_canAdapt() {

TokenStreamToFluxAdapter adapter = new TokenStreamToFluxAdapter();

assertThat(adapter.canAdaptTokenStreamTo(getReturnTypeOfMethod("fluxOfString"))).isTrue();

assertThat(adapter.canAdaptTokenStreamTo(getReturnTypeOfMethod("flux"))).isFalse();
assertThat(adapter.canAdaptTokenStreamTo(getReturnTypeOfMethod("fluxOfObject"))).isFalse();
}

private static Type getReturnTypeOfMethod(String methodName) {
try {
return Assistant.class.getDeclaredMethod(methodName).getGenericReturnType();
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,4 @@
* this attribute specifies the names of beans containing methods annotated with {@link Tool} that should be used by this AI Service.
*/
String[] tools() default {};

// TODO support Flux return type for AI Service method(s) (for streaming)
}
8 changes: 8 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
<module>langchain4j-redis-spring-boot-starter</module>
<module>langchain4j-qianfan-spring-boot-starter</module>
<module>langchain4j-milvus-spring-boot-starter</module>

<module>langchain4j-reactor</module>
</modules>

<properties>
Expand Down Expand Up @@ -56,6 +58,12 @@
<version>${spring.boot.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
<version>${spring.boot.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-autoconfigure-processor</artifactId>
Expand Down

0 comments on commit fd8ed4b

Please sign in to comment.