From 518d394119af0afe302a2b5b4f406af330e5078f Mon Sep 17 00:00:00 2001 From: Tim Zhou <5866950+ttzhou@users.noreply.github.com> Date: Wed, 27 Nov 2024 22:50:18 -0500 Subject: [PATCH] Allow `json_result_force_utf8_encoding` specification in `providers.snowflake.hooks.SnowflakeHook` extra dict (#44264) * Allow json_result_force_utf8_encoding specification in SnowflakeHook extra dict * Use a set for the not in --- .../providers/snowflake/hooks/snowflake.py | 17 +++++++++++++- .../tests/snowflake/hooks/test_snowflake.py | 22 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/providers/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/src/airflow/providers/snowflake/hooks/snowflake.py index e957c0623cb03..fdf75939bdacf 100644 --- a/providers/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/src/airflow/providers/snowflake/hooks/snowflake.py @@ -201,6 +201,9 @@ def _get_conn_params(self) -> dict[str, str | None]: region = self._get_field(extra_dict, "region") or "" role = self._get_field(extra_dict, "role") or "" insecure_mode = _try_to_boolean(self._get_field(extra_dict, "insecure_mode")) + json_result_force_utf8_decoding = _try_to_boolean( + self._get_field(extra_dict, "json_result_force_utf8_decoding") + ) schema = conn.schema or "" client_request_mfa_token = _try_to_boolean(self._get_field(extra_dict, "client_request_mfa_token")) @@ -225,6 +228,9 @@ def _get_conn_params(self) -> dict[str, str | None]: if insecure_mode: conn_config["insecure_mode"] = insecure_mode + if json_result_force_utf8_decoding: + conn_config["json_result_force_utf8_decoding"] = json_result_force_utf8_decoding + if client_request_mfa_token: conn_config["client_request_mfa_token"] = client_request_mfa_token @@ -302,7 +308,13 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str: for k, v in conn_params.items() if v and k - not in ["session_parameters", "insecure_mode", "private_key", "client_request_mfa_token"] + not in { + "session_parameters", + "insecure_mode", + "private_key", + "client_request_mfa_token", + "json_result_force_utf8_decoding", + } } ) @@ -324,6 +336,9 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): if "insecure_mode" in conn_params: engine_kwargs.setdefault("connect_args", {}) engine_kwargs["connect_args"]["insecure_mode"] = True + if "json_result_force_utf8_decoding" in conn_params: + engine_kwargs.setdefault("connect_args", {}) + engine_kwargs["connect_args"]["json_result_force_utf8_decoding"] = True for key in ["session_parameters", "private_key"]: if conn_params.get(key): engine_kwargs.setdefault("connect_args", {}) diff --git a/providers/tests/snowflake/hooks/test_snowflake.py b/providers/tests/snowflake/hooks/test_snowflake.py index b7c9382654be0..d75f1a4baf14c 100644 --- a/providers/tests/snowflake/hooks/test_snowflake.py +++ b/providers/tests/snowflake/hooks/test_snowflake.py @@ -138,6 +138,7 @@ class TestPytestSnowflakeHook: "extra__snowflake__region": "af_region", "extra__snowflake__role": "af_role", "extra__snowflake__insecure_mode": "True", + "extra__snowflake__json_result_force_utf8_decoding": "True", "extra__snowflake__client_request_mfa_token": "True", }, }, @@ -158,6 +159,7 @@ class TestPytestSnowflakeHook: "user": "user", "warehouse": "af_wh", "insecure_mode": True, + "json_result_force_utf8_decoding": True, "client_request_mfa_token": True, }, ), @@ -171,6 +173,7 @@ class TestPytestSnowflakeHook: "extra__snowflake__region": "af_region", "extra__snowflake__role": "af_role", "extra__snowflake__insecure_mode": "False", + "extra__snowflake__json_result_force_utf8_decoding": "False", "extra__snowflake__client_request_mfa_token": "False", }, }, @@ -247,6 +250,7 @@ class TestPytestSnowflakeHook: "extra": { **BASE_CONNECTION_KWARGS["extra"], "extra__snowflake__insecure_mode": False, + "extra__snowflake__json_result_force_utf8_decoding": True, "extra__snowflake__client_request_mfa_token": False, }, }, @@ -266,6 +270,7 @@ class TestPytestSnowflakeHook: "session_parameters": None, "user": "user", "warehouse": "af_wh", + "json_result_force_utf8_decoding": True, }, ), ], @@ -473,6 +478,23 @@ def test_get_sqlalchemy_engine_should_support_insecure_mode(self): ) assert mock_create_engine.return_value == conn + def test_get_sqlalchemy_engine_should_support_json_result_force_utf8_decoding(self): + connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) + connection_kwargs["extra"]["extra__snowflake__json_result_force_utf8_decoding"] = "True" + + with ( + mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()), + mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as mock_create_engine, + ): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + conn = hook.get_sqlalchemy_engine() + mock_create_engine.assert_called_once_with( + "snowflake://user:pw@airflow.af_region/db/public" + "?application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh", + connect_args={"json_result_force_utf8_decoding": True}, + ) + assert mock_create_engine.return_value == conn + def test_get_sqlalchemy_engine_should_support_session_parameters(self): connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) connection_kwargs["extra"]["session_parameters"] = {"TEST_PARAM": "AA", "TEST_PARAM_B": 123}