Skip to content

Commit

Permalink
feat: allow SQLDatabase to run queries with parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
danbianchini committed Jan 3, 2024
1 parent 73da8f8 commit 9cadb1d
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions libs/community/langchain_community/utilities/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

import warnings
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union

import sqlalchemy
from langchain_core.utils import get_from_env
Expand Down Expand Up @@ -376,14 +376,15 @@ def _get_sample_rows(self, table: Table) -> str:
def _execute(
self,
command: str,
fetch: Literal["all", "one"] = "all",
fetch: Union[Literal["all"], Literal["one"]] = "all",
query_params: Optional[Dict[str, Any]] = None,
) -> Sequence[Dict[str, Any]]:
"""
Executes SQL command through underlying engine.
If the statement returns no rows, an empty list is returned.
"""
with self._engine.begin() as connection: # type: Connection
with self._engine.begin() as connection:
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
Expand Down Expand Up @@ -411,7 +412,7 @@ def _execute(
pass
else: # postgresql and other compatible dialects
connection.exec_driver_sql("SET search_path TO %s", (self._schema,))
cursor = connection.execute(text(command))
cursor = connection.execute(text(command), parameters=query_params)
if cursor.returns_rows:
if fetch == "all":
result = [x._asdict() for x in cursor.fetchall()]
Expand All @@ -426,15 +427,17 @@ def _execute(
def run(
self,
command: str,
fetch: Literal["all", "one"] = "all",
fetch: Union[Literal["all"], Literal["one"]] = "all",
include_columns: bool = False,
query_params: Optional[Dict[str, Any]] = None,
) -> str:
"""Execute a SQL command and return a string representing the results.
Optionally, query_params may be provided to parameterize the query.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
result = self._execute(command, fetch)
result = self._execute(command, fetch, query_params)

res = [
{
Expand Down Expand Up @@ -471,18 +474,20 @@ def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> st
def run_no_throw(
self,
command: str,
fetch: Literal["all", "one"] = "all",
fetch: Union[Literal["all"], Literal["one"]] = "all",
include_columns: bool = False,
query_params: Optional[Dict[str, Any]] = None,
) -> str:
"""Execute a SQL command and return a string representing the results.
Optionally, query_params may be provided to parameterize the query.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
If the statement throws an error, the error message is returned.
"""
try:
return self.run(command, fetch, include_columns)
return self.run(command, fetch, include_columns, query_params)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"

0 comments on commit 9cadb1d

Please sign in to comment.