Skip to content

Commit

Permalink
Add way to enable host key validation for SSH/SFTP
Browse files Browse the repository at this point in the history
  • Loading branch information
adammcdonagh authored Oct 19, 2024
1 parent 765ed39 commit 90c0099
Show file tree
Hide file tree
Showing 18 changed files with 338 additions and 9 deletions.
10 changes: 10 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@
"args": ["-t", "scp-basic", "-c", "test/cfg", "-v3"],
"justMyCode": false
},
{
"name": "Python: Transfer - SFTP Basic",
"type": "debugpy",
"request": "launch",
"preLaunchTask": "Build Test containers",
"program": "src/opentaskpy/cli/task_run.py",
"console": "integratedTerminal",
"args": ["-t", "sftp-basic", "-c", "test/cfg", "-v3"],
"justMyCode": false
},
{
"name": "Python: Transfer - Basic - As job",
"type": "debugpy",
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- When creating a destination directory with SFTP, it will now check whether lower level directories exist, and create them if not
- Always check for a directory before trying to delete it and thowing an exception if it doesn't exist.
- Moved exception printing for transfers to earlier in th code to ensure log messages aren't confusing.
- Add ability to check for SSH host key and validate it before proceeding with connection

# No release

Expand Down
6 changes: 6 additions & 0 deletions src/opentaskpy/config/schemas/execution/ssh/protocol.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
"port": {
"type": "integer"
},
"hostKeyValidation": {
"type": "boolean"
},
"knownHostsFile": {
"type": "string"
},
"credentials": {
"type": "object",
"properties": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
"port": {
"type": "integer"
},
"hostKeyValidation": {
"type": "boolean"
},
"knownHostsFile": {
"type": "string"
},
"supportsPosixRename": {
"type": "boolean",
"default": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
"port": {
"type": "integer"
},
"hostKeyValidation": {
"type": "boolean"
},
"knownHostsFile": {
"type": "string"
},
"supportsPosixRename": {
"type": "boolean",
"default": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
"port": {
"type": "integer"
},
"hostKeyValidation": {
"type": "boolean"
},
"knownHostsFile": {
"type": "string"
},
"credentials": {
"type": "object",
"properties": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
"port": {
"type": "integer"
},
"hostKeyValidation": {
"type": "boolean"
},
"knownHostsFile": {
"type": "string"
},
"credentials": {
"type": "object",
"properties": {
Expand Down
20 changes: 17 additions & 3 deletions src/opentaskpy/remotehandlers/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@
from io import StringIO
from shlex import quote

from paramiko import AutoAddPolicy, Channel, RSAKey, SFTPClient, SSHClient
from tenacity import retry, stop_after_attempt, wait_exponential
from paramiko import Channel, RSAKey, SFTPClient, SSHClient
from tenacity import (
retry,
retry_if_exception,
retry_if_not_exception_message,
stop_after_attempt,
wait_exponential,
)

import opentaskpy.otflogging
from opentaskpy.remotehandlers.remotehandler import RemoteTransferHandler

from .ssh_utils import setup_host_key_validation


class SFTPTransfer(RemoteTransferHandler):
"""SFTP Transfer Handler."""
Expand Down Expand Up @@ -103,6 +111,12 @@ def connect(self, hostname: str) -> None:
reraise=True,
stop=stop_after_attempt(6),
wait=wait_exponential(multiplier=2, min=5, max=60),
retry=(
retry_if_not_exception_message(
match=r".*(not found in known_hosts|Name or service not known).*"
)
& retry_if_exception(Exception)
),
)
def connect_with_retry(self, client_kwargs: dict) -> SSHClient:
"""Connect to the remote host with retry.
Expand All @@ -118,7 +132,7 @@ def connect_with_retry(self, client_kwargs: dict) -> SSHClient:
ssh_client.set_log_channel(
f"{__name__}.{ self.spec['task_id']}.paramiko.transport"
)
ssh_client.set_missing_host_key_policy(AutoAddPolicy())
setup_host_key_validation(ssh_client, self.spec, self.logger)
self.logger.info(f"Connecting to {client_kwargs['hostname']}")

# Set additional timeout options to match the standard timeout
Expand Down
23 changes: 19 additions & 4 deletions src/opentaskpy/remotehandlers/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
from io import StringIO
from shlex import quote

from paramiko import AutoAddPolicy, RSAKey, SFTPClient, SSHClient, Transport
from paramiko import RSAKey, SFTPClient, SSHClient, Transport
from paramiko.channel import ChannelFile, ChannelStderrFile
from tenacity import retry, stop_after_attempt, wait_exponential
from tenacity import (
retry,
retry_if_exception,
retry_if_not_exception_message,
stop_after_attempt,
wait_exponential,
)

import opentaskpy.otflogging
from opentaskpy.exceptions import SSHClientError
Expand All @@ -24,6 +30,8 @@
RemoteTransferHandler,
)

from .ssh_utils import setup_host_key_validation

SSH_OPTIONS: str = "-o StrictHostKeyChecking=no -o BatchMode=yes -o ConnectTimeout=5"
REMOTE_SCRIPT_BASE_DIR: str = "/tmp" # nosec B108

Expand Down Expand Up @@ -52,7 +60,7 @@ def __init__(self, spec: dict):

client = SSHClient()
client.set_log_channel(f"{__name__}.{ spec['task_id']}.paramiko.transport")
client.set_missing_host_key_policy(AutoAddPolicy())
setup_host_key_validation(client, spec, self.logger)
self.ssh_client = client

# Handle default values
Expand Down Expand Up @@ -136,6 +144,12 @@ def connect(self, hostname: str, ssh_client: SSHClient | None = None) -> None:
reraise=True,
stop=stop_after_attempt(6),
wait=wait_exponential(multiplier=2, min=5, max=60),
retry=(
retry_if_not_exception_message(
match=r".*(not found in known_hosts|Name or service not known).*"
)
& retry_if_exception(Exception)
),
)
def connect_with_retry(self, ssh_client: SSHClient, kwargs: dict) -> None:
"""Connect to the remote host with retry.
Expand Down Expand Up @@ -916,7 +930,8 @@ def __init__(self, remote_host: str, spec: dict):

client = SSHClient()
client.set_log_channel(f"{__name__}.{ spec['task_id']}.paramiko.transport")
client.set_missing_host_key_policy(AutoAddPolicy())

setup_host_key_validation(client, spec, self.logger)

self.ssh_client = client

Expand Down
28 changes: 28 additions & 0 deletions src/opentaskpy/remotehandlers/ssh_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Utility functions for SSH."""

from logging import Logger

from paramiko import AutoAddPolicy, SSHClient


def setup_host_key_validation(client: SSHClient, spec: dict, logger: Logger) -> None:
"""Set up host key validation for an SSH client.
Args:
client (SSHClient): The SSH client to set up.
spec (dict): The spec for the SSH connection.
logger (logging.Logger): The logger to use.
"""
logger.info("Loading system host keys")
client.load_system_host_keys()

if (
"hostKeyValidation" in spec["protocol"]
and spec["protocol"]["hostKeyValidation"]
):
if "knownHostsFile" in spec["protocol"]:
host_key = spec["protocol"]["knownHostsFile"]
logger.info(f"Loading host keys from {host_key}")
client.load_host_keys(host_key)
else:
client.set_missing_host_key_policy(AutoAddPolicy())
27 changes: 27 additions & 0 deletions test/cfg/transfers/sftp-basic.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"type": "transfer",
"source": {
"hostname": "127.0.0.1",
"directory": "/home/application/testFiles/src",
"fileRegex": ".*\\.txt",
"protocol": {
"name": "sftp",
"port": 1234,
"credentials": {
"username": "{{ SSH_USERNAME }}"
}
}
},
"destination": [
{
"hostname": "{{ HOST_D }}",
"directory": "/home/application/testFiles/dest",
"protocol": {
"name": "sftp",
"credentials": {
"username": "{{ SSH_USERNAME }}"
}
}
}
]
}
2 changes: 2 additions & 0 deletions test/cfg/variables.json.j2
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
{
"HOST_A": "172.16.0.11",
"HOST_B": "172.16.0.12",
"HOST_C": "172.16.0.21",
"HOST_D": "172.16.0.22",
"SSH_USERNAME": "application",
"TEMP_SOURCE_FOLDER": "/tmp",
"MY_FOLDER": "{{ FOLDER1 }}",
Expand Down
6 changes: 6 additions & 0 deletions test/createTestDirectories.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@ rm -fr $DIR/testFiles/ssh_2/src $DIR/testFiles/ssh_2/dest $DIR/testFiles/ssh_2/a

mkdir -p $DIR/testFiles/ssh_1/dest $DIR/testFiles/ssh_1/src $DIR/testFiles/ssh_1/archive $DIR/testFiles/ssh_1/ssh
mkdir -p $DIR/testFiles/ssh_2/dest $DIR/testFiles/ssh_2/src $DIR/testFiles/ssh_2/archive $DIR/testFiles/ssh_2/ssh

rm -fr $DIR/testFiles/sftp_1/src $DIR/testFiles/sftp_1/dest $DIR/testFiles/sftp_1/archive 2>/dev/null
rm -fr $DIR/testFiles/sftp_2/src $DIR/testFiles/sftp_2/dest $DIR/testFiles/sftp_2/archive 2>/dev/null

mkdir -p $DIR/testFiles/sftp_1/dest $DIR/testFiles/sftp_1/src $DIR/testFiles/sftp_1/archive $DIR/testFiles/sftp_1/ssh
mkdir -p $DIR/testFiles/sftp_2/dest $DIR/testFiles/sftp_2/src $DIR/testFiles/sftp_2/archive $DIR/testFiles/sftp_2/ssh
2 changes: 1 addition & 1 deletion tests/test_remotehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def test_cacheable_variable_dotted_notation():

spec = {"task_id": "1234", "x": {"y": "value"}}
spec = {"task_id": "1234", "x": {"y": "value"}, "protocol": {"name": "ssh"}}
rh = SSHTransfer(spec)

assert rh.obtain_variable_from_spec("x.y", spec) == "value"
Expand Down
8 changes: 8 additions & 0 deletions tests/test_ssh_execution_schema_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def test_ssh_basic(valid_execution):

assert validate_execution_json(json_data)

# Add hostKeyValidation
json_data["protocol"]["hostKeyValidation"] = True
assert validate_execution_json(json_data)

# Add knownHostsFile
json_data["protocol"]["knownHostsFile"] = "/some/file"
assert validate_execution_json(json_data)

# Remove protocol
del json_data["protocol"]
assert not validate_execution_json(json_data)
Expand Down
46 changes: 46 additions & 0 deletions tests/test_taskhandler_execution_ssh.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: skip-file
# ruff: noqa
import os
from copy import deepcopy

import pytest
from pytest_shell import fs
Expand Down Expand Up @@ -79,6 +80,51 @@ def test_basic_execution(setup_ssh_keys, root_dir):
assert os.path.exists(f"{root_dir}/testFiles/ssh_2/dest/execution.txt")


def test_basic_execution_host_key_validation(setup_ssh_keys, root_dir):
# Run the above test again, but this time with host key validation
ssh_validation_task_definition = deepcopy(touch_task_definition)
ssh_validation_task_definition["protocol"]["hostKeyValidation"] = True

# Delete the known hosts file if it exists
user_home = os.path.expanduser("~")
known_hosts_file = f"{user_home}/.ssh/known_hosts"
if os.path.exists(known_hosts_file):
os.remove(known_hosts_file)

execution_obj = execution.Execution(
None, "ssh-host-key-validation", ssh_validation_task_definition
)

# Run the execution and expect a false status
assert not execution_obj.run()

# log a load of blank messages
for _ in range(10):
execution_obj.logger.info("")

# SSH onto the host manually and accept the host key so it's saved to the system known hosts
cmd = "ssh -o StrictHostKeyChecking=no [email protected] echo 'test' && ssh -o StrictHostKeyChecking=no [email protected] echo 'test' "
result = subprocess.run(cmd, shell=True, capture_output=True)
assert result.returncode == 0

# Now rerun the execution, but this time it should work
assert execution_obj.run()

# Move the known host file elsewhere and pass the new location to the protocol definition
known_hosts_file = f"{user_home}/.ssh/known_hosts"
new_known_hosts_file = f"{user_home}/known_hosts.new"
os.rename(known_hosts_file, new_known_hosts_file)

ssh_validation_task_definition["protocol"]["knownHostsFile"] = new_known_hosts_file

execution_obj = execution.Execution(
None, "ssh-host-key-validation", ssh_validation_task_definition
)

# Run the execution and expect a true status
assert execution_obj.run()


def test_basic_execution_cmd_failure(setup_ssh_keys, root_dir):
# Write a test file to the source directory
fs.create_files(
Expand Down
Loading

0 comments on commit 90c0099

Please sign in to comment.