diff --git a/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py b/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py index 96abdf8e9bf06..961e37ba3fe40 100644 --- a/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py +++ b/providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py @@ -19,6 +19,10 @@ from __future__ import annotations +from base64 import b64encode +from contextlib import suppress + +from winrm.exceptions import WinRMOperationTimeoutError from winrm.protocol import Protocol from airflow.exceptions import AirflowException @@ -218,3 +222,71 @@ def get_conn(self): raise AirflowException(error_msg) return self.client + + def run( + self, + command: str, + ps_path: str | None = None, + output_encoding: str = "utf-8", + return_output: bool = True, + ) -> tuple[int, list[bytes], list[bytes]]: + """ + Run a command. + + :param command: command to execute on remote host. + :param ps_path: path to powershell, `powershell` for v5.1- and `pwsh` for v6+. + If specified, it will execute the command as powershell script. + :param output_encoding: the encoding used to decode stout and stderr. + :param return_output: Whether to accumulate and return the stdout or not. + :return: returns a tuple containing return_code, stdout and stderr in order. + """ + winrm_client = self.get_conn() + + try: + if ps_path is not None: + self.log.info("Running command as powershell script: '%s'...", command) + encoded_ps = b64encode(command.encode("utf_16_le")).decode("ascii") + command_id = self.winrm_protocol.run_command( # type: ignore[attr-defined] + winrm_client, f"{ps_path} -encodedcommand {encoded_ps}" + ) + else: + self.log.info("Running command: '%s'...", command) + command_id = self.winrm_protocol.run_command( # type: ignore[attr-defined] + winrm_client, command + ) + + # See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py + stdout_buffer = [] + stderr_buffer = [] + command_done = False + while not command_done: + # this is an expected error when waiting for a long-running process, just silently retry + with suppress(WinRMOperationTimeoutError): + ( + stdout, + stderr, + return_code, + command_done, + ) = self.winrm_protocol._raw_get_command_output( # type: ignore[attr-defined] + winrm_client, command_id + ) + + # Only buffer stdout if we need to so that we minimize memory usage. + if return_output: + stdout_buffer.append(stdout) + stderr_buffer.append(stderr) + + for line in stdout.decode(output_encoding).splitlines(): + self.log.info(line) + for line in stderr.decode(output_encoding).splitlines(): + self.log.warning(line) + + self.winrm_protocol.cleanup_command( # type: ignore[attr-defined] + winrm_client, command_id + ) + + return return_code, stdout_buffer, stderr_buffer + except Exception as e: + raise AirflowException(f"WinRM operator error: {e}") + finally: + self.winrm_protocol.close_shell(winrm_client) # type: ignore[attr-defined] diff --git a/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py b/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py index 3b61afb195b82..0662333c7886c 100644 --- a/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py +++ b/providers/src/airflow/providers/microsoft/winrm/operators/winrm.py @@ -21,8 +21,6 @@ from base64 import b64encode from typing import TYPE_CHECKING, Sequence -from winrm.exceptions import WinRMOperationTimeoutError - from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -90,68 +88,21 @@ def execute(self, context: Context) -> list | str: if not self.command: raise AirflowException("No command specified so nothing to execute here.") - winrm_client = self.winrm_hook.get_conn() - - try: - if self.ps_path is not None: - self.log.info("Running command as powershell script: '%s'...", self.command) - encoded_ps = b64encode(self.command.encode("utf_16_le")).decode("ascii") - command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined] - winrm_client, f"{self.ps_path} -encodedcommand {encoded_ps}" - ) - else: - self.log.info("Running command: '%s'...", self.command) - command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined] - winrm_client, self.command - ) - - # See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py - stdout_buffer = [] - stderr_buffer = [] - command_done = False - while not command_done: - try: - ( - stdout, - stderr, - return_code, - command_done, - ) = self.winrm_hook.winrm_protocol._raw_get_command_output( # type: ignore[attr-defined] - winrm_client, command_id - ) - - # Only buffer stdout if we need to so that we minimize memory usage. - if self.do_xcom_push: - stdout_buffer.append(stdout) - stderr_buffer.append(stderr) - - for line in stdout.decode(self.output_encoding).splitlines(): - self.log.info(line) - for line in stderr.decode(self.output_encoding).splitlines(): - self.log.warning(line) - except WinRMOperationTimeoutError: - # this is an expected error when waiting for a - # long-running process, just silently retry - pass - - self.winrm_hook.winrm_protocol.cleanup_command( # type: ignore[attr-defined] - winrm_client, command_id - ) - self.winrm_hook.winrm_protocol.close_shell(winrm_client) # type: ignore[attr-defined] - - except Exception as e: - raise AirflowException(f"WinRM operator error: {e}") + return_code, stdout_buffer, stderr_buffer = self.winrm_hook.run( + command=self.command, + ps_path=self.ps_path, + output_encoding=self.output_encoding, + return_output=self.do_xcom_push, + ) if return_code == 0: # returning output if do_xcom_push is set enable_pickling = conf.getboolean("core", "enable_xcom_pickling") + if enable_pickling: return stdout_buffer - else: - return b64encode(b"".join(stdout_buffer)).decode(self.output_encoding) - else: - stderr_output = b"".join(stderr_buffer).decode(self.output_encoding) - error_msg = ( - f"Error running cmd: {self.command}, return code: {return_code}, error: {stderr_output}" - ) - raise AirflowException(error_msg) + return b64encode(b"".join(stdout_buffer)).decode(self.output_encoding) + + stderr_output = b"".join(stderr_buffer).decode(self.output_encoding) + error_msg = f"Error running cmd: {self.command}, return code: {return_code}, error: {stderr_output}" + raise AirflowException(error_msg) diff --git a/providers/tests/microsoft/winrm/hooks/test_winrm.py b/providers/tests/microsoft/winrm/hooks/test_winrm.py index 83411ccf9ccf2..7c2223cadc3f4 100644 --- a/providers/tests/microsoft/winrm/hooks/test_winrm.py +++ b/providers/tests/microsoft/winrm/hooks/test_winrm.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -119,3 +119,85 @@ def test_get_conn_no_endpoint(self, mock_protocol): winrm_hook.get_conn() assert f"http://{winrm_hook.remote_host}:{winrm_hook.remote_port}/wsman" == winrm_hook.endpoint + + @patch("airflow.providers.microsoft.winrm.hooks.winrm.Protocol", autospec=True) + @patch( + "airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection", + return_value=Connection( + login="username", + password="password", + host="remote_host", + extra="""{ + "endpoint": "endpoint", + "remote_port": 123, + "transport": "plaintext", + "service": "service", + "keytab": "keytab", + "ca_trust_path": "ca_trust_path", + "cert_pem": "cert_pem", + "cert_key_pem": "cert_key_pem", + "server_cert_validation": "validate", + "kerberos_delegation": "true", + "read_timeout_sec": 124, + "operation_timeout_sec": 123, + "kerberos_hostname_override": "kerberos_hostname_override", + "message_encryption": "auto", + "credssp_disable_tlsv1_2": "true", + "send_cbt": "false" + }""", + ), + ) + def test_run_with_stdout(self, mock_get_connection, mock_protocol): + winrm_hook = WinRMHook(ssh_conn_id="conn_id") + + mock_protocol.return_value.run_command = MagicMock(return_value="command_id") + mock_protocol.return_value._raw_get_command_output = MagicMock( + return_value=(b"stdout", b"stderr", 0, True) + ) + + return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir") + + assert return_code == 0 + assert stdout_buffer == [b"stdout"] + assert stderr_buffer == [b"stderr"] + + @patch("airflow.providers.microsoft.winrm.hooks.winrm.Protocol", autospec=True) + @patch( + "airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection", + return_value=Connection( + login="username", + password="password", + host="remote_host", + extra="""{ + "endpoint": "endpoint", + "remote_port": 123, + "transport": "plaintext", + "service": "service", + "keytab": "keytab", + "ca_trust_path": "ca_trust_path", + "cert_pem": "cert_pem", + "cert_key_pem": "cert_key_pem", + "server_cert_validation": "validate", + "kerberos_delegation": "true", + "read_timeout_sec": 124, + "operation_timeout_sec": 123, + "kerberos_hostname_override": "kerberos_hostname_override", + "message_encryption": "auto", + "credssp_disable_tlsv1_2": "true", + "send_cbt": "false" + }""", + ), + ) + def test_run_without_stdout(self, mock_get_connection, mock_protocol): + winrm_hook = WinRMHook(ssh_conn_id="conn_id") + + mock_protocol.return_value.run_command = MagicMock(return_value="command_id") + mock_protocol.return_value._raw_get_command_output = MagicMock( + return_value=(b"stdout", b"stderr", 0, True) + ) + + return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir", return_output=False) + + assert return_code == 0 + assert not stdout_buffer + assert stderr_buffer == [b"stderr"]