Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Jan 18, 2024
1 parent 5238461 commit 8b368a7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
10 changes: 5 additions & 5 deletions docs/docs/use_cases/sql/large_db.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 59,
"id": "0ccb0bf5-c580-428f-9cde-a58772ae784e",
"metadata": {},
"outputs": [
Expand All @@ -202,7 +202,7 @@
"[Table(name='Music')]"
]
},
"execution_count": 49,
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -219,7 +219,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 60,
"id": "ae4899fc-6f8a-4b10-983c-9e3fef4a7bb9",
"metadata": {},
"outputs": [
Expand All @@ -229,7 +229,7 @@
"['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']"
]
},
"execution_count": 50,
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -258,7 +258,7 @@
" return tables\n",
"\n",
"\n",
"table_chain = category_chain | get_tables\n",
"table_chain = category_chain | get_tables # noqa\n",
"table_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})"
]
},
Expand Down
16 changes: 11 additions & 5 deletions libs/langchain/langchain/chains/sql_database/query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import List, Optional, TypedDict, Union
from typing import Any, Dict, List, Optional, TypedDict, Union

from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnableParallel, RunnablePassthrough
from langchain_core.runnables import Runnable, RunnablePassthrough

from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS

Expand All @@ -31,7 +31,7 @@ def create_sql_query_chain(
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5,
) -> Runnable[Union[SQLInput, SQLInputWithTables], str]:
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
"""Create a chain that generates SQL queries.
*Security Note*: This chain generates SQL queries for the given database.
Expand Down Expand Up @@ -128,8 +128,14 @@ def create_sql_query_chain(
),
}
return (
RunnablePassthrough.assign(**inputs)
| (lambda x: {k: v for k, v in x.items() if k not in ("question", "table_names_to_use")})
RunnablePassthrough.assign(**inputs) # type: ignore
| (
lambda x: {
k: v
for k, v in x.items()
if k not in ("question", "table_names_to_use")
}
)
| prompt_to_use.partial(top_k=str(k))
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
Expand Down

0 comments on commit 8b368a7

Please sign in to comment.