Skip to content

Commit

Permalink
Added output_processor parameter to SQLQueryOperator and fixed bug wi…
Browse files Browse the repository at this point in the history
…th return_single_query_results handler when None is passed as split_statements (apache#44781)

* refactor: Added output_processor parameter to SQLQueryOperator

* refactor: Reformatted SQLQueryOperator

* refactor: Fixed return_single_query_results when none is passed as split_statements

* refactor: Reformatted SQLExecuteOperator

* refactor: Added white line

---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored Dec 9, 2024
1 parent 771d56b commit bb259ea
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 5 deletions.
4 changes: 3 additions & 1 deletion providers/src/airflow/providers/common/sql/hooks/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def return_single_query_results(sql: str | Iterable[str], return_last: bool, spl
:param split_statements: whether to split string statements.
:return: True if the hook should return single query results
"""
return isinstance(sql, str) and (return_last or not split_statements)
if split_statements is not None:
return isinstance(sql, str) and (return_last or not split_statements)
return isinstance(sql, str) and return_last


def fetch_all_handler(cursor) -> list[tuple] | None:
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
be removed in the future. Please import it from 'airflow.providers.common.sql.hooks.handlers'."""


def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool):
def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool | None):
warnings.warn(WARNING_MESSAGE.format("return_single_query_results"), DeprecationWarning, stacklevel=2)

from airflow.providers.common.sql.hooks import handlers
Expand Down
24 changes: 21 additions & 3 deletions providers/src/airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def _get_failed_checks(checks, col=None):
}


def default_output_processor(results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]:
return results


class BaseSQLOperator(BaseOperator):
"""
This is a base class for generic SQL Operator to get a DB Hook.
Expand Down Expand Up @@ -210,6 +214,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
:param autocommit: (optional) if True, each command is automatically committed (default: False).
:param parameters: (optional) the parameters to render the SQL query with.
:param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler).
:param output_processor: (optional) the function that will be applied to the result
(default: default_output_processor).
:param split_statements: (optional) if split single SQL string into statements. By default, defers
to the default value in the ``run`` method of the configured hook.
:param conn_id: the connection ID used to connect to the database
Expand All @@ -235,6 +241,13 @@ def __init__(
autocommit: bool = False,
parameters: Mapping | Iterable | None = None,
handler: Callable[[Any], list[tuple] | None] = fetch_all_handler,
output_processor: (
Callable[
[list[Any], list[Sequence[Sequence] | None]],
list[Any] | tuple[list[Sequence[Sequence] | None], list],
]
| None
) = None,
conn_id: str | None = None,
database: str | None = None,
split_statements: bool | None = None,
Expand All @@ -247,11 +260,14 @@ def __init__(
self.autocommit = autocommit
self.parameters = parameters
self.handler = handler
self._output_processor = output_processor or default_output_processor
self.split_statements = split_statements
self.return_last = return_last
self.show_return_value_in_logs = show_return_value_in_logs

def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]:
def _process_output(
self, results: list[Any], descriptions: list[Sequence[Sequence] | None]
) -> list[Any] | tuple[list[Sequence[Sequence] | None], list]:
"""
Process output before it is returned by the operator.
Expand All @@ -270,7 +286,7 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen
"""
if self.show_return_value_in_logs:
self.log.info("Operator output is: %s", results)
return results
return self._output_processor(results, descriptions)

def _should_run_output_processing(self) -> bool:
return self.do_xcom_push
Expand All @@ -297,7 +313,9 @@ def execute(self, context):
# single query results are going to be returned, and we return the first element
# of the list in this case from the (always) list returned by _process_output
return self._process_output([output], hook.descriptions)[-1]
return self._process_output(output, hook.descriptions)
result = self._process_output(output, hook.descriptions)
self.log.info("result: %s", result)
return result

def prepare_template(self) -> None:
"""Parse template file for attribute parameters."""
Expand Down
7 changes: 7 additions & 0 deletions providers/src/airflow/providers/common/sql/operators/sql.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
autocommit: bool = False,
parameters: Mapping | Iterable | None = None,
handler: Callable[[Any], list[tuple] | None] = ...,
output_processor: (
Callable[
[list[Any], list[Sequence[Sequence] | None]],
list[Any] | tuple[list[Sequence[Sequence] | None], list],
]
| None
) = None,
conn_id: str | None = None,
database: str | None = None,
split_statements: bool | None = None,
Expand Down
1 change: 1 addition & 0 deletions providers/tests/common/sql/hooks/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TestHandlers:
def test_return_single_query_results(self):
assert return_single_query_results("SELECT 1", return_last=True, split_statements=False)
assert return_single_query_results("SELECT 1", return_last=False, split_statements=False)
assert return_single_query_results("SELECT 1", return_last=False, split_statements=None) is False
assert return_single_query_results(["SELECT 1"], return_last=True, split_statements=False) is False
assert return_single_query_results(["SELECT 1"], return_last=False, split_statements=False) is False
assert return_single_query_results("SELECT 1", return_last=False, split_statements=True) is False
Expand Down
19 changes: 19 additions & 0 deletions providers/tests/common/sql/operators/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,25 @@ def test_dont_xcom_push(self, mock_get_db_hook, mock_process_output):
)
mock_process_output.assert_not_called()

@mock.patch.object(SQLExecuteQueryOperator, "get_db_hook")
def test_output_processor(self, mock_get_db_hook):
data = [(1, "Alice"), (2, "Bob")]

mock_hook = MagicMock()
mock_hook.run.return_value = data
mock_hook.descriptions = ("id", "name")
mock_get_db_hook.return_value = mock_hook

operator = self._construct_operator(
sql="SELECT * FROM users;",
output_processor=lambda results, descriptions: (descriptions, results),
return_last=False,
)
descriptions, result = operator.execute(context=MagicMock())

assert descriptions == ("id", "name")
assert result == [(1, "Alice"), (2, "Bob")]


class TestColumnCheckOperator:
valid_column_mapping = {
Expand Down

0 comments on commit bb259ea

Please sign in to comment.