From bb259eaa670240ead9bb9964e9f0b0e19f0f5cde Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 9 Dec 2024 22:49:47 +0100 Subject: [PATCH] Added output_processor parameter to SQLQueryOperator and fixed bug with return_single_query_results handler when None is passed as split_statements (#44781) * refactor: Added output_processor parameter to SQLQueryOperator * refactor: Reformatted SQLQueryOperator * refactor: Fixed return_single_query_results when none is passed as split_statements * refactor: Reformatted SQLExecuteOperator * refactor: Added white line --------- Co-authored-by: David Blain --- .../providers/common/sql/hooks/handlers.py | 4 +++- .../airflow/providers/common/sql/hooks/sql.py | 2 +- .../providers/common/sql/operators/sql.py | 24 ++++++++++++++++--- .../providers/common/sql/operators/sql.pyi | 7 ++++++ .../tests/common/sql/hooks/test_handlers.py | 1 + .../tests/common/sql/operators/test_sql.py | 19 +++++++++++++++ 6 files changed, 52 insertions(+), 5 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/hooks/handlers.py b/providers/src/airflow/providers/common/sql/hooks/handlers.py index 3636cc214d213..b399dc0023f6f 100644 --- a/providers/src/airflow/providers/common/sql/hooks/handlers.py +++ b/providers/src/airflow/providers/common/sql/hooks/handlers.py @@ -44,7 +44,9 @@ def return_single_query_results(sql: str | Iterable[str], return_last: bool, spl :param split_statements: whether to split string statements. :return: True if the hook should return single query results """ - return isinstance(sql, str) and (return_last or not split_statements) + if split_statements is not None: + return isinstance(sql, str) and (return_last or not split_statements) + return isinstance(sql, str) and return_last def fetch_all_handler(cursor) -> list[tuple] | None: diff --git a/providers/src/airflow/providers/common/sql/hooks/sql.py b/providers/src/airflow/providers/common/sql/hooks/sql.py index bd8780a750dbd..f4d107f0c5f3e 100644 --- a/providers/src/airflow/providers/common/sql/hooks/sql.py +++ b/providers/src/airflow/providers/common/sql/hooks/sql.py @@ -59,7 +59,7 @@ be removed in the future. Please import it from 'airflow.providers.common.sql.hooks.handlers'.""" -def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool): +def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool | None): warnings.warn(WARNING_MESSAGE.format("return_single_query_results"), DeprecationWarning, stacklevel=2) from airflow.providers.common.sql.hooks import handlers diff --git a/providers/src/airflow/providers/common/sql/operators/sql.py b/providers/src/airflow/providers/common/sql/operators/sql.py index 56b14fe66b22b..7ea887461e7c1 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/src/airflow/providers/common/sql/operators/sql.py @@ -116,6 +116,10 @@ def _get_failed_checks(checks, col=None): } +def default_output_processor(results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]: + return results + + class BaseSQLOperator(BaseOperator): """ This is a base class for generic SQL Operator to get a DB Hook. @@ -210,6 +214,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator): :param autocommit: (optional) if True, each command is automatically committed (default: False). :param parameters: (optional) the parameters to render the SQL query with. :param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler). + :param output_processor: (optional) the function that will be applied to the result + (default: default_output_processor). :param split_statements: (optional) if split single SQL string into statements. By default, defers to the default value in the ``run`` method of the configured hook. :param conn_id: the connection ID used to connect to the database @@ -235,6 +241,13 @@ def __init__( autocommit: bool = False, parameters: Mapping | Iterable | None = None, handler: Callable[[Any], list[tuple] | None] = fetch_all_handler, + output_processor: ( + Callable[ + [list[Any], list[Sequence[Sequence] | None]], + list[Any] | tuple[list[Sequence[Sequence] | None], list], + ] + | None + ) = None, conn_id: str | None = None, database: str | None = None, split_statements: bool | None = None, @@ -247,11 +260,14 @@ def __init__( self.autocommit = autocommit self.parameters = parameters self.handler = handler + self._output_processor = output_processor or default_output_processor self.split_statements = split_statements self.return_last = return_last self.show_return_value_in_logs = show_return_value_in_logs - def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequence] | None]) -> list[Any]: + def _process_output( + self, results: list[Any], descriptions: list[Sequence[Sequence] | None] + ) -> list[Any] | tuple[list[Sequence[Sequence] | None], list]: """ Process output before it is returned by the operator. @@ -270,7 +286,7 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen """ if self.show_return_value_in_logs: self.log.info("Operator output is: %s", results) - return results + return self._output_processor(results, descriptions) def _should_run_output_processing(self) -> bool: return self.do_xcom_push @@ -297,7 +313,9 @@ def execute(self, context): # single query results are going to be returned, and we return the first element # of the list in this case from the (always) list returned by _process_output return self._process_output([output], hook.descriptions)[-1] - return self._process_output(output, hook.descriptions) + result = self._process_output(output, hook.descriptions) + self.log.info("result: %s", result) + return result def prepare_template(self) -> None: """Parse template file for attribute parameters.""" diff --git a/providers/src/airflow/providers/common/sql/operators/sql.pyi b/providers/src/airflow/providers/common/sql/operators/sql.pyi index 6921e3411ea01..6f89fc8b6ebb2 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.pyi +++ b/providers/src/airflow/providers/common/sql/operators/sql.pyi @@ -78,6 +78,13 @@ class SQLExecuteQueryOperator(BaseSQLOperator): autocommit: bool = False, parameters: Mapping | Iterable | None = None, handler: Callable[[Any], list[tuple] | None] = ..., + output_processor: ( + Callable[ + [list[Any], list[Sequence[Sequence] | None]], + list[Any] | tuple[list[Sequence[Sequence] | None], list], + ] + | None + ) = None, conn_id: str | None = None, database: str | None = None, split_statements: bool | None = None, diff --git a/providers/tests/common/sql/hooks/test_handlers.py b/providers/tests/common/sql/hooks/test_handlers.py index 9adf8df67c82c..8fd3ed8b65f18 100644 --- a/providers/tests/common/sql/hooks/test_handlers.py +++ b/providers/tests/common/sql/hooks/test_handlers.py @@ -30,6 +30,7 @@ class TestHandlers: def test_return_single_query_results(self): assert return_single_query_results("SELECT 1", return_last=True, split_statements=False) assert return_single_query_results("SELECT 1", return_last=False, split_statements=False) + assert return_single_query_results("SELECT 1", return_last=False, split_statements=None) is False assert return_single_query_results(["SELECT 1"], return_last=True, split_statements=False) is False assert return_single_query_results(["SELECT 1"], return_last=False, split_statements=False) is False assert return_single_query_results("SELECT 1", return_last=False, split_statements=True) is False diff --git a/providers/tests/common/sql/operators/test_sql.py b/providers/tests/common/sql/operators/test_sql.py index 85ab77a0aecca..133d51ac75753 100644 --- a/providers/tests/common/sql/operators/test_sql.py +++ b/providers/tests/common/sql/operators/test_sql.py @@ -148,6 +148,25 @@ def test_dont_xcom_push(self, mock_get_db_hook, mock_process_output): ) mock_process_output.assert_not_called() + @mock.patch.object(SQLExecuteQueryOperator, "get_db_hook") + def test_output_processor(self, mock_get_db_hook): + data = [(1, "Alice"), (2, "Bob")] + + mock_hook = MagicMock() + mock_hook.run.return_value = data + mock_hook.descriptions = ("id", "name") + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + sql="SELECT * FROM users;", + output_processor=lambda results, descriptions: (descriptions, results), + return_last=False, + ) + descriptions, result = operator.execute(context=MagicMock()) + + assert descriptions == ("id", "name") + assert result == [(1, "Alice"), (2, "Bob")] + class TestColumnCheckOperator: valid_column_mapping = {