diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 80fe11b7174b..8fbf12b16b41 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -9227,7 +9227,7 @@ "methods" : [ "public void (ai.vespa.search.llm.LlmSearcherConfig, com.yahoo.component.provider.ComponentRegistry)", "public com.yahoo.search.Result search(com.yahoo.search.Query, com.yahoo.search.searchchain.Execution)", - "protected com.yahoo.search.Result complete(com.yahoo.search.Query, ai.vespa.llm.completion.Prompt)", + "protected com.yahoo.search.Result complete(com.yahoo.search.Query, ai.vespa.llm.completion.Prompt, com.yahoo.search.Result, com.yahoo.search.searchchain.Execution)", "public java.lang.String getPrompt(com.yahoo.search.Query)", "public java.lang.String getPropertyPrefix()", "public java.lang.String lookupProperty(java.lang.String, com.yahoo.search.Query)", diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java index 4c39506ed961..d0d2cd3a4429 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java @@ -14,11 +14,14 @@ import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; +import com.yahoo.search.rendering.JsonRenderer; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.EventStream; import com.yahoo.search.result.HitGroup; import com.yahoo.search.searchchain.Execution; +import com.yahoo.text.Utf8; +import java.io.ByteArrayOutputStream; import java.util.List; import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; @@ -38,6 +41,10 @@ public class LLMSearcher extends Searcher { private static final String API_KEY_HEADER = "X-LLM-API-KEY"; private static final String STREAM_PROPERTY = "stream"; private static final String PROMPT_PROPERTY = "prompt"; + private static final String INCLUDE_PROMPT_IN_RESULT = "includePrompt"; + private static final String INCLUDE_HITS_IN_RESULT = "includeHits"; + + private final JsonRenderer jsonRenderer; private final String propertyPrefix; private final boolean stream; @@ -50,11 +57,13 @@ public LLMSearcher(LlmSearcherConfig config, ComponentRegistry la this.languageModelId = config.providerId(); this.languageModel = findLanguageModel(languageModelId, languageModels); this.propertyPrefix = config.propertyPrefix(); + + this.jsonRenderer = new JsonRenderer(); } @Override public Result search(Query query, Execution execution) { - return complete(query, StringPrompt.from(getPrompt(query))); + return complete(query, StringPrompt.from(getPrompt(query)), null, execution); } private LanguageModel findLanguageModel(String providerId, ComponentRegistry languageModels) @@ -81,30 +90,37 @@ private LanguageModel findLanguageModel(String providerId, ComponentRegistry lookupProperty(s, query)); var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config try { - return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + if (stream) { + return completeAsync(query, prompt, options, result, execution); + } + return completeSync(query, prompt, options, result, execution); } catch (RejectedExecutionException e) { return new Result(query, new ErrorMessage(429, e.getMessage())); } } private boolean shouldAddPrompt(Query query) { - return query.getTrace().getLevel() >= 1; + var includePrompt = lookupPropertyBool(INCLUDE_PROMPT_IN_RESULT, query, false); + return query.getTrace().getLevel() >= 1 || includePrompt; } private boolean shouldAddTokenStats(Query query) { return query.getTrace().getLevel() >= 1; } - private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { + private Result completeAsync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) { final EventStream eventStream = new EventStream(); if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } + if (shouldAddHits(query) && result != null) { + eventStream.add(renderHits(result, execution), "hits"); + } final TokenStats tokenStats = new TokenStats(); languageModel.completeAsync(prompt, options, completion -> { @@ -143,12 +159,15 @@ private void handleException(EventStream eventStream, Throwable exception) { eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); } - private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { + private Result completeSync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) { EventStream eventStream = new EventStream(); if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } + if (shouldAddHits(query) && result != null) { + eventStream.add(renderHits(result, execution), "hits"); + } List completions = languageModel.complete(prompt, options); eventStream.add(completions.get(0).text(), "completion"); @@ -200,6 +219,18 @@ public String getApiKeyHeader(Query query) { return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); } + private boolean shouldAddHits(Query query) { + return lookupPropertyBool(INCLUDE_HITS_IN_RESULT, query, false); + } + + private String renderHits(Result results, Execution execution) { + var bs = new ByteArrayOutputStream(); + var renderer = jsonRenderer.clone(); + renderer.init(); + renderer.renderResponse(bs, results, execution, null).join(); // wait for renderer to complete + return Utf8.toString(bs.toByteArray()); + } + private static class TokenStats { private final long start; diff --git a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java index cba153d881d6..cdf57922bce2 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java @@ -37,7 +37,7 @@ public RAGSearcher(LlmSearcherConfig config, ComponentRegistry la public Result search(Query query, Execution execution) { Result result = execution.search(query); execution.fill(result); - return complete(query, buildPrompt(query, result)); + return complete(query, buildPrompt(query, result), result, execution); } protected Prompt buildPrompt(Query query, Result result) { diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java index 88a1e6c1485e..ffbb63514f1b 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -79,13 +79,16 @@ public void data(Data data) throws IOException { generator.writeRaw("event: " + event.type() + "\n"); } generator.writeRaw("data: "); - generator.writeStartObject(); - generator.writeStringField(event.type(), event.toString()); - generator.writeEndObject(); + if (event.type().equals("hits")) { + generator.writeRaw(event.toString()); + } else { + generator.writeStartObject(); + generator.writeStringField(event.type(), event.toString()); + generator.writeEndObject(); + } generator.writeRaw("\n\n"); generator.flush(); } - // Todo: support other types of data such as search results (hits), timing and trace } @Override diff --git a/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java index d6b66b1a8c6f..13b5f540a3af 100755 --- a/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java +++ b/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java @@ -115,7 +115,7 @@ static Result runMockSearch(Searcher searcher, Map parameters) { return execution.search(query); } - private static Searcher createRAGSearcher(Map llms) { + static Searcher createRAGSearcher(Map llms) { var config = new LlmSearcherConfig.Builder().stream(false).build(); ComponentRegistry models = new ComponentRegistry<>(); llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); diff --git a/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java b/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java new file mode 100644 index 000000000000..40aba0d5b6ac --- /dev/null +++ b/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java @@ -0,0 +1,77 @@ +package ai.vespa.search.llm; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.search.Result; +import com.yahoo.search.rendering.EventRenderer; +import com.yahoo.search.searchchain.Execution; +import com.yahoo.text.Utf8; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayOutputStream; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class RAGWithEventRendererTest { + + @Test + public void testPromptAndHitsAreRendered() throws Exception { + var params = Map.of( + "query", "why are ducks better than cats?", + "llm.stream", "false", + "llm.includePrompt", "true", + "llm.includeHits", "true" + ); + var llm = LLMSearcherTest.createLLMClient(); + var searcher = RAGSearcherTest.createRAGSearcher(Map.of("mock", llm)); + var results = RAGSearcherTest.runMockSearch(searcher, params); + + var result = render(results); + + var promptEvent = extractEvent(result, "prompt"); + assertNotNull(promptEvent); + assertTrue(promptEvent.has("prompt")); + + var resultsEvent = extractEvent(result, "hits"); + assertNotNull(resultsEvent); + assertTrue(resultsEvent.has("root")); + assertEquals(2, resultsEvent.get("root").get("children").size()); + } + + private JsonNode extractEvent(String result, String eventName) throws JsonProcessingException { + var lines = result.split("\n"); + for (int i = 0; i < lines.length; i++) { + if (lines[i].startsWith("event: " + eventName)) { + var data = lines[i + 1].substring("data: ".length()).trim(); + ObjectMapper objectMapper = new ObjectMapper(); + return objectMapper.readTree(data); + } + } + return null; + } + + private String render(Result r) throws InterruptedException, ExecutionException { + var execution = new Execution(Execution.Context.createContextStub()); + return render(execution, r); + } + + private String render(Execution execution, Result r) throws ExecutionException, InterruptedException { + var renderer = new EventRenderer(); + try { + renderer.init(); + ByteArrayOutputStream bs = new ByteArrayOutputStream(); + CompletableFuture f = renderer.renderResponse(bs, r, execution, null); + assertTrue(f.get()); + return Utf8.toString(bs.toByteArray()); + } finally { + renderer.deconstruct(); + } + } + +} diff --git a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java index 2cfb6552379d..f6f6f40bdaeb 100644 --- a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java @@ -232,7 +232,7 @@ public void testResultRenderingIsSkipped() throws ExecutionException, Interrupte event: end """; - assertEquals(expected, result); // Todo: support other types of data such as search results (hits), timing and trace + assertEquals(expected, result); } static HitGroup newHitGroup(EventStream eventStream, String id) {