Skip to content

Commit

Permalink
Configuring ReRankingContentAggregator to Return Top Results (langcha…
Browse files Browse the repository at this point in the history
…in4j#2043)

## Issue
Closes langchain4j#1944 

## Change
Added `maxResults` setting to `ReRankingContentAggregator`

## General checklist
- [ ] There are no breaking changes
- [ ] I have added unit and integration tests for my change
- [ ] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
omarmahamid authored Nov 8, 2024
1 parent 13d024e commit a48bf2c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public class ReRankingContentAggregator implements ContentAggregator {
private final ScoringModel scoringModel;
private final Function<Map<Query, Collection<List<Content>>>, Query> querySelector;
private final Double minScore;
private final Integer maxResults;

public ReRankingContentAggregator(ScoringModel scoringModel) {
this(scoringModel, DEFAULT_QUERY_SELECTOR, null);
Expand All @@ -70,9 +71,17 @@ public ReRankingContentAggregator(ScoringModel scoringModel) {
public ReRankingContentAggregator(ScoringModel scoringModel,
Function<Map<Query, Collection<List<Content>>>, Query> querySelector,
Double minScore) {
this(scoringModel, querySelector, minScore, null);
}

public ReRankingContentAggregator(ScoringModel scoringModel,
Function<Map<Query, Collection<List<Content>>>, Query> querySelector,
Double minScore,
Integer maxResults) {
this.scoringModel = ensureNotNull(scoringModel, "scoringModel");
this.querySelector = getOrDefault(querySelector, DEFAULT_QUERY_SELECTOR);
this.minScore = minScore;
this.maxResults = getOrDefault(maxResults, Integer.MAX_VALUE);
}

public static ReRankingContentAggregatorBuilder builder() {
Expand Down Expand Up @@ -130,13 +139,15 @@ protected List<Content> reRankAndFilter(List<Content> contents, Query query) {
.sorted(Map.Entry.<TextSegment, Double>comparingByValue().reversed())
.map(Map.Entry::getKey)
.map(Content::from)
.limit(maxResults)
.collect(toList());
}

public static class ReRankingContentAggregatorBuilder {
private ScoringModel scoringModel;
private Function<Map<Query, Collection<List<Content>>>, Query> querySelector;
private Double minScore;
private Integer maxResults;

ReRankingContentAggregatorBuilder() {
}
Expand All @@ -156,8 +167,13 @@ public ReRankingContentAggregatorBuilder minScore(Double minScore) {
return this;
}

public ReRankingContentAggregatorBuilder maxResults(Integer maxResults) {
this.maxResults = maxResults;
return this;
}

public ReRankingContentAggregator build() {
return new ReRankingContentAggregator(this.scoringModel, this.querySelector, this.minScore);
return new ReRankingContentAggregator(this.scoringModel, this.querySelector, this.minScore, this.maxResults);
}

public String toString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,78 @@ void should_fuse_then_rerank_against_first_query_then_filter_by_minScore() {
.containsExactly(content1, content8, content2);
}

@Test
void test_should_got_max_results() {

// given
Function<Map<Query, Collection<List<Content>>>, Query> querySelector =
(q) -> q.keySet().iterator().next(); // always selects first query

double minScore = 0.4;

Query query1 = Query.from("query 1");

Content content1 = Content.from("content");
Content content2 = Content.from("content 2");

Content content3 = Content.from("content 3");
Content content4 = Content.from("content");

Query query2 = Query.from("query 2");

Content content5 = Content.from("content 5");
Content content6 = Content.from("content");

Content content7 = Content.from("content");
Content content8 = Content.from("content 8");

// LinkedHashMap is used to ensure a predictable order in the test
Map<Query, Collection<List<Content>>> queryToContents = new LinkedHashMap<>();
queryToContents.put(query1, asList(
asList(content1, content2),
asList(content3, content4)

));
queryToContents.put(query2, asList(
asList(content5, content6),
asList(content7, content8)
));

ScoringModel scoringModel = mock(ScoringModel.class);
when(scoringModel.scoreAll(
asList(
content1.textSegment(),
content3.textSegment(),
content5.textSegment(),
content2.textSegment(),
content8.textSegment()
), query1.text())).thenReturn(Response.from(
asList(
0.6,
0.2,
0.3,
0.4,
0.5
)));

ContentAggregator aggregator = ReRankingContentAggregator.builder()
.scoringModel(scoringModel)
.querySelector(querySelector)
.minScore(minScore)
.maxResults(2)
.build();

// when
List<Content> aggregated = aggregator.aggregate(queryToContents);

// then
assertThat(aggregated)
// content4, content6, content7 were fused with content1
// content3 and content5 were filtered out by minScore
// count2 filtered by maxResults
.containsExactly(content1, content8);
}

@ParameterizedTest
@MethodSource
void should_return_empty_list_when_there_is_no_content_to_rerank(
Expand Down

0 comments on commit a48bf2c

Please sign in to comment.