Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DERCBOT-1037] Use of PromptTemplate And Rewrite the RAG chain using LCEL #1772

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,14 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler {
query = RAGQuery(
history = getDialogHistory(dialog),
questionAnsweringLlmSetting = ragConfiguration.llmSetting,
questionAnsweringPromptInputs = mapOf(
"question" to action.toString(),
"locale" to userPreferences.locale.displayLanguage,
"no_answer" to ragConfiguration.noAnswerSentence
questionAnsweringPrompt = PromptTemplate(
formatter = Formatter.F_STRING.id,
template = ragConfiguration.llmSetting.prompt,
inputs = mapOf(
"question" to action.toString(),
"locale" to userPreferences.locale.displayLanguage,
"no_answer" to ragConfiguration.noAnswerSentence
)
),
embeddingQuestionEmSetting = ragConfiguration.emSetting,
documentIndexName = indexName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import ai.tock.genai.orchestratorcore.models.vectorstore.VectorStoreSetting

data class RAGQuery(
// val condenseQuestionLlmSetting: LLMSetting,
// val condenseQuestionPromptInputs: Map<String, String>,
// val condenseQuestionPrompt: PromptTemplate,
val history: List<ChatMessage> = emptyList(),
val questionAnsweringLlmSetting: LLMSetting,
val questionAnsweringPromptInputs: Map<String, String>,
val questionAnsweringPrompt: PromptTemplate,
val embeddingQuestionEmSetting: EMSetting,
val documentIndexName: String,
val documentSearchParams: DocumentSearchParamsBase,
Expand Down
1,366 changes: 691 additions & 675 deletions gen-ai/orchestrator-server/src/main/python/server/poetry.lock

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions gen-ai/orchestrator-server/src/main/python/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,23 @@ packages = [{include = "gen_ai_orchestrator", from = "src"}]

[tool.poetry.dependencies]
python = "^3.10"
uvicorn = "^0.31.1"
pydantic-settings="^2.5.2"
fastapi = "^0.115.0"
langchain = "^0.3.3"
langchain-community = "^0.3.2"
langchain-openai = "^0.2.2"
uvicorn = "^0.32.0"
pydantic-settings="^2.6.0"
fastapi = "^0.115.3"
langchain = "^0.3.4"
langchain-community = "^0.3.3"
langchain-openai = "^0.2.3"
tiktoken = "^0.8.0"
opensearch-py = "^2.7.1"
path = "^17.0.0"
colorlog = "^6.8.2"
boto3 = "^1.35.37"
boto3 = "^1.35.48"
urllib3 = "^2.2.3"
jinja2 = "^3.1.4"
langfuse = "^2.52.0"
langfuse = "^2.52.2"
httpx-auth-awssigv4 = "^0.1.4"
langchain-postgres = "^0.0.12"
google-cloud-secret-manager = "^2.20.2"
google-cloud-secret-manager = "^2.21.0"
psycopg = {extras = ["binary"], version = "^3.2.3"}


Expand All @@ -35,8 +35,8 @@ bandit = {version = "^1.7.7", extras = ["json"]}
cyclonedx-bom = "^4.1.4"

[tool.poetry.group.test.dependencies]
tox = "^4.11.4"
coverage = "^7.4.0"
tox = "^4.23.2"
coverage = "^7.6.4"
pytest = "^7.4.4"
pytest-asyncio = "^0.23.6"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,3 @@ class BaseLLMSetting(BaseModel):
ge=0,
le=2,
)
prompt: str = Field(
description='The prompt to generate completions for.',
examples=['How to learn to ride a bike without wheels!'],
min_length=1,
)
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,17 @@ class RagQuery(BaseQuery):
history: list[ChatMessage] = Field(
description="Conversation history, used to reformulate the user's question."
)
question_answering_prompt_inputs: Any = Field(
description='Key-value inputs for the llm prompt when used as a template. Please note that the '
'chat_history field must not be specified here, it will be override by the history field',
)
# condense_question_llm_setting: LLMSetting =
# Field(description="LLM setting, used to condense the user's question.")
# condense_question_prompt_inputs: Any = (
# Field(
# description='Key-value inputs for the condense question llm prompt, when used as a template.',
# ),
# condense_question_prompt: PromptTemplate = Field(
# description='Prompt template, used to create a prompt with inputs for jinja and fstring format'
# )
question_answering_llm_setting: LLMSetting = Field(
description='LLM setting, used to perform a QA Prompt.'
)
question_answering_prompt : PromptTemplate = Field(
description='Prompt template, used to create a prompt with inputs for jinja and fstring format'
)

model_config = {
'json_schema_extra': {
Expand All @@ -164,7 +161,11 @@ class RagQuery(BaseQuery):
'value': 'ab7***************************A1IV4B',
},
'temperature': 1.2,
'prompt': """Use the following context to answer the question at the end.
'model': 'gpt-3.5-turbo',
},
'question_answering_prompt': {
'formatter': 'f-string',
'template': """Use the following context to answer the question at the end.
If you don't know the answer, just say {no_answer}.

Context:
Expand All @@ -174,12 +175,11 @@ class RagQuery(BaseQuery):
{question}

Answer in {locale}:""",
'model': 'gpt-3.5-turbo',
},
'question_answering_prompt_inputs': {
'question': 'How to get started playing guitar ?',
'no_answer': "Sorry, I don't know.",
'locale': 'French',
'inputs': {
'question': 'How to get started playing guitar ?',
'no_answer': 'Sorry, I don t know.',
'locale': 'French',
}
},
'embedding_question_em_setting': {
'provider': 'OpenAI',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,14 @@

import logging
import time
from typing import Optional

from jinja2 import Template, TemplateError
from langchain_core.output_parsers import NumberedListOutputParser
from langchain_core.prompts import PromptTemplate as LangChainPromptTemplate
from langchain_core.runnables import RunnableConfig

from gen_ai_orchestrator.errors.exceptions.exceptions import (
GenAIPromptTemplateException,
)
from gen_ai_orchestrator.errors.handlers.openai.openai_exception_handler import (
openai_exception_handler,
)
from gen_ai_orchestrator.models.errors.errors_models import ErrorInfo
from gen_ai_orchestrator.models.observability.observability_trace import ObservabilityTrace
from gen_ai_orchestrator.models.prompt.prompt_formatter import PromptFormatter
from gen_ai_orchestrator.models.prompt.prompt_template import PromptTemplate
from gen_ai_orchestrator.routers.requests.requests import (
SentenceGenerationQuery,
)
Expand All @@ -42,6 +33,7 @@
from gen_ai_orchestrator.services.langchain.factories.langchain_factory import (
get_llm_factory, create_observability_callback_handler,
)
from gen_ai_orchestrator.services.utils.prompt_utility import validate_prompt_template

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,29 +82,3 @@ async def generate_and_split_sentences(
)

return SentenceGenerationResponse(sentences=sentences)


def validate_prompt_template(prompt: PromptTemplate):
"""
Prompt template validation

Args:
prompt: The prompt template

Returns:
Nothing.
Raises:
GenAIPromptTemplateException: if template is incorrect
"""
if PromptFormatter.JINJA2 == prompt.formatter:
try:
Template(prompt.template).render(prompt.inputs)
except TemplateError as exc:
logger.error('Prompt completion - template validation failed!')
logger.error(exc)
raise GenAIPromptTemplateException(
ErrorInfo(
error=exc.__class__.__name__,
cause=str(exc),
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (C) 2023-2024 Credit Mutuel Arkea
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Retriever callback handler for LangChain."""

import logging
from typing import Any, Dict, Optional

from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.messages import SystemMessage, AIMessage
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue

logger = logging.getLogger(__name__)


class RAGCallbackHandler(BaseCallbackHandler):
"""Customized RAG callback handler that retrieves data from the chain execution."""

records: Dict[str, Any] = {
'chat_prompt': None,
'chat_chain_output': None,
'rag_prompt': None,
'rag_chain_output': None,
'documents': None,
}

def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""

if kwargs['name'] == 'chat_chain_output' and isinstance(inputs, AIMessage):
self.records['chat_chain_output'] = inputs.content

if kwargs['name'] == 'rag_chain_output' and isinstance(inputs, AIMessage):
self.records['rag_chain_output'] = inputs.content

if kwargs['name'] == 'RunnableAssign<answer>' and 'documents' in inputs:
self.records['documents'] = inputs['documents']

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain.""" # if outputs is instance of StringPromptValue

if isinstance(outputs, ChatPromptValue):
self.records['chat_prompt'] = next(
(msg.content for msg in outputs.messages if isinstance(msg, SystemMessage)), None
)

if isinstance(outputs, StringPromptValue):
self.records['rag_prompt'] = outputs.text
Loading