From 82102c99b3f6ff5c9bdb676e19ef7b9ec299a64f Mon Sep 17 00:00:00 2001 From: Abdul Date: Fri, 1 Dec 2023 19:26:16 -0800 Subject: [PATCH] langchain[patch]: Running SQLDatabaseChain adds prefix "SQLQuery:\n" (#14058) - **Issue:** https://github.com/langchain-ai/langchain/issues/12077 --------- Co-authored-by: Abdul Kader Maliyakkal --- libs/experimental/langchain_experimental/sql/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/experimental/langchain_experimental/sql/base.py b/libs/experimental/langchain_experimental/sql/base.py index e14989db83771..ce46e8e04a452 100644 --- a/libs/experimental/langchain_experimental/sql/base.py +++ b/libs/experimental/langchain_experimental/sql/base.py @@ -17,6 +17,7 @@ from langchain_experimental.pydantic_v1 import Extra, Field, root_validator INTERMEDIATE_STEPS_KEY = "intermediate_steps" +SQL_QUERY = "SQLQuery:" class SQLDatabaseChain(Chain): @@ -110,7 +111,7 @@ def _call( run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - input_text = f"{inputs[self.input_key]}\nSQLQuery:" + input_text = f"{inputs[self.input_key]}\n{SQL_QUERY}" _run_manager.on_text(input_text, verbose=self.verbose) # If not present, then defaults to None which is all tables. table_names_to_use = inputs.get("table_names_to_use") @@ -140,6 +141,8 @@ def _call( sql_cmd ) # output: sql generation (no checker) intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec + if SQL_QUERY in sql_cmd: + sql_cmd = sql_cmd.split(SQL_QUERY)[1].strip() result = self.database.run(sql_cmd) intermediate_steps.append(str(result)) # output: sql exec else: