Skip to content

Commit

Permalink
Move responsibility to run a command from WinRMOperator to WinRMHook (a…
Browse files Browse the repository at this point in the history
…pache#43646)

* refactor: Moved responsibility to run a command away from WinRmOperator to WinRMHook and also made WinRMHook closable

* refactor: Reformatted exception message in WinRMOperator

* refactor: command parameter of run method in WinRMHook must be specified

* refactor: Changed return type of run method in WinRMHook

* refactor: WinRMHook cannot be closable as it doesn't have the winrm_client instance

* refactor: Reorganized imports in WinRMHook

* refactor: Added unit tests for new run method in WinRMHook

* refactor: Reorganized imports in TestWinRMHook

---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored Nov 6, 2024
1 parent cd75707 commit 464f7c4
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 62 deletions.
72 changes: 72 additions & 0 deletions providers/src/airflow/providers/microsoft/winrm/hooks/winrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

from __future__ import annotations

from base64 import b64encode
from contextlib import suppress

from winrm.exceptions import WinRMOperationTimeoutError
from winrm.protocol import Protocol

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -218,3 +222,71 @@ def get_conn(self):
raise AirflowException(error_msg)

return self.client

def run(
self,
command: str,
ps_path: str | None = None,
output_encoding: str = "utf-8",
return_output: bool = True,
) -> tuple[int, list[bytes], list[bytes]]:
"""
Run a command.
:param command: command to execute on remote host.
:param ps_path: path to powershell, `powershell` for v5.1- and `pwsh` for v6+.
If specified, it will execute the command as powershell script.
:param output_encoding: the encoding used to decode stout and stderr.
:param return_output: Whether to accumulate and return the stdout or not.
:return: returns a tuple containing return_code, stdout and stderr in order.
"""
winrm_client = self.get_conn()

try:
if ps_path is not None:
self.log.info("Running command as powershell script: '%s'...", command)
encoded_ps = b64encode(command.encode("utf_16_le")).decode("ascii")
command_id = self.winrm_protocol.run_command( # type: ignore[attr-defined]
winrm_client, f"{ps_path} -encodedcommand {encoded_ps}"
)
else:
self.log.info("Running command: '%s'...", command)
command_id = self.winrm_protocol.run_command( # type: ignore[attr-defined]
winrm_client, command
)

# See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
stdout_buffer = []
stderr_buffer = []
command_done = False
while not command_done:
# this is an expected error when waiting for a long-running process, just silently retry
with suppress(WinRMOperationTimeoutError):
(
stdout,
stderr,
return_code,
command_done,
) = self.winrm_protocol._raw_get_command_output( # type: ignore[attr-defined]
winrm_client, command_id
)

# Only buffer stdout if we need to so that we minimize memory usage.
if return_output:
stdout_buffer.append(stdout)
stderr_buffer.append(stderr)

for line in stdout.decode(output_encoding).splitlines():
self.log.info(line)
for line in stderr.decode(output_encoding).splitlines():
self.log.warning(line)

self.winrm_protocol.cleanup_command( # type: ignore[attr-defined]
winrm_client, command_id
)

return return_code, stdout_buffer, stderr_buffer
except Exception as e:
raise AirflowException(f"WinRM operator error: {e}")
finally:
self.winrm_protocol.close_shell(winrm_client) # type: ignore[attr-defined]
73 changes: 12 additions & 61 deletions providers/src/airflow/providers/microsoft/winrm/operators/winrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from base64 import b64encode
from typing import TYPE_CHECKING, Sequence

from winrm.exceptions import WinRMOperationTimeoutError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -90,68 +88,21 @@ def execute(self, context: Context) -> list | str:
if not self.command:
raise AirflowException("No command specified so nothing to execute here.")

winrm_client = self.winrm_hook.get_conn()

try:
if self.ps_path is not None:
self.log.info("Running command as powershell script: '%s'...", self.command)
encoded_ps = b64encode(self.command.encode("utf_16_le")).decode("ascii")
command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined]
winrm_client, f"{self.ps_path} -encodedcommand {encoded_ps}"
)
else:
self.log.info("Running command: '%s'...", self.command)
command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined]
winrm_client, self.command
)

# See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
stdout_buffer = []
stderr_buffer = []
command_done = False
while not command_done:
try:
(
stdout,
stderr,
return_code,
command_done,
) = self.winrm_hook.winrm_protocol._raw_get_command_output( # type: ignore[attr-defined]
winrm_client, command_id
)

# Only buffer stdout if we need to so that we minimize memory usage.
if self.do_xcom_push:
stdout_buffer.append(stdout)
stderr_buffer.append(stderr)

for line in stdout.decode(self.output_encoding).splitlines():
self.log.info(line)
for line in stderr.decode(self.output_encoding).splitlines():
self.log.warning(line)
except WinRMOperationTimeoutError:
# this is an expected error when waiting for a
# long-running process, just silently retry
pass

self.winrm_hook.winrm_protocol.cleanup_command( # type: ignore[attr-defined]
winrm_client, command_id
)
self.winrm_hook.winrm_protocol.close_shell(winrm_client) # type: ignore[attr-defined]

except Exception as e:
raise AirflowException(f"WinRM operator error: {e}")
return_code, stdout_buffer, stderr_buffer = self.winrm_hook.run(
command=self.command,
ps_path=self.ps_path,
output_encoding=self.output_encoding,
return_output=self.do_xcom_push,
)

if return_code == 0:
# returning output if do_xcom_push is set
enable_pickling = conf.getboolean("core", "enable_xcom_pickling")

if enable_pickling:
return stdout_buffer
else:
return b64encode(b"".join(stdout_buffer)).decode(self.output_encoding)
else:
stderr_output = b"".join(stderr_buffer).decode(self.output_encoding)
error_msg = (
f"Error running cmd: {self.command}, return code: {return_code}, error: {stderr_output}"
)
raise AirflowException(error_msg)
return b64encode(b"".join(stdout_buffer)).decode(self.output_encoding)

stderr_output = b"".join(stderr_buffer).decode(self.output_encoding)
error_msg = f"Error running cmd: {self.command}, return code: {return_code}, error: {stderr_output}"
raise AirflowException(error_msg)
84 changes: 83 additions & 1 deletion providers/tests/microsoft/winrm/hooks/test_winrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -119,3 +119,85 @@ def test_get_conn_no_endpoint(self, mock_protocol):
winrm_hook.get_conn()

assert f"http://{winrm_hook.remote_host}:{winrm_hook.remote_port}/wsman" == winrm_hook.endpoint

@patch("airflow.providers.microsoft.winrm.hooks.winrm.Protocol", autospec=True)
@patch(
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
return_value=Connection(
login="username",
password="password",
host="remote_host",
extra="""{
"endpoint": "endpoint",
"remote_port": 123,
"transport": "plaintext",
"service": "service",
"keytab": "keytab",
"ca_trust_path": "ca_trust_path",
"cert_pem": "cert_pem",
"cert_key_pem": "cert_key_pem",
"server_cert_validation": "validate",
"kerberos_delegation": "true",
"read_timeout_sec": 124,
"operation_timeout_sec": 123,
"kerberos_hostname_override": "kerberos_hostname_override",
"message_encryption": "auto",
"credssp_disable_tlsv1_2": "true",
"send_cbt": "false"
}""",
),
)
def test_run_with_stdout(self, mock_get_connection, mock_protocol):
winrm_hook = WinRMHook(ssh_conn_id="conn_id")

mock_protocol.return_value.run_command = MagicMock(return_value="command_id")
mock_protocol.return_value._raw_get_command_output = MagicMock(
return_value=(b"stdout", b"stderr", 0, True)
)

return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir")

assert return_code == 0
assert stdout_buffer == [b"stdout"]
assert stderr_buffer == [b"stderr"]

@patch("airflow.providers.microsoft.winrm.hooks.winrm.Protocol", autospec=True)
@patch(
"airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook.get_connection",
return_value=Connection(
login="username",
password="password",
host="remote_host",
extra="""{
"endpoint": "endpoint",
"remote_port": 123,
"transport": "plaintext",
"service": "service",
"keytab": "keytab",
"ca_trust_path": "ca_trust_path",
"cert_pem": "cert_pem",
"cert_key_pem": "cert_key_pem",
"server_cert_validation": "validate",
"kerberos_delegation": "true",
"read_timeout_sec": 124,
"operation_timeout_sec": 123,
"kerberos_hostname_override": "kerberos_hostname_override",
"message_encryption": "auto",
"credssp_disable_tlsv1_2": "true",
"send_cbt": "false"
}""",
),
)
def test_run_without_stdout(self, mock_get_connection, mock_protocol):
winrm_hook = WinRMHook(ssh_conn_id="conn_id")

mock_protocol.return_value.run_command = MagicMock(return_value="command_id")
mock_protocol.return_value._raw_get_command_output = MagicMock(
return_value=(b"stdout", b"stderr", 0, True)
)

return_code, stdout_buffer, stderr_buffer = winrm_hook.run("dir", return_output=False)

assert return_code == 0
assert not stdout_buffer
assert stderr_buffer == [b"stderr"]

0 comments on commit 464f7c4

Please sign in to comment.