From c35934e06291a765dec85d2f2a77a64d3973b97a Mon Sep 17 00:00:00 2001 From: Yan Savitski Date: Wed, 6 Nov 2024 13:06:45 +0100 Subject: [PATCH] [Search] [Playground] [Bug] Remove token clipping (#199055) - Remove token pruning functionality as this has a large cost, causing OOMs on serverless. - make the default model for openai gpt-4o - when the context is over the model limit, show a better error to the user for this - Update tests --------- Co-authored-by: Joseph McElroy --- .../search_playground/common/models.ts | 16 +- .../public/hooks/use_llms_models.test.ts | 22 +-- .../server/lib/conversational_chain.test.ts | 169 +++++++----------- .../server/lib/conversational_chain.ts | 38 ++-- .../search_playground/server/lib/errors.ts | 18 ++ .../search_playground/server/routes.test.ts | 25 +++ .../search_playground/server/routes.ts | 17 ++ .../page_objects/search_playground_page.ts | 2 +- 8 files changed, 161 insertions(+), 146 deletions(-) create mode 100644 x-pack/plugins/search_playground/server/lib/errors.ts diff --git a/x-pack/plugins/search_playground/common/models.ts b/x-pack/plugins/search_playground/common/models.ts index ca27c29e20533..85bf5ddfb0970 100644 --- a/x-pack/plugins/search_playground/common/models.ts +++ b/x-pack/plugins/search_playground/common/models.ts @@ -8,12 +8,6 @@ import { ModelProvider, LLMs } from './types'; export const MODELS: ModelProvider[] = [ - { - name: 'OpenAI GPT-3.5 Turbo', - model: 'gpt-3.5-turbo', - promptTokenLimit: 16385, - provider: LLMs.openai, - }, { name: 'OpenAI GPT-4o', model: 'gpt-4o', @@ -26,6 +20,12 @@ export const MODELS: ModelProvider[] = [ promptTokenLimit: 128000, provider: LLMs.openai, }, + { + name: 'OpenAI GPT-3.5 Turbo', + model: 'gpt-3.5-turbo', + promptTokenLimit: 16385, + provider: LLMs.openai, + }, { name: 'Anthropic Claude 3 Haiku', model: 'anthropic.claude-3-haiku-20240307-v1:0', @@ -40,13 +40,13 @@ export const MODELS: ModelProvider[] = [ }, { name: 'Google Gemini 1.5 Pro', - model: 'gemini-1.5-pro-001', + model: 'gemini-1.5-pro-002', promptTokenLimit: 2097152, provider: LLMs.gemini, }, { name: 'Google Gemini 1.5 Flash', - model: 'gemini-1.5-flash-001', + model: 'gemini-1.5-flash-002', promptTokenLimit: 2097152, provider: LLMs.gemini, }, diff --git a/x-pack/plugins/search_playground/public/hooks/use_llms_models.test.ts b/x-pack/plugins/search_playground/public/hooks/use_llms_models.test.ts index ebce3883a471b..c529a9d4b9aa6 100644 --- a/x-pack/plugins/search_playground/public/hooks/use_llms_models.test.ts +++ b/x-pack/plugins/search_playground/public/hooks/use_llms_models.test.ts @@ -41,11 +41,11 @@ describe('useLLMsModels Hook', () => { connectorType: LLMs.openai, disabled: false, icon: expect.any(Function), - id: 'connectorId1OpenAI GPT-3.5 Turbo ', - name: 'OpenAI GPT-3.5 Turbo ', + id: 'connectorId1OpenAI GPT-4o ', + name: 'OpenAI GPT-4o ', showConnectorName: false, - value: 'gpt-3.5-turbo', - promptTokenLimit: 16385, + value: 'gpt-4o', + promptTokenLimit: 128000, }, { connectorId: 'connectorId1', @@ -53,10 +53,10 @@ describe('useLLMsModels Hook', () => { connectorType: LLMs.openai, disabled: false, icon: expect.any(Function), - id: 'connectorId1OpenAI GPT-4o ', - name: 'OpenAI GPT-4o ', + id: 'connectorId1OpenAI GPT-4 Turbo ', + name: 'OpenAI GPT-4 Turbo ', showConnectorName: false, - value: 'gpt-4o', + value: 'gpt-4-turbo', promptTokenLimit: 128000, }, { @@ -65,11 +65,11 @@ describe('useLLMsModels Hook', () => { connectorType: LLMs.openai, disabled: false, icon: expect.any(Function), - id: 'connectorId1OpenAI GPT-4 Turbo ', - name: 'OpenAI GPT-4 Turbo ', + id: 'connectorId1OpenAI GPT-3.5 Turbo ', + name: 'OpenAI GPT-3.5 Turbo ', showConnectorName: false, - value: 'gpt-4-turbo', - promptTokenLimit: 128000, + value: 'gpt-3.5-turbo', + promptTokenLimit: 16385, }, { connectorId: 'connectorId2', diff --git a/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts b/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts index 13959e4455c29..5a56598e7387b 100644 --- a/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts +++ b/x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts @@ -9,9 +9,8 @@ import type { Client } from '@elastic/elasticsearch'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { ChatPromptTemplate } from '@langchain/core/prompts'; import { FakeListChatModel, FakeStreamingLLM } from '@langchain/core/utils/testing'; -import { experimental_StreamData } from 'ai'; import { createAssist as Assist } from '../utils/assist'; -import { ConversationalChain, clipContext } from './conversational_chain'; +import { ConversationalChain, contextLimitCheck } from './conversational_chain'; import { ChatMessage, MessageRole } from '../types'; describe('conversational chain', () => { @@ -30,16 +29,20 @@ describe('conversational chain', () => { }: { responses: string[]; chat: ChatMessage[]; - expectedFinalAnswer: string; - expectedDocs: any; - expectedTokens: any; - expectedSearchRequest: any; + expectedFinalAnswer?: string; + expectedDocs?: any; + expectedTokens?: any; + expectedSearchRequest?: any; contentField?: Record; isChatModel?: boolean; docs?: any; expectedHasClipped?: boolean; modelLimit?: number; }) => { + if (expectedHasClipped) { + expect.assertions(1); + } + const searchMock = jest.fn().mockImplementation(() => { return { hits: { @@ -101,44 +104,52 @@ describe('conversational chain', () => { questionRewritePrompt: 'rewrite question {question} using {context}"', }); - const stream = await conversationalChain.stream(aiClient, chat); + try { + const stream = await conversationalChain.stream(aiClient, chat); - const streamToValue: string[] = await new Promise((resolve, reject) => { - const reader = stream.getReader(); - const textDecoder = new TextDecoder(); - const chunks: string[] = []; + const streamToValue: string[] = await new Promise((resolve, reject) => { + const reader = stream.getReader(); + const textDecoder = new TextDecoder(); + const chunks: string[] = []; - const read = () => { - reader.read().then(({ done, value }) => { - if (done) { - resolve(chunks); - } else { - chunks.push(textDecoder.decode(value)); - read(); - } - }, reject); - }; - read(); - }); + const read = () => { + reader.read().then(({ done, value }) => { + if (done) { + resolve(chunks); + } else { + chunks.push(textDecoder.decode(value)); + read(); + } + }, reject); + }; + read(); + }); - const textValue = streamToValue - .filter((v) => v[0] === '0') - .reduce((acc, v) => acc + v.replace(/0:"(.*)"\n/, '$1'), ''); - expect(textValue).toEqual(expectedFinalAnswer); + const textValue = streamToValue + .filter((v) => v[0] === '0') + .reduce((acc, v) => acc + v.replace(/0:"(.*)"\n/, '$1'), ''); + expect(textValue).toEqual(expectedFinalAnswer); - const annotations = streamToValue - .filter((v) => v[0] === '8') - .map((entry) => entry.replace(/8:(.*)\n/, '$1'), '') - .map((entry) => JSON.parse(entry)) - .reduce((acc, v) => acc.concat(v), []); + const annotations = streamToValue + .filter((v) => v[0] === '8') + .map((entry) => entry.replace(/8:(.*)\n/, '$1'), '') + .map((entry) => JSON.parse(entry)) + .reduce((acc, v) => acc.concat(v), []); - const docValues = annotations.filter((v: { type: string }) => v.type === 'retrieved_docs'); - const tokens = annotations.filter((v: { type: string }) => v.type.endsWith('_token_count')); - const hasClipped = !!annotations.some((v: { type: string }) => v.type === 'context_clipped'); - expect(docValues).toEqual(expectedDocs); - expect(tokens).toEqual(expectedTokens); - expect(hasClipped).toEqual(expectedHasClipped); - expect(searchMock.mock.calls[0]).toEqual(expectedSearchRequest); + const docValues = annotations.filter((v: { type: string }) => v.type === 'retrieved_docs'); + const tokens = annotations.filter((v: { type: string }) => v.type.endsWith('_token_count')); + const hasClipped = !!annotations.some((v: { type: string }) => v.type === 'context_clipped'); + expect(docValues).toEqual(expectedDocs); + expect(tokens).toEqual(expectedTokens); + expect(hasClipped).toEqual(expectedHasClipped); + expect(searchMock.mock.calls[0]).toEqual(expectedSearchRequest); + } catch (error) { + if (expectedHasClipped) { + expect(error).toMatchInlineSnapshot(`[ContextLimitError: Context exceeds the model limit]`); + } else { + throw error; + } + } }; it('should be able to create a conversational chain', async () => { @@ -470,102 +481,56 @@ describe('conversational chain', () => { }, ], modelLimit: 100, - expectedFinalAnswer: 'the final answer', - expectedDocs: [ - { - documents: [ - { - metadata: { _id: '1', _index: 'index' }, - pageContent: expect.any(String), - }, - { - metadata: { _id: '1', _index: 'website' }, - pageContent: expect.any(String), - }, - ], - type: 'retrieved_docs', - }, - ], - // Even with body_content of 1000, the token count should be below or equal to model limit of 100 - expectedTokens: [ - { type: 'context_token_count', count: 63 }, - { type: 'prompt_token_count', count: 97 }, - ], expectedHasClipped: true, - expectedSearchRequest: [ - { - method: 'POST', - path: '/index,website/_search', - body: { query: { match: { field: 'rewrite "the" question' } }, size: 3 }, - }, - ], isChatModel: false, }); }, 10000); - describe('clipContext', () => { + describe('contextLimitCheck', () => { const prompt = ChatPromptTemplate.fromTemplate( 'you are a QA bot {question} {chat_history} {context}' ); + afterEach(() => { + jest.clearAllMocks(); + }); + it('should return the input as is if modelLimit is undefined', async () => { const input = { context: 'This is a test context.', question: 'This is a test question.', chat_history: 'This is a test chat history.', }; + jest.spyOn(prompt, 'format'); + const result = await contextLimitCheck(undefined, prompt)(input); - const data = new experimental_StreamData(); - const appendMessageAnnotationSpy = jest.spyOn(data, 'appendMessageAnnotation'); - - const result = await clipContext(undefined, prompt, data)(input); - expect(result).toEqual(input); - expect(appendMessageAnnotationSpy).not.toHaveBeenCalled(); + expect(result).toBe(input); + expect(prompt.format).not.toHaveBeenCalled(); }); - it('should not clip context if within modelLimit', async () => { + it('should return the input if within modelLimit', async () => { const input = { context: 'This is a test context.', question: 'This is a test question.', chat_history: 'This is a test chat history.', }; - const data = new experimental_StreamData(); - const appendMessageAnnotationSpy = jest.spyOn(data, 'appendMessageAnnotation'); - const result = await clipContext(10000, prompt, data)(input); + jest.spyOn(prompt, 'format'); + const result = await contextLimitCheck(10000, prompt)(input); expect(result).toEqual(input); - expect(appendMessageAnnotationSpy).not.toHaveBeenCalled(); + expect(prompt.format).toHaveBeenCalledWith(input); }); it('should clip context if exceeds modelLimit', async () => { + expect.assertions(1); const input = { context: 'This is a test context.\nThis is another line.\nAnd another one.', question: 'This is a test question.', chat_history: 'This is a test chat history.', }; - const data = new experimental_StreamData(); - const appendMessageAnnotationSpy = jest.spyOn(data, 'appendMessageAnnotation'); - const result = await clipContext(33, prompt, data)(input); - expect(result.context).toBe('This is a test context.\nThis is another line.'); - expect(appendMessageAnnotationSpy).toHaveBeenCalledWith({ - type: 'context_clipped', - count: 4, - }); - }); - it('exit when context becomes empty', async () => { - const input = { - context: 'This is a test context.\nThis is another line.\nAnd another one.', - question: 'This is a test question.', - chat_history: 'This is a test chat history.', - }; - const data = new experimental_StreamData(); - const appendMessageAnnotationSpy = jest.spyOn(data, 'appendMessageAnnotation'); - const result = await clipContext(1, prompt, data)(input); - expect(result.context).toBe(''); - expect(appendMessageAnnotationSpy).toHaveBeenCalledWith({ - type: 'context_clipped', - count: 15, - }); + await expect(contextLimitCheck(33, prompt)(input)).rejects.toMatchInlineSnapshot( + `[ContextLimitError: Context exceeds the model limit]` + ); }); }); }); diff --git a/x-pack/plugins/search_playground/server/lib/conversational_chain.ts b/x-pack/plugins/search_playground/server/lib/conversational_chain.ts index 922f672bda5c6..dcd1f4189bc75 100644 --- a/x-pack/plugins/search_playground/server/lib/conversational_chain.ts +++ b/x-pack/plugins/search_playground/server/lib/conversational_chain.ts @@ -25,6 +25,7 @@ import { renderTemplate } from '../utils/render_template'; import { AssistClient } from '../utils/assist'; import { getCitations } from '../utils/get_citations'; import { getTokenEstimate, getTokenEstimateFromMessages } from './token_tracking'; +import { ContextLimitError } from './errors'; interface RAGOptions { index: string; @@ -88,37 +89,26 @@ position: ${i + 1} return serializedDocs.join('\n'); }; -export function clipContext( +export function contextLimitCheck( modelLimit: number | undefined, - prompt: ChatPromptTemplate, - data: experimental_StreamData + prompt: ChatPromptTemplate ): (input: ContextInputs) => Promise { return async (input) => { if (!modelLimit) return input; - let context = input.context; - const clippedContext = []; - while ( - getTokenEstimate(await prompt.format({ ...input, context })) > modelLimit && - context.length > 0 - ) { - // remove the last paragraph - const lines = context.split('\n'); - clippedContext.push(lines.pop()); - context = lines.join('\n'); - } + const stringPrompt = await prompt.format(input); + const approxPromptTokens = getTokenEstimate(stringPrompt); + const aboveContextLimit = approxPromptTokens > modelLimit; - if (clippedContext.length > 0) { - data.appendMessageAnnotation({ - type: 'context_clipped', - count: getTokenEstimate(clippedContext.join('\n')), - }); + if (aboveContextLimit) { + throw new ContextLimitError( + 'Context exceeds the model limit', + modelLimit, + approxPromptTokens + ); } - return { - ...input, - context, - }; + return input; }; } @@ -205,7 +195,7 @@ class ConversationalChainFn { }); return inputs; }), - RunnableLambda.from(clipContext(this.options?.rag?.inputTokensLimit, prompt, data)), + RunnableLambda.from(contextLimitCheck(this.options?.rag?.inputTokensLimit, prompt)), RunnableLambda.from(registerContextTokenCounts(data)), prompt, this.options.model.withConfig({ metadata: { type: 'question_answer_qa' } }), diff --git a/x-pack/plugins/search_playground/server/lib/errors.ts b/x-pack/plugins/search_playground/server/lib/errors.ts new file mode 100644 index 0000000000000..38441b607a64a --- /dev/null +++ b/x-pack/plugins/search_playground/server/lib/errors.ts @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export class ContextLimitError extends Error { + public modelLimit: number; + public currentTokens: number; + + constructor(message: string, modelLimit: number, currentTokens: number) { + super(message); + this.name = 'ContextLimitError'; + this.modelLimit = modelLimit; + this.currentTokens = currentTokens; + } +} diff --git a/x-pack/plugins/search_playground/server/routes.test.ts b/x-pack/plugins/search_playground/server/routes.test.ts index 018b1420a46cf..fca1adab5862b 100644 --- a/x-pack/plugins/search_playground/server/routes.test.ts +++ b/x-pack/plugins/search_playground/server/routes.test.ts @@ -12,6 +12,7 @@ import { MockRouter } from '../__mocks__/router.mock'; import { ConversationalChain } from './lib/conversational_chain'; import { getChatParams } from './lib/get_chat_params'; import { createRetriever, defineRoutes } from './routes'; +import { ContextLimitError } from './lib/errors'; jest.mock('./lib/get_chat_params', () => ({ getChatParams: jest.fn(), @@ -100,5 +101,29 @@ describe('Search Playground routes', () => { }, }); }); + + it('responds with context error message if there is ContextLimitError', async () => { + (getChatParams as jest.Mock).mockResolvedValue({ model: 'open-ai' }); + (ConversationalChain as jest.Mock).mockImplementation(() => { + return { + stream: jest + .fn() + .mockRejectedValue( + new ContextLimitError('Context exceeds the model limit', 16385, 24000) + ), + }; + }); + + await mockRouter.callRoute({ + body: mockRequestBody, + }); + + expect(mockRouter.response.badRequest).toHaveBeenCalledWith({ + body: { + message: + 'Your request uses 24000 input tokens. This exceeds the model token limit of 16385 tokens. Please try using a different model thats capable of accepting larger prompts or reducing the prompt by decreasing the size of the context documents. If you are unsure, please see our documentation.', + }, + }); + }); }); }); diff --git a/x-pack/plugins/search_playground/server/routes.ts b/x-pack/plugins/search_playground/server/routes.ts index c26a342aace49..3cdebe11c02c2 100644 --- a/x-pack/plugins/search_playground/server/routes.ts +++ b/x-pack/plugins/search_playground/server/routes.ts @@ -8,6 +8,7 @@ import { schema } from '@kbn/config-schema'; import type { Logger } from '@kbn/logging'; import { IRouter, StartServicesAccessor } from '@kbn/core/server'; +import { i18n } from '@kbn/i18n'; import { sendMessageEvent, SendMessageEventData } from './analytics/events'; import { fetchFields } from './lib/fetch_query_source_fields'; import { AssistClientOptionsWithClient, createAssist as Assist } from './utils/assist'; @@ -23,6 +24,7 @@ import { getChatParams } from './lib/get_chat_params'; import { fetchIndices } from './lib/fetch_indices'; import { isNotNullish } from '../common/is_not_nullish'; import { MODELS } from '../common/models'; +import { ContextLimitError } from './lib/errors'; export function createRetriever(esQuery: string) { return (question: string) => { @@ -157,6 +159,21 @@ export function defineRoutes({ isCloud: cloud?.isCloudEnabled ?? false, }); } catch (e) { + if (e instanceof ContextLimitError) { + return response.badRequest({ + body: { + message: i18n.translate( + 'xpack.searchPlayground.serverErrors.exceedsModelTokenLimit', + { + defaultMessage: + 'Your request uses {approxPromptTokens} input tokens. This exceeds the model token limit of {modelLimit} tokens. Please try using a different model thats capable of accepting larger prompts or reducing the prompt by decreasing the size of the context documents. If you are unsure, please see our documentation.', + values: { modelLimit: e.modelLimit, approxPromptTokens: e.currentTokens }, + } + ), + }, + }); + } + logger.error('Failed to create the chat stream', e); if (typeof e === 'object') { diff --git a/x-pack/test/functional/page_objects/search_playground_page.ts b/x-pack/test/functional/page_objects/search_playground_page.ts index 9b44addce9e25..3a47da067097f 100644 --- a/x-pack/test/functional/page_objects/search_playground_page.ts +++ b/x-pack/test/functional/page_objects/search_playground_page.ts @@ -146,7 +146,7 @@ export function SearchPlaygroundPageProvider({ getService }: FtrProviderContext) const model = await testSubjects.find('summarizationModelSelect'); const defaultModel = await model.getVisibleText(); - expect(defaultModel).to.equal('OpenAI GPT-3.5 Turbo'); + expect(defaultModel).to.equal('OpenAI GPT-4o'); expect(defaultModel).not.to.be.empty(); expect(