diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/__init__.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/auth/authClient.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/auth/authClient.py new file mode 100644 index 000000000..2877a5622 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/auth/authClient.py @@ -0,0 +1,127 @@ +import socket +import ssl +import sys +sys.path.insert(0, '../') +from tools.verification_tools import * +from tools.custom_logger import CustomLogger +from tools.utils import mac_to_ipv6, get_mac_addr +import glob +import random +import time + +MAX_RETRIES = 5 +MIN_WAIT_TIME = 1 # seconds +MAX_WAIT_TIME = 3 # seconds + +logger_instance = CustomLogger("authClient") + + + +class AuthClient: + def __init__(self, interface, server_mac, server_port, cert_path, ca_path, mua): + self.sslServerIP = mac_to_ipv6(server_mac) + self.sslServerPort = server_port + self.CERT_PATH = cert_path + self.interface = interface + self.secure_client_socket = None + self.logger = logger_instance.get_logger() + self.ca = ca_path + self.mymac = get_mac_addr(self.interface) + self.server_mac = server_mac + self.mua = mua + + def establish_connection(self): + # Create an SSL context + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + context.verify_mode = ssl.CERT_REQUIRED + + # Uncomment to enable Certificate Revocation List (CRL) check + # context.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF + + context.load_verify_locations(glob.glob(self.ca)[0]) + context.load_cert_chain( + certfile=glob.glob(f"{self.CERT_PATH}/macsec_{self.mymac.replace(':', '')}.crt")[0], + keyfile=glob.glob(f"{self.CERT_PATH}/macsec_{self.mymac.replace(':', '')}.key")[0], + ) + + # Detect if the server IP is IPv4 or IPv6 and create a socket accordingly + if ":" in self.sslServerIP: + clientSocket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + else: + clientSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + # Make the client socket suitable for secure communication + self.secure_client_socket = context.wrap_socket(clientSocket) + try: + result = self.connection(self.secure_client_socket) + if result['authenticated']: + self.mua.auth_pass(secure_client_socket=self.secure_client_socket, client_mac=self.server_mac) + else: + self.mua.auth_fail(client_mac=self.server_mac) + except Exception as e: + self.logger.error("Define better this exception.", exc_info=True) + self.mua.auth_fail(client_mac=self.server_mac) + # finally: + # # Close the socket + # secureClientSocket.close() + + def connection(self, secureClientSocket): + result = { + 'IP': self.sslServerIP, + 'authenticated': False + } + + try: + self.to_validate(secureClientSocket, result) + except Exception as e: + self.logger.error("An error occurred during the connection process.", exc_info=True) + + finally: + return result + + def to_validate(self, secureClientSocket, result): + # If the IP is a link-local IPv6 address, connect it with the interface index + retries = 0 + while retries < MAX_RETRIES: + try: + if self.sslServerIP.startswith("fe80"): + secureClientSocket.connect( + (self.sslServerIP, self.sslServerPort, 0, socket.if_nametoindex(self.interface))) + else: + secureClientSocket.connect((self.sslServerIP, self.sslServerPort)) + break # break out of loop if connection is successful + except ConnectionRefusedError: + retries += 1 + if retries < MAX_RETRIES: + wait_time = random.uniform(MIN_WAIT_TIME, MAX_WAIT_TIME) + self.logger.info(f"Connection refused. Retrying in {wait_time:.2f} seconds...") + time.sleep(wait_time) + else: + self.logger.error("Exceeded maximum retry attempts. Unable to connect to server.") + #raise ServerConnectionRefusedError("Unable to connect to server socket") + return + + server_cert = secureClientSocket.getpeercert(binary_form=True) + if not server_cert: + self.logger.error("Unable to get the server certificate", exc_info=True) + #raise CertificateNoPresentError("Unable to get the server certificate") + return + + result['authenticated'] = verify_cert(server_cert, self.ca, self.sslServerIP, self.interface, self.logger) + + # # Safe to proceed with the communication, even if the certificate is not authenticated + # msgReceived = secureClientSocket.recv(1024) + # logger.info(f"Secure communication received from server: {msgReceived.decode()}") + + +if __name__ == "__main__": + # IP address and the port number of the server + sslServerIP = "127.0.0.1" + sslServerPort = 15001 + CERT_PATH = '../../../certificates' # Change this to the actual path of your certificates + + auth_client = AuthClient(sslServerIP, sslServerPort, CERT_PATH) + auth_client.establish_connection() + + + diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/auth/authServer.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/auth/authServer.py new file mode 100644 index 000000000..8a41e3a63 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/auth/authServer.py @@ -0,0 +1,136 @@ +import socket +import ssl +import threading +from tools.utils import * +import sys + +sys.path.insert(0, '../') +from tools.verification_tools import * +from tools.custom_logger import CustomLogger +from tools.utils import wait_for_interface_to_be_pingable +import glob + +logger_instance = CustomLogger("Server") +logger = logger_instance.get_logger() + + +class AuthServer: + def __init__(self, interface, ip_address, port, cert_path, ca_path, mua): + threading.Thread.__init__(self) + self.running = True + self.ipAddress = ip_address + self.port = port + self.CERT_PATH = cert_path + self.ca = ca_path + self.interface = interface + self.mymac = get_mac_addr(self.interface) + # Create the SSL context here and set it as an instance variable + self.context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + self.context.verify_mode = ssl.CERT_REQUIRED + self.context.load_verify_locations(glob.glob(self.ca)[0]) + self.context.load_cert_chain( + certfile=glob.glob(f"{self.CERT_PATH}/macsec_{self.mymac.replace(':', '')}.crt")[0], + keyfile=glob.glob(f"{self.CERT_PATH}/macsec_{self.mymac.replace(':', '')}.key")[0], + ) + self.client_auth_results = {} + self.active_sockets = {} + self.client_auth_results_lock = threading.Lock() + self.active_sockets_lock = threading.Lock() + self.mua = mua + + def handle_client(self, client_connection, client_address): + client_mac = extract_mac_from_ipv6(client_address[0]) # TODO: check if it is safe to do so + print("------------------server---------------------") + if client_mac not in self.mua.connected_peers_status: + with self.mua.connected_peers_status_lock: + self.mua.connected_peers_status[client_mac] = ["ongoing",0] # Update status as ongoing, num of failed attempts = 0 + else: + with self.mua.connected_peers_status_lock: + self.mua.connected_peers_status[client_mac][0] = "ongoing" # Update status as ongoing, num of failed attempts = same as before + self.authenticate_client(client_connection, client_address, client_mac) + + def authenticate_client(self, client_connection, client_address, client_mac): + secure_client_socket = self.context.wrap_socket(client_connection, server_side=True) + try: + client_cert = secure_client_socket.getpeercert(binary_form=True) + if not client_cert: + logger.error(f"Unable to get the certificate from the client {client_address[0]}", exc_info=True) + raise CertificateNoPresentError("Unable to get the certificate from the client") + + auth = verify_cert(client_cert, self.ca, client_address[0], self.interface, logger) + with self.client_auth_results_lock: + self.client_auth_results[client_address[0]] = auth + if auth: + with self.active_sockets_lock: + self.active_sockets[client_address[0]] = secure_client_socket + self.mua.auth_pass(secure_client_socket=secure_client_socket, client_mac=client_mac) + else: + # Handle the case when authentication fails, maybe send an error message + self.mua.auth_fail(client_mac=client_mac) + # secure_client_socket.sendall(b"Authentication failed.") + except Exception as e: + logger.error(f"An error occurred while handling the client {client_address[0]}.", exc_info=True) + self.mua.auth_fail(client_mac=client_mac) + # finally: + # secure_client_socket.close() + + def get_secure_socket(self, client_address): + with self.active_sockets_lock: + return self.active_sockets.get(client_address) + + def get_client_auth_result(self, client_address): + with self.client_auth_results_lock: + return self.client_auth_results.get(client_address, None) + + def start_server(self): + if is_ipv4(self.ipAddress): + self.serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.serverSocket.bind((self.ipAddress, self.port)) + elif is_ipv6(self.ipAddress): + self.serverSocket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + scope_id = socket.if_nametoindex(self.interface) + self.serverSocket.bind((self.ipAddress, int(self.port), 0, scope_id)) + else: + raise ValueError("Invalid IP address") + + self.serverSocket.listen() + self.serverSocket.settimeout(99999) # maybe we can remove timeout since server needs to be listening throughout + logger.info("Server listening") + + while self.running and not self.mua.shutdown_event.is_set(): + try: + client_connection, client_address = self.serverSocket.accept() + threading.Thread(target=self.handle_client, args=(client_connection, client_address)).start() + except socket.timeout: # In case we add a timeout later. + continue + except Exception as e: + if self.running: + logger.error("Unexpected error in server loop.", exc_info=True) + + def stop_server(self): + self.running = False + if is_ipv4(self.ipAddress): + serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + serverSocket.bind((self.ipAddress, self.port)) + elif is_ipv6(self.ipAddress): + serverSocket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + scope_id = socket.if_nametoindex(self.interface) + serverSocket.bind((self.ipAddress, int(self.port), 0, scope_id)) + if hasattr(self, "serverSocket"): + self.serverSocket.close() + for sock in auth_server.active_sockets.values(): + sock.close() + +if __name__ == "__main__": + # IP address and the port number of the server + ipAddress = "127.0.0.1" + port = 15001 + CERT_PATH = '../../../certificates' # Change this to the actual path of your certificates + + auth_server = AuthServer(ipAddress, port, CERT_PATH) + auth_server.start_server() + + # Access the authentication result for a specific client + client_address = ("127.0.0.1", 12345) # Replace with the actual client address you want to check + auth_result = auth_server.get_client_auth_result(client_address) + print(f"Authentication result for {client_address}: {auth_result}") diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/ca_side.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/ca_side.py new file mode 100644 index 000000000..1f5eacdab --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/ca_side.py @@ -0,0 +1,136 @@ +""" +for running this code, it is necessary to add the pub key for ssh on the other node. +Steps: +1) ssh-keygen -t ecdsa +2) ssh-copy-id user@somedomain (to server) +""" +# TODO: change to ntas +import os +import subprocess +import socket +import threading +import time +import shutil +import argparse +import sys + +sys.path.insert(0, '../') +from tools.custom_logger import CustomLogger + +logger_instance = CustomLogger("ca-side") +logger = logger_instance.get_logger() + +# Create the CSR directory if it doesn't exist +csr_directory = "/tmp/request/" # Update this with the actual path +if not os.path.exists(csr_directory): + os.makedirs(csr_directory) + + +# Function to run a command and get its output +def run_command(command): + try: + return subprocess.run(command, shell=True, text=True, capture_output=True, check=True) + except subprocess.CalledProcessError as e: + logger.error(f"Error running command: {command}") + logger.error(e.stderr) + raise + + +# Function to generate certificates and send files +def generate_and_send_certificates(csr_filename, IPAddress): + try: + # Generate the certificate using the Bash script + generate_cert_script = "./generate_certificates.sh" + output = run_command(f"{generate_cert_script} {csr_filename} {csr_directory}") + logger.info(output.stdout) + # Rename certificate files + crt = f"{os.path.splitext(csr_filename)[0]}.crt" + crt_filename = crt.split(os.sep)[-1] + shutil.copy(f"certificates/{crt_filename}", f"{csr_directory}{crt_filename}") + ca_crt_filename = "ca.crt" + shutil.copy(f"certificates/{ca_crt_filename}", f"{csr_directory}{ca_crt_filename}") + + # Send CA.crt and signed certificate (crt) back to the client via SCP + files_to_send = [crt_filename, ca_crt_filename] + + for file_to_send in files_to_send: + scp_command = f"scp {csr_directory}{file_to_send} root@{IPAddress}:{csr_directory}" + run_command(scp_command) + + logger.info("Certificates sent successfully.") + except Exception as e: + logger.error(f"Error generating and sending certificates: {e}") + raise + + # Clean up + # os.remove(crt_filename) + # os.remove(ca_crt_filename) + + +# Function to handle client connection +def handle_client(client_socket, IPaddress): + try: + # Receive the CSR filename from the client + csr_filename = client_socket.recv(1024).decode() + if "CSR uploaded " in csr_filename: + csr_filename = csr_filename.split("CSR uploaded ")[1] + # Acknowledge the filename + client_socket.sendall(b"Filename received") + + # Verify the CSR filename in the monitoring directory + full_path = os.path.join(csr_directory, csr_filename) + if os.path.exists(full_path) and csr_filename.endswith(".csr") and csr_filename != '': + generate_and_send_certificates(full_path, IPaddress) + os.remove(full_path) # Remove the processed CSR file + except Exception as e: + logger.error(f"Error handling client connection: {e}") + + +# Function to monitor a directory for new CSR files +def monitor_csr_directory(directory, existing_files): + try: + while True: + for filename in os.listdir(directory): + if filename.endswith(".csr") and filename not in existing_files: + full_path = os.path.join(directory, filename) + existing_files.add(filename) + time.sleep(10) # Adjust the interval as needed + except Exception as e: + logger.error(f"Error monitoring CSR directory: {e}") + + +def main(): + try: + # Parse command-line arguments + parser = argparse.ArgumentParser(description="Server for handling CSR files") + parser.add_argument("--port", type=int, default=12345, help="Port to listen on") + args = parser.parse_args() + + # Server configuration + host = "0.0.0.0" + port = args.port + + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.bind((host, port)) + server_socket.listen(1) + + logger.info("Server listening on port %d", port) + + # Start the CSR monitoring thread + existing_files = set() + monitor_thread = threading.Thread(target=monitor_csr_directory, args=(csr_directory, existing_files,)) + monitor_thread.start() + + while True: + # Handle the client connection + client_socket, client_address = server_socket.accept() + logger.info("Accepted connection from %s", client_address) + handle_client(client_socket, client_address[0]) + client_socket.close() + + except Exception as e: + logger.error(f"Main loop error: {e}") + + +if __name__ == "__main__": + main() diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/client_side.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/client_side.py new file mode 100644 index 000000000..2ee842aea --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/client_side.py @@ -0,0 +1,101 @@ +""" +for running this code, it is necessary to add the pub key for ssh on the other node. +Steps: +1) ssh-keygen -t ecdsa +2) ssh-copy-id user@somedomain (to server) +""" +import subprocess +import socket +import os +import time +import glob +import shutil +import argparse + +import sys +sys.path.insert(0, '../') +from tools.utils import get_mac_addr + + +# Constants +REMOTE_PATH = "/tmp/request" +if not os.path.exists(REMOTE_PATH): + os.makedirs(REMOTE_PATH) + +LOCAL_PATH = "certificates/" +if not os.path.exists(LOCAL_PATH): + os.makedirs(LOCAL_PATH) + +CUSTOM_PORT = 12345 +CSR_SCRIPT_PATH = "./generate-csr.sh" + + +def run_command(command): + try: + return subprocess.run(command, shell=True, text=True, capture_output=True, check=True) + except subprocess.CalledProcessError as e: + print(f"Error running command: {command}. Error: {e.stderr}") + exit(1) + + +def is_server_reachable(server_ip, port=CUSTOM_PORT): + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(2) + s.connect((server_ip, port)) + return True + except (socket.timeout, ConnectionRefusedError): + return False + + +def upload_file_to_server(local_file, server_ip, username, remote_path=REMOTE_PATH): + scp_command = f"scp -i ~/.ssh/id_rsa {local_file} {username}@{server_ip}:{remote_path}" + run_command(scp_command) + + +def are_files_received(csr_filename, path=REMOTE_PATH): + # Paths for the required files + ca_crt_path = os.path.join(path, "ca.crt") + csr_crt_path = os.path.join(path, f"{os.path.splitext(csr_filename)[0]}.crt") + + # Check if both files exist + return os.path.exists(ca_crt_path) and os.path.exists(csr_crt_path) + + +def main(interface): + server_ip = input("Enter the server IP address: ") + username = input("Enter your username: ") + + if not is_server_reachable(server_ip): + print("Error: The server is not reachable.") + exit(1) + + #run_command(CSR_SCRIPT_PATH) + subprocess.run([CSR_SCRIPT_PATH, interface], check=True) + mac_address = get_mac_addr(interface) + csr_filename = glob.glob(f"macsec_{mac_address.replace(':', '')}.csr")[0] + crt = csr_filename.split(".csr")[0]+".crt" + upload_file_to_server(csr_filename, server_ip, username) + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client_socket: + client_socket.connect((server_ip, CUSTOM_PORT)) + message = f"CSR uploaded {csr_filename}".encode() + client_socket.sendall(message) + + # This will allow the server some time to process and send the files. You can adjust the sleep time if needed. + time.sleep(10) + + if are_files_received(csr_filename): + print(f"The crt file ({crt}) and the ca.crt have been successfully received locally!") + for f in glob.glob(f"{REMOTE_PATH}/*"): + shutil.copy2(f, LOCAL_PATH) + shutil.copy2(csr_filename.split(".csr")[0]+".key", LOCAL_PATH) + else: + print("Failed to confirm the receipt of the CSR file and ca.crt locally.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Client script to generate CSR and get it signed by CA for an interface") + parser.add_argument('--interface', required=True, help='Interface name: Eg. wlp1s0, halow1') + args = parser.parse_args() + main(args.interface) diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/generate-csr.sh b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/generate-csr.sh new file mode 100755 index 000000000..7f67b4e43 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/generate-csr.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +# Function to derive IPv6 address from MAC address +function mac_to_ipv6() { + # Remove any separators from the MAC address (e.g., colons, hyphens) + mac_address="${1//[:\-]/}" + mac_address="${mac_address,,}" + + # Split the MAC address into two equal halves + first_half="${mac_address:0:6}" + + # Convert the first octet from hexadecimal to binary + binary_first_octet=$(echo "obase=2; ibase=16; ${first_half:0:2}" | bc | xargs printf "%08d") + + # Invert the seventh bit (change 0 to 1 or 1 to 0) + inverted_seventh_bit=$(( 1 - $(echo "${binary_first_octet:6:1}") )) + + # Convert the modified binary back to hexadecimal + modified_first_octet=$(echo "obase=16; ibase=2; ${binary_first_octet:0:6}${inverted_seventh_bit}${binary_first_octet:7}" | bc) + + # Replace the original first octet with the modified one + modified_mac_address="${modified_first_octet}${mac_address:2}" + + line="${modified_mac_address:0:6}fffe${modified_mac_address:6}" + + # Add "ff:fe:" to the middle of the new MAC address + mac_with_fffe=$(echo "$line" | sed -r 's/(.{4})/\1:/g; s/:$//') + + echo "fe80::$mac_with_fffe" +} + +# Read the user input for the network interface, defaulting to "wlp1s0" if no input is provided +#read -p "Enter the network interface (default: wlp1s0): " network_interface +#network_interface=${network_interface:-wlp1s0} + +# Get the network interface from the command-line argument or use a default value +network_interface=${1:-wlp1s0} + + +# Parse the MAC address from the network interface using ip command +mac_address=$(ip link show $network_interface | awk '/ether/ {print $2}') + +id=${mac_address//:/} #mac address with no colon + +# Derive the IPv6 address from the MAC address (extended format) +ipv6_address=$(mac_to_ipv6 "$mac_address") + +# Generate the EC private key +openssl ecparam -name prime256v1 -genkey -noout -out macsec_"$id".key # for other certificates (eg ipsec) we need to verify if this exists + +# Generate the 256-bit random number from the fingerprint of the public key +random=$(openssl ec -in macsec_"$id".key -pubout -outform DER | sha256sum | awk '{print substr($1, 1, 30)}') + + +# Derive the second IPv6 address (mesh IPv6) from the ID (extended format) +mesh_ipv6="fe80::$(echo "$random" | cut -c1-4):$(echo "$random" | cut -c5-8):$(echo "$random" | cut -c9-12)" + + +# Create a CSR configuration file with the custom SANs +cat > csr.conf </dev/null; then + # Generate the Certificate Authority (CA) key and self-signed certificate + openssl ecparam -name prime256v1 -genkey -noout -out ca.key + openssl req -new -x509 -key ca.key -out ca.crt -days 365 -subj "/CN=TII" +fi + +# Sign the server CSR with the CA to get the server certificate +openssl x509 -req -in "$1" -CA ca.crt -CAkey ca.key -CAcreateserial -out "$(basename "$1" .csr).crt" -days 365 + +# Verify that the certificate file has been created +if [ -f "$(basename "$1" .csr).crt" ]; then + echo "Certificates have been generated successfully in the 'certificates' directory." +else + echo "Failed to generate the certificate. Please check the input CSR and CA files." +fi + +echo "Verifying certificates..." +openssl verify -CAfile ca.crt "$(basename "$1" .csr).crt" diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/test/test_ca_side.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/test/test_ca_side.py new file mode 100644 index 000000000..6a6602345 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/test/test_ca_side.py @@ -0,0 +1,35 @@ +import sys +sys.path.insert(0, '../') +from ca_side import monitor_csr_directory, generate_and_send_certificates, handle_client +import pytest +from unittest.mock import MagicMock, patch, mock_open + + + +@pytest.fixture +def mock_socket(): + class MockSocket: + def __init__(self): + self.data = None + + def recv(self, num): + return b'CSR uploaded test.csr' + + def send(self, data): + self.data = data + return len(data) + + return MockSocket() + + +@patch('ca_side.os.path.exists', return_value=True) +@patch('ca_side.generate_and_send_certificates', return_value=None) +@patch('ca_side.os.remove', return_value=None) +@patch('ca_side.logging.error') +def test_handle_client(mock_log_error, mock_remove, mock_gen_send, mock_exists, mock_socket): + handle_client(mock_socket, "127.0.0.1") + + assert mock_socket.data == b'Filename received' + mock_gen_send.assert_called_once_with('/tmp/request/test.csr', '127.0.0.1') + mock_remove.assert_called_once_with('/tmp/request/test.csr') + mock_log_error.assert_not_called() diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/test/test_client_side.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/test/test_client_side.py new file mode 100644 index 000000000..04dbc9518 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cert_generation/test/test_client_side.py @@ -0,0 +1,54 @@ +import pytest +from unittest.mock import patch, Mock +import sys +import subprocess +import socket +sys.path.insert(0, '../') +import client_side + +# Sample data for the tests +sample_csr_output = "Some random text. CSR generated: test.csr" + +def test_run_command_successful(): + with patch('subprocess.run', return_value=Mock(stdout=sample_csr_output)): + result = client_side.run_command("ls") + assert result.stdout == sample_csr_output + +def test_run_command_error(): + with patch('subprocess.run', side_effect=subprocess.CalledProcessError(1, 'cmd', stderr="Error")): + with pytest.raises(SystemExit): + client_side.run_command("invalid_command") + +def test_is_server_reachable(): + with patch('socket.socket') as mock_socket: + instance = mock_socket.return_value.__enter__.return_value + instance.settimeout.return_value = None + instance.connect.return_value = None + + assert client_side.is_server_reachable("127.0.0.1") == True + +def test_is_server_unreachable_due_to_timeout(): + with patch('socket.socket') as mock_socket: + instance = mock_socket.return_value.__enter__.return_value + instance.connect.side_effect = socket.timeout + + assert client_side.is_server_reachable("192.0.2.0") == False + +def test_is_server_unreachable_due_to_refusal(): + with patch('socket.socket') as mock_socket: + instance = mock_socket.return_value.__enter__.return_value + instance.connect.side_effect = ConnectionRefusedError + + assert client_side.is_server_reachable("192.0.2.0") == False + +def test_upload_file_to_server(): + with patch('client_side.run_command', return_value=None): + client_side.upload_file_to_server("test.csr", "127.0.0.1", "user") + +def test_are_files_received(): + with patch('os.path.exists', side_effect=lambda x: True): + assert client_side.are_files_received("test.csr") + +def test_are_files_not_received(): + with patch('os.path.exists', side_effect=lambda x: False): + assert not client_side.are_files_received("test.csr") \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cleanup_cbma.sh b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cleanup_cbma.sh new file mode 100644 index 000000000..37562bd2d --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/cleanup_cbma.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# Deletes macsec links, batman and bridges created within cbma +ip macsec show | grep ': protect on validate' | awk -F: '{print $2}' | awk '{print $1}' | xargs -I {} ip link delete {} +ifconfig bat0 down +ifconfig bat1 down +batctl meshif bat0 interface destroy +batctl meshif bat1 interface destroy +ip link set br-upper down +brctl delbr br-upper \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/delete_default_brlan_bat0.sh b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/delete_default_brlan_bat0.sh new file mode 100644 index 000000000..407db8230 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/delete_default_brlan_bat0.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Check if the bridge "br-lan" exists +if ip link show br-lan &> /dev/null; then + # Delete the bridge "br-lan" if it exists + echo "Deleting bridge br-lan..." + ip link set br-lan down + brctl delbr br-lan + echo "Bridge br-lan deleted." +else + echo "Bridge br-lan does not exist." +fi + +# Check if the batman interface "bat0" exists +if ip link show bat0 &> /dev/null; then + # Delete the batman interface "bat0" if it exists + echo "Deleting batman interface bat0..." + ip link set bat0 down + batctl meshif bat0 interface destroy + + + echo "Batman interface bat0 deleted." +else + echo "Batman interface bat0 does not exist." +fi \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/macsec/macsec.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/macsec/macsec.py new file mode 100644 index 000000000..bc817252c --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/macsec/macsec.py @@ -0,0 +1,69 @@ +import subprocess +import sys +import os +import random +import threading + +script_dir = os.path.dirname(__file__) # path to macsec directory +sys.path.insert(0, f'{script_dir}/../') +from tools.custom_logger import CustomLogger + +logger_instance = CustomLogger("macsec") +class Macsec: + def __init__(self, level, interface, macsec_encryption): + self.level = level # Macsec level: "lower" or "upper" + self.interface = interface + self.available_ports = set(range(1, 2**16)) # 1 to 2^16-1 + self.used_ports = {} # client_mac: port + self.available_ports_lock = threading.Lock() + self.macsec_encryption = macsec_encryption # Flag to set macsec encrypt on or off + self.logger = logger_instance.get_logger() + + def set_macsec_tx(self, client_mac, my_macsec_key, my_port): + # Sets up macsec link and adds tx channel + macsec_interface = self.get_macsec_interface_name(client_mac) + try: + subprocess.run(["ip", "link", "add", "link", self.interface, macsec_interface, "type", "macsec", "port", str(my_port), "encrypt", self.macsec_encryption, "cipher", "gcm-aes-256"], check=True) + subprocess.run(["ip", "macsec", "add", macsec_interface, "tx", "sa", "0", "pn", "1", "on", "key", "01", my_macsec_key], check=True) + subprocess.run(["ip", "link", "set", macsec_interface, "up"], check=True) + subprocess.run(["ip", "macsec", "show"], check=True) + self.logger.info(f'{self.level} macsec tx channel set with {client_mac}') + except Exception as e: + self.logger.error(f'Error setting up {self.level} macsec tx channel with {client_mac}: {e}') + sys.exit(1) + + def set_macsec_rx(self, client_mac, client_macsec_key, client_port): + # Adds a rx channel with client_mac, with key id = client mac without ":" + macsec_interface = self.get_macsec_interface_name(client_mac) + try: + subprocess.run(["ip", "macsec", "add", macsec_interface, "rx", "port", str(client_port), "address", client_mac], check=True) + subprocess.run(["ip", "macsec", "add", macsec_interface, "rx", "port", str(client_port), "address", client_mac, "sa", "0", "pn", "1", "on", "key", client_mac.replace(":", ""), client_macsec_key], check=True) + subprocess.run(["ip", "macsec", "show"], check=True) + self.logger.info(f'{self.level} macsec rx channel set with {client_mac}') + except Exception as e: + self.logger.error(f'Error setting up {self.level} macsec rx channel with {client_mac}: {e}') + + def get_macsec_interface_name(self, client_mac): + if self.level == "lower": + return f"lms{client_mac.replace(':', '')}" + else: + return f"ums{client_mac.replace(':', '')}" + + + def assign_unique_port(self, client_mac): + with self.available_ports_lock: + if client_mac in self.used_ports: + return self.used_ports[client_mac] + if not self.available_ports: + raise ValueError("No available ports.") + port = random.sample(list(self.available_ports), 1)[0] + self.available_ports.remove(port) + self.used_ports[client_mac] = port + return port + + def release_port(self, client_mac): + with self.available_ports_lock: + if client_mac not in self.used_ports: + raise ValueError(f"Client {client_mac} is not in the list of used ports.") + self.available_ports.add(self.used_ports[client_mac]) + del self.used_ports[client_mac] \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/multicast/__init__.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/multicast/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/multicast/multicast.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/multicast/multicast.py new file mode 100644 index 000000000..ce49bff51 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/multicast/multicast.py @@ -0,0 +1,111 @@ +import socket +import struct +import threading +import time +import argparse +import json +import sys +from queue import Queue, Empty +sys.path.insert(0, '../') +from tools.utils import get_mac_addr +from tools.custom_logger import CustomLogger + +logger_instance = CustomLogger("multicast") + + +class MulticastHandler: + def __init__(self, qeue, multicast_group, port, interface, shutdown_event=threading.Event()): + self.queue = qeue + self.multicast_group = multicast_group + self.port = port + self.interface = interface # Multicast interface. Set as radio name: in case of TLS for lower macsec, lower batman interface: in case of TLS for upper macsec + self.logger = logger_instance.get_logger() + self.excluded = [get_mac_addr(interface), f'{get_mac_addr(interface)}_server'] + self.shutdown_event = shutdown_event + + def send_multicast_message(self, data): + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 1) #TODO: check if multi-hop nodes receive multicast over bat0 + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, socket.if_nametoindex(self.interface)) # Set multicast interface + + message = { + 'mac_address': data, + 'message_type': 'mac_announcement' + } + self.logger.info(f'Sending data {message} to {self.multicast_group}:{self.port} from interface {self.interface}') + sock.sendto(json.dumps(message).encode('utf-8'), (self.multicast_group, self.port)) + + def receive_multicast(self): + + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDP) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, 25, str(self.interface + '\0').encode('utf-8')) # Bind socket to interface + + # Bind to the wildcard address and desired port + sock.bind(('::', self.port)) + + # Set the multicast interface + index = socket.if_nametoindex(self.interface) + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, index.to_bytes(4, byteorder='little')) + + # Construct the membership request + mreq = socket.inet_pton(socket.AF_INET6, self.multicast_group) + index.to_bytes(4, byteorder='little') + + # Add the membership to the socket + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, mreq) + + self.logger.info(f"Listening for messages on {self.multicast_group}:{self.port}...") + while not self.shutdown_event.is_set(): + data, address = sock.recvfrom(1024) + decoded_data = json.loads(data.decode()) + if decoded_data['mac_address'] not in self.excluded: + self.logger.info(f'Received data {decoded_data} from {address} at interface {self.interface}') + if 'mac_address' in decoded_data: + self.queue.put(("MULTICAST", decoded_data['mac_address'])) + + def multicast_message(self): + self.receive_multicast() + + +def main(): + parser = argparse.ArgumentParser(description="IPv6 Multicast Sender/Receiver") + parser.add_argument('--mode', choices=['send', 'receive', 'both'], required=True, help='Run mode: send, receive, or both') + parser.add_argument('--address', default='ff02::1', help='Multicast IPv6 address (default: ff02::1)') + parser.add_argument('--port', type=int, default=12345, help='Port to use (default: 12345)') + parser.add_argument('--interface', default="wlp1s0", help='Multicast interface (default: wlp1s0)') + + args = parser.parse_args() + queue = Queue() + + multicast_handler = MulticastHandler(queue, args.address, args.port, args.interface) + message = get_mac_addr(args.interface) + + if args.mode == 'receive': + multicast_handler.receive_multicast() + elif args.mode == 'send': + multicast_handler.send_multicast_message(message) + elif args.mode == 'both': + receiver_thread = threading.Thread(target=multicast_handler.multicast_message) + receiver_thread.start() + + # Wait a bit for the receiver thread to start + time.sleep(2) + + multicast_handler.send_multicast_message(message) + + try: + while True: + source, data = queue.get(timeout=10) + if source == "MULTICAST": + multicast_handler.logger.info(f"Main thread received MAC: {data}") + + except Empty: + pass + + except KeyboardInterrupt: + multicast_handler.logger.info("Shutting down...") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/mutauth.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/mutauth.py new file mode 100644 index 000000000..e6ae709b7 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/mutauth.py @@ -0,0 +1,202 @@ +import os +import sys +path_to_cbma_dir = os.path.dirname(__file__) # Path to dir containing this script +sys.path.insert(0, path_to_cbma_dir) + +from auth.authServer import AuthServer +from auth.authClient import AuthClient +import threading +from multicast.multicast import MulticastHandler +from tools.monitoring_wpa import * +from tools.utils import * +from tools.custom_logger import CustomLogger +from macsec import macsec +import queue +import random +import json +from secure_channel.secchannel import SecMessageHandler +BEACON_TIME = 10 +MAX_CONSECUTIVE_NOT_RECEIVED = 2 +MULTICAST_ADDRESS = 'ff02::1' +TIMEOUT = 3 * BEACON_TIME +logger_instance = CustomLogger("mutAuth") + +class mutAuth(): + def __init__(self, in_queue, level, meshiface, port, batman_interface, path_to_certificate, path_to_ca, macsec_encryption, shutdown_event, lan_bridge_flag=False): + self.level = level + self.meshiface = meshiface + self.mymac = get_mac_addr(self.meshiface) + self.ipAddress = mac_to_ipv6(self.mymac) + self.port = port + self.CERT_PATH = path_to_certificate # Absolute path to certificates folder + self.CA_PATH = path_to_ca # Absolute path to ca certificate + self.in_queue = in_queue + self.logger = logger_instance.get_logger() + self.multicast_handler = MulticastHandler(self.in_queue, MULTICAST_ADDRESS, self.port, self.meshiface, shutdown_event) + self.stop_event = threading.Event() + self.sender_thread = threading.Thread(target=self._periodic_sender, args=()) + self.shutdown_event = shutdown_event # Add this to handle graceful shutdown + self.batman_interface = batman_interface + self.macsec_obj = macsec.Macsec(level=self.level, interface=self.meshiface, macsec_encryption=macsec_encryption) # Initialize macsec object + if self.level == "upper": + self.bridge_interface = "br-upper" # bridge for upper macsec interfaces + setup_bridge(self.bridge_interface) + add_interface_to_batman(interface_to_add=self.bridge_interface, batman_interface=self.batman_interface) + self.connected_peers_status = {} # key = client mac address, value = [status : ("ongoing", "authenticated", "not connected"), no of failed attempts]} + self.connected_peers_status_lock = threading.Lock() + self.maximum_num_failed_attempts = 3 # Maximum number of failed attempts for mutual authentication (can be changed) + self.lan_bridge_flag = lan_bridge_flag + self.macsec_setup_event = threading.Event() + + def check_mesh(self): + if not is_wpa_supplicant_running(): + logger.info("wpa_supplicant process is not running.") + run_wpa_supplicant(self.meshiface) + set_ipv6(self.meshiface, self.ipAddress) + else: + logger.info("wpa_supplicant process is running.") + + def _periodic_sender(self): + while not self.stop_event.is_set() and not self.shutdown_event.is_set(): + self.multicast_handler.send_multicast_message(self.mymac) + time.sleep(BEACON_TIME) + + def monitor_wpa_multicast(self): + #muthread = threading.Thread(target=self.multicast_handler.receive_multicast) + #muthread.start() + + while not self.shutdown_event.is_set(): + source, message = self.in_queue.get() + if source == "WPA": + self.logger.info("External node_connect event triggered!") + self.logger.info(f"Received MAC from WPA event: {message}") + handle_peer_connected_thread = threading.Thread(target=self.handle_wpa_multicast_event, args=(message,)) + handle_peer_connected_thread.start() + elif source == "MULTICAST": + self.logger.info(f"Received MAC on multicast: {message} at interface {self.meshiface}") + handle_peer_connected_thread = threading.Thread(target=self.handle_wpa_multicast_event, args=(message,)) + handle_peer_connected_thread.start() + + def handle_wpa_multicast_event(self, mac): + if mac not in self.connected_peers_status: + # There is no ongoing connection with peer yet + # Wait for random seconds + random_wait = random.uniform(0.5,3) # Wait between 0.5 to 3 seconds. Random waiting to avoid race condition + time.sleep(random_wait) + if mac not in self.connected_peers_status: + # Start as client + print("------------------client ---------------------") + with self.connected_peers_status_lock: + self.connected_peers_status[mac] = ["ongoing", 0] # Update status as ongoing, num of failed attempts = 0 + self.start_auth_client(mac) + elif self.connected_peers_status[mac][0] not in ["ongoing", "authenticated"]: + # If node does not have ongoing authentication or is not already authenticated or has not been blacklisted + # Wait for random seconds + random_wait = random.uniform(0.5,3) # Wait between 0.5 to 3 seconds. Random waiting to avoid race condition + time.sleep(random_wait) + if self.connected_peers_status[mac][0] not in ["ongoing", "authenticated"]: + # Start as client + print("------------------client ---------------------") + with self.connected_peers_status_lock: + self.connected_peers_status[mac][0] = "ongoing" # Update status as ongoing, num of failed attempts = same as before + self.start_auth_client(mac) + + def start_auth_server(self): + auth_server = AuthServer(self.meshiface, self.ipAddress, self.port, self.CERT_PATH, self.CA_PATH, self) + auth_server_thread = threading.Thread(target=auth_server.start_server) + auth_server_thread.start() + return auth_server_thread, auth_server + + def start_auth_client(self, server_mac): + cli = AuthClient(self.meshiface, server_mac, self.port, self.CERT_PATH, self.CA_PATH, self) + cli.establish_connection() + + def auth_pass(self, secure_client_socket, client_mac): + # Steps to execute if auth passes + with self.connected_peers_status_lock: + self.connected_peers_status[client_mac][0] = "authenticated" # Update status as authenticated, num of failed attempts = same as before + self.setup_macsec(secure_client_socket=secure_client_socket, client_mac=client_mac) + + def auth_fail(self, client_mac): + # Steps to execute if auth fails + with self.connected_peers_status_lock: + self.connected_peers_status[client_mac][1] = self.connected_peers_status[client_mac][1] + 1 # Increment number of failed attempt by 1 + self.connected_peers_status[client_mac][0] = "not connected" # Update status as not connected + def batman(self, batman_interface): + try: + batman_exec(batman_interface,"batman-adv") + except Exception as e: + logger.error(f'Error setting up bat0: {e}') + sys.exit(1) + + def setup_secchannel(self, secure_client_socket, my_macsec_param): + # Establish secure channel and exchange macsec key + secchan = SecMessageHandler(secure_client_socket, self.shutdown_event) + macsec_param_q = queue.Queue() # queue to store macsec parameters: macsec_key, port from client_secchan.receive_message + receiver_thread = threading.Thread(target=secchan.receive_message, args=(macsec_param_q,)) + receiver_thread.start() + print(f"Sending my macsec parameters: {my_macsec_param} to {secure_client_socket.getpeername()[0]}") + secchan.send_message(json.dumps(my_macsec_param)) + client_macsec_param = json.loads(macsec_param_q.get()) + return secchan, client_macsec_param + + def setup_macsec(self, secure_client_socket, client_mac): + # Setup macsec + # Compute macsec parameters + bytes_for_my_key = generate_random_bytes() # bytes for my key + bytes_for_client_key = generate_random_bytes() # bytes for client key + my_port = self.macsec_obj.assign_unique_port(client_mac) + my_macsec_param = {'bytes_for_my_key': bytes_for_my_key.hex(), 'bytes_for_client_key': bytes_for_client_key.hex(), 'port': my_port} # Bytes conveted into hex strings so that they can be dumped into json later + + # Establish secure channel and exchange bytes for macsec keys and port + secchan, client_macsec_param = self.setup_secchannel(secure_client_socket, my_macsec_param) + + # Compute keys by XORing bytes + my_macsec_key = xor_bytes(bytes_for_my_key, bytes.fromhex(client_macsec_param['bytes_for_client_key'])).hex() # XOR my bytes_for_my_key with client's bytes_for_client_key + client_macsec_key = xor_bytes(bytes_for_client_key, bytes.fromhex(client_macsec_param['bytes_for_my_key'])).hex() # XOR my bytes_for_client_key with client's bytes_for_my_key + + self.macsec_obj.set_macsec_tx(client_mac, my_macsec_key, my_port) # setup macsec tx channel + self.macsec_obj.set_macsec_rx(client_mac, client_macsec_key, client_macsec_param['port']) # setup macsec rx channel + self.add_to_batman(client_mac) + self.macsec_setup_event.set() + + def add_to_batman(self, client_mac): + # Adds macsec interface to batman + if self.level == "lower": + # TODO: change this to add lower macsec interfaces to a bridge before adding to batman (need to avoid bridge loops with ebtables) + # Add lower macsec interface to lower batman interface + add_interface_to_batman(interface_to_add=self.macsec_obj.get_macsec_interface_name(client_mac), batman_interface=self.batman_interface) + elif self.level == "upper": + add_interface_to_bridge(interface_to_add=self.macsec_obj.get_macsec_interface_name(client_mac), bridge_interface=self.bridge_interface) + + def setup_batman(self): + # Wait till a macsec interface is setup and added to batman before setting up batman interface + self.macsec_setup_event.wait() + if not is_interface_up(self.batman_interface): # Turn batman interface up if not up already + self.batman(self.batman_interface) + if self.level == "upper" and self.lan_bridge_flag: + # Add bat1 and ethernet interface to br-lan to connect external devices + bridge_interface = "br-lan" + if not is_interface_up(bridge_interface): + self.setup_bridge_over_batman(bridge_interface) + + def setup_bridge_over_batman(self, bridge_interface): + # TODO: Need to configure interfaces to add to br-lan (This is just for quick test) + subprocess.run(["brctl", "addbr", bridge_interface], check=True) + add_interface_to_bridge(interface_to_add=self.batman_interface, bridge_interface=bridge_interface) # Add bat1 to br-lan + add_interface_to_bridge(interface_to_add="eth1", bridge_interface=bridge_interface) # Add eth1 to br-lan + logger.info(f"Setting mac address of {bridge_interface} to be same as {self.batman_interface}..") + subprocess.run(["ip", "link", "set", "dev", bridge_interface, "address", get_mac_addr(self.batman_interface)], check=True) + subprocess.run(["ip", "link", "set", bridge_interface, "up"], check=True) + subprocess.run(["ifconfig", bridge_interface], check=True) + + + + def start(self): + # ... other starting procedures + self.sender_thread.start() + + def stop(self): + # Use this method to stop the periodic sender and other threads + self.stop_event.set() + self.sender_thread.join() \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/secure_channel/secchannel.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/secure_channel/secchannel.py new file mode 100644 index 000000000..4a320a63a --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/secure_channel/secchannel.py @@ -0,0 +1,74 @@ +import queue +import ssl +import socket +import json +import threading +from tools.custom_logger import CustomLogger + +logger_instance = CustomLogger("SecChannel") + + +class SecMessageHandler: + def __init__(self, socket, shutdown_event=threading.Event()): + self.socket = socket + self.logger = logger_instance.get_logger() + self.callback = None + self.shutdown_event = shutdown_event + + def set_callback(self, callback): + """Set the callback function to be executed when a message is received.""" + self.callback = callback + + def _is_socket_active(self): + """Check if the socket is active.""" + try: + # Just a simple check; will rely on recv() in receive_message() for more detailed check. + return self.socket.fileno() != -1 + except socket.error: + return False + + def _is_ssl_socket(self): + return isinstance(self.socket, ssl.SSLSocket) + + def send_message(self, message): + """Send a message through the socket.""" + if not self._is_ssl_socket(): + self.logger.error("Socket is not SSL enabled.") + return + + if not self._is_socket_active(): + self.logger.error("Socket is not active.") + return + + try: + self.socket.sendall(message.encode()) + self.logger.info(f"Sent: {message} to {self.socket.getpeername()[0]}") + except Exception as e: + self.logger.error(f"Error sending message to {self.socket.getpeername()[0]}.", exc_info=True) + + def receive_message(self, macsec_param_q=queue.Queue()): + """Continuously receive messages from the socket.""" + if not self._is_ssl_socket(): + self.logger.error("Socket is not SSL enabled.") + return + try: + while not self.shutdown_event.is_set(): + # No need to check _is_socket_active here, rely on recv's result. + data = self.socket.recv(1024).decode() + if not data: + self.logger.warning("Connection closed or socket not active.") + break + elif data == "GOODBYE": # trigger to close it + self.logger.info("Other end signaled end of communication.") + break + else: + self.logger.info(f"Received: {data} from {self.socket.getpeername()[0]}") + if self.callback: + self.callback(data) # Execute the callback with the received data + if 'bytes_for_my_key' in data and 'bytes_for_client_key' in data and 'port' in data: # if received data has macsec parameters, put it in queue + macsec_param_q.put(data) + + except socket.timeout: + self.logger.warning("Connection timed out. Ending communication.") + except Exception as e: + self.logger.error("Error receiving message.", exc_info=True) \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/setup_cbma.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/setup_cbma.py new file mode 100644 index 000000000..eff858a31 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/setup_cbma.py @@ -0,0 +1,143 @@ +from mutauth import * +from tools.monitoring_wpa import WPAMonitor +import argparse +import subprocess + +shutdown_event = threading.Event() +cbma_threads = [] +file_dir = os.path.dirname(__file__) # Path to dir containing this script + +def setup_macsec(level, interface_name, port, batman_interface, path_to_certificate, path_to_ca, macsec_encryption, wpa_supplicant_control_path = None): + ''' + Sets up macsec links between peers for an interface and adds the macsec interfaces to batman_interface + ''' + + wait_for_interface_to_be_up(interface_name) # Wait for interface to be up, if not already + in_queue = queue.Queue() # Queue to store wpa peer connected messages/ multicast messages on interface for cbma + mua = mutAuth(in_queue, level=level, meshiface=interface_name, port=port, + batman_interface=batman_interface, path_to_certificate=path_to_certificate, path_to_ca=path_to_ca, macsec_encryption=macsec_encryption, shutdown_event=shutdown_event) + # Wait for wireless interface to be pingable before starting mtls server, multicast + wait_for_interface_to_be_pingable(mua.meshiface, mua.ipAddress) + # Start server to facilitate client auth requests, monitor ongoing auths and start client request if there is a new peer/ server baecon + auth_server_thread, auth_server = mua.start_auth_server() + + if wpa_supplicant_control_path: + # Start monitoring wpa for new peer connection + wpa_ctrl_instance = WPAMonitor(wpa_supplicant_control_path) + wpa_thread = threading.Thread(target=wpa_ctrl_instance.start_monitoring, args=(in_queue, shutdown_event)) + wpa_thread.start() + + # Start multicast receiver that listens to multicasts from other mesh nodes + receiver_thread = threading.Thread(target=mua.multicast_handler.receive_multicast) + monitor_thread = threading.Thread(target=mua.monitor_wpa_multicast) + receiver_thread.start() + monitor_thread.start() + # Send periodic multicasts + mua.sender_thread.start() + return mua + +def cbma(level, interface_name, port, batman_interface, path_to_certificate, path_to_ca, macsec_encryption, wpa_supplicant_control_path=None): + ''' + Sets up macsec and batman for the specified interface and level + level: MACSec/ CBMA level. "lower" or "upper" + interface_name: Name of the interface (physical interface if lower level, bat0 if upper level) + port: Port number for mutual authentication and multicast. Can use 15001 for lower level and 15002 for upper level + batman_interface: Batman interface name. bat0 for lower level, bat1 for upper level + path_to_certificate: Path to folder containing certificates + path_to_ca: Path to ca certificate + macsec_encryption: Encryption flag for macsec. "on" or "off" + wpa_supplicant_control_path: Path to wpa supplicant control (if any) + ''' + mutauth_obj = setup_macsec(level, interface_name, port, batman_interface, path_to_certificate, path_to_ca, macsec_encryption, wpa_supplicant_control_path) + mutauth_obj.setup_batman() + return mutauth_obj + +def main(): + ''' + Example of setting up cbma for interfaces wlp1s0, eth1 + + Prerequisite: test certificates generation + 1. Connect CSls/ CMs to your PC with ethernet + 2. Run cbma/cert_generation/ca_side.py in your PC + 3. Run cbma/cert_generation/client_side.py --interface {interface} in your CSLs/ CMs for each interface that you want to apply cbma to + ''' + # Change mode of scripts to executable (Needs to be done once initially) + subprocess.run(['chmod', '+x', f'{file_dir}/delete_default_brlan_bat0.sh']) + subprocess.run(['chmod', '+x', f'{file_dir}/cleanup_cbma.sh']) + + # Delete default bat0 and br-lan if they come up by default + subprocess.run([f'{file_dir}/delete_default_brlan_bat0.sh']) + + # Apply firewall rules that only allows macsec traffic and cbma configuration traffic + # TODO: nft needs to be enabled on the images + #apply_nft_rules(rules_file=f'{file_dir}/tools/firewall.nft') + + # Start cbma lower for each interface/ radio by calling cbma(), which in turn calls setup_macsec(), followed by setup_batman() + # For example, for wlp1s0: + cbma_wlp1s0 = threading.Thread( + target=cbma, + args=( + "lower", # level + "wlp1s0", # interface_name + 15001, # port + "bat0", # batman_interface + f'{file_dir}/cert_generation/certificates', # path_to_certificate + f'{file_dir}/cert_generation/certificates/ca.crt', # path_to_ca + "off", # macsec_encryption (can be "on" if required) + '/var/run/wpa_supplicant_id0/wlp1s0' # wpa_supplicant_control_path + ), + ) + cbma_threads.append(cbma_wlp1s0) + cbma_wlp1s0.start() + + # Similarly, for eth1 + cbma_eth1 = threading.Thread( + target=cbma, + args=( + "lower", + "eth1", + 15001, + "bat0", + f'{file_dir}/cert_generation/certificates', + f'{file_dir}/cert_generation/certificates/ca.crt', + "off", + None + ), + ) + #cbma_eth1 = threading.Thread(target=cbma, args=("lower", "eth1", 15001, "bat0", f'{file_dir}/cert_generation/certificates', f'{file_dir}/cert_generation/certificates/ca.crt', "off", None)) + cbma_threads.append(cbma_eth1) + cbma_eth1.start() + # Repeat the same for other interfaces/ radios by changing the interface_name and wpa_supplicant_control_path (if any). The port number can be reused for the lower level + # setup_batman will setup bat0 once (for whichever interface this is called first) by setting the mac address of bat0 same as that of the physical interface + # This is because right now, we reuse certificates from one of the physical interfaces for upper macsec over bat0 + # setup_batman may be modified later if required + + # Start cbma upper for bat0 by calling cbma() for lower batman interface (bat0) + # This only needs to be called once + # path_to_certificates and path_to_ca can be changed as required + # batman interface = bat1, port number should be different from that used for lower cbma + cbma_bat0 = threading.Thread( + target=cbma, + args=( + "upper", + "bat0", + 15002, + "bat1", + f'{file_dir}/cert_generation/certificates', + f'{file_dir}/cert_generation/certificates/ca.crt', + "on", + None + ), + ) + cbma_threads.append(cbma_bat0) + cbma_bat0.start() + +def stop(): + shutdown_event.set() + for cbma_thread in cbma_threads: + cbma_thread.join() + # Delete macsec links, batman interfaces and bridges created within CBMA + subprocess.run([f'{file_dir}/cleanup_cbma.sh']) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/__init__.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/custom_logger.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/custom_logger.py new file mode 100644 index 000000000..bb26c764e --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/custom_logger.py @@ -0,0 +1,53 @@ +import logging +import sys +import os + + +class CustomLogger: + def __init__(self, role): + self.role = role + self._ensure_log_directory_exists() + self.logger = self._setup_logger() + + @staticmethod + def _ensure_log_directory_exists(): + if not os.path.exists("logs"): + os.makedirs("logs") + + def _setup_logger(self): + # Create a custom logger + logger = logging.getLogger(f"{self.role}") + logger.setLevel(logging.INFO) + + # Create file handler + file_path = os.path.join("logs", f'{self.role}.log') + file_handler = logging.FileHandler(file_path, encoding='utf-8') + file_handler.setLevel(logging.INFO) + + # Create console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + + # Create a formatter + formatter = logging.Formatter( + f'[%(asctime)s] [{self.role}] %(levelname)s %(message)s' + ) + file_handler.setFormatter(formatter) + console_handler.setFormatter(formatter) + + # Add the handlers to the logger + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger + + def get_logger(self): + return self.logger + + +# Usage example: +# logger_instance = AuthLogger("Server") +# logger = logger_instance.get_logger() +# logger.info("This is an info message!") + + diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/firewall.nft b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/firewall.nft new file mode 100644 index 000000000..1f231d408 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/firewall.nft @@ -0,0 +1,43 @@ +#!/usr/sbin/nft -f + +table inet filter { + chain input { + type filter hook input priority 0; policy drop; + + # Accept MACSec traffic (Ethernet type 0x88e5) + meta protocol 0x88e5 accept + + # Accept EAP-TLS (EAPoL) traffic (Ethernet type 0x888E) + meta protocol 0x888E accept + + # Accept IPv6 traffic (Ethernet type 0x86DD) + meta protocol 0x86DD accept + + # Accept AH traffic (IP type 50) + ip protocol ah accept + + # Accept ESP traffic (IP type 51) + ip protocol esp accept + + # Accept IPv6 NDP (ICMPv6) traffic + ip6 nexthdr icmpv6 icmpv6 type { nd-neighbor-solicit, nd-neighbor-advert, nd-router-solicit, nd-router-advert, redirect } accept + + # Accept IPv6 TCP traffic on port 15001, 15002 + ip6 nexthdr tcp tcp dport 15001 accept + ip6 nexthdr tcp tcp dport 15002 accept + + # Accept traffic on lo and eth interfaces + iif "lo" accept + iif "eth0" accept + iif "eth1" accept + + # Allow SSH traffic + tcp dport 22 accept + + # Allow HTTPS traffic + tcp dport 443 accept + + # Allow HTTP traffic (If you want to allow non-encrypted traffic as well) + tcp dport 80 accept + } +} \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/monitoring_wpa.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/monitoring_wpa.py new file mode 100644 index 000000000..faa5163a4 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/monitoring_wpa.py @@ -0,0 +1,95 @@ + +#from socket.python3.tools.wpactrl import WpaCtrl +import os +import time +import contextlib +from tools.wpactrl import WpaCtrl +from .custom_logger import CustomLogger + +class WPAMonitor: + def __init__(self, ctrl_path): + self.ctrl_path = ctrl_path + self.logger = self._setup_logger() + + def _setup_logger(self): + logger_instance = CustomLogger("wpa_monitor") + return logger_instance.get_logger() + + def _wait_for_ctrl_interface(self): + waiting_message_printed = False + while not os.path.exists(self.ctrl_path): + if not waiting_message_printed: + self.logger.info("Waiting for the wpa_supplicant control interface file to be created...") + waiting_message_printed = True + time.sleep(1) + + return WpaCtrl(self.ctrl_path) + + def start_monitoring(self, queue, shutdown_event): + instance = self._wait_for_ctrl_interface() + with instance as ctrl: + ctrl.attach() + while not shutdown_event.is_set(): + if ctrl.pending(): + response = ctrl.recv() + decoded_response = response.decode().strip() + self._handle_event(decoded_response, queue) + + def _handle_event(self, decoded_response, queue): + # Check for the MESH-PEER-CONNECTED event + if "MESH-PEER-CONNECTED" in decoded_response: + mac_address = decoded_response.split()[-1] + event = f"MESH-PEER-CONNECTED {mac_address}" + self.logger.info(event) + queue.put(("WPA", mac_address)) + + # Check for the MESH-PEER-DISCONNECTED event + elif "MESH-PEER-DISCONNECTED" in decoded_response: + mac_address = decoded_response.split()[-1] + event = f"MESH-PEER-DISCONNECTED {mac_address}" + self.logger.info(event) + + # Uncomment the next line if you want to log other events or for debugging purposes + # self.logger.debug(f"< {decoded_response}") + +# Usage +# queue = some_queue_structure +# monitor = WPAMonitor("/path/to/ctrl") +# monitor.start_monitoring(queue) + + +# +# +# def create_wpa_ctrl_instance(ctrl_path): +# waiting_message_printed = False +# +# while not os.path.exists(ctrl_path): +# if not waiting_message_printed: +# logger.info("Waiting for the wpa_supplicant control interface file to be created...") +# waiting_message_printed = True +# +# time.sleep(1) +# +# return WpaCtrl(ctrl_path) +# +# def process_events(ctrl, queue): +# with contextlib.suppress(KeyboardInterrupt): +# while True: +# if ctrl.pending(): +# response = ctrl.recv() +# decoded_response = response.decode().strip() +# +# # Check for the MESH-PEER-CONNECTED event +# if "MESH-PEER-CONNECTED" in decoded_response: +# mac_address = decoded_response.split()[-1] +# event = f"MESH-PEER-CONNECTED {mac_address}" +# logger.info(event) +# queue.put(mac_address) +# +# # Check for the MESH-PEER-DISCONNECTED event +# if "MESH-PEER-DISCONNECTED" in decoded_response: +# mac_address = decoded_response.split()[-1] +# event = f"MESH-PEER-DISCONNECTED {mac_address}" +# logger.info(event) +# +# #print("<", decoded_response) \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/utils.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/utils.py new file mode 100644 index 000000000..a356aa5ad --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/utils.py @@ -0,0 +1,346 @@ +import subprocess +import configparser +import sys +import os +import time +import shutil +import queue +import socket +import ipaddress +import re +from .custom_logger import CustomLogger +#sys.path.insert(0, '../') +path_to_tools_dir = os.path.dirname(__file__) # Path to dir containing this script + +logger_instance = CustomLogger("utils") +logger = logger_instance.get_logger() + +def is_wpa_supplicant_running(): + try: + # Running the command and decoding the output + output = subprocess.check_output(['ps', 'ax']).decode('utf-8') + # Check for wpa_supplicant in the output + processes = [line for line in output.splitlines() if 'wpa_supplicant' in line] + # Filter out any lines containing 'grep' + processes = [proc for proc in processes if 'grep' not in proc] + return len(processes) > 0 + except Exception as e: + # Handle exceptions based on your requirements + return False + + +def run_wpa_supplicant(wifidev): + ''' + maybe this should be executed from the mesh-11s.sh + but we will need to modify the batmat part + ''' + conf_file = "/var/run/wpa_supplicant-11s.conf" + log_file = "/tmp/wpa_supplicant_11s.log" + shutil.copy(f'{path_to_tools_dir}/wpa_supplicant-11s.conf', conf_file) # TODO: change in mesh_com, this is only for testing + + # Build the command with all the arguments + command = [ + "wpa_supplicant", + "-i", wifidev, + "-c", conf_file, + "-D", "nl80211", + "-C", "/var/run/wpa_supplicant/", + "-B", + "-f", log_file + ] + + try: + # Run the wpa_supplicant command as a subprocess + result = subprocess.run(command, check=True) + if result.returncode != 0: + logger.info(f"Error executing command: {result.args}. Return code: {result.returncode}") + else: + logger.info("wpa_supplicant process started successfully.") + except subprocess.CalledProcessError as e: + logger.info(f"Error starting wpa_supplicant process: {e}") + +def mesh_service(): + # Check if mesh provisioning is done + # we are assuming that the file exists on /opt/mesh.conf + if os.path.exists("/opt/S9011sMesh"): + # Start Mesh service + try: + logger.info("starting 11s mesh service") + result = subprocess.run(["/opt/S9011sMesh", "start"], check=True, capture_output=True, text=True) + time.sleep(2) + except subprocess.CalledProcessError as e: + logger.info(f"Error executing command: {e.cmd}. Return code: {e.returncode}") + logger.info(f"Error output: {e.stderr}") + +# Call the function + + +def killall(interface): + try: + # Kill wpa_supplicant + result = subprocess.run(['killall', 'wpa_supplicant'], capture_output=True, text=True) + if result.returncode != 0: + if "no process killed" in result.stderr: + print("wpa_supplicant process was not running.") + else: + print(f"Error executing command: {result.args}. Return code: {result.returncode}") + + # Bring the interface down + subprocess.run(['ifconfig', interface, 'down'], check=True) + + # Bring the interface back up + subprocess.run(['ifconfig', interface, 'up'], check=True) + + except subprocess.CalledProcessError as e: + print(f"Error executing command: {e.cmd}. Return code: {e.returncode}") + +def apply_nft_rules(rules_file="firewall.nft"): + try: + # Run the nft command to apply the rules from the specified file + subprocess.run(['nft', '-f', rules_file], check=True) + logger.info("nftables rules applied successfully.") + except subprocess.CalledProcessError as e: + logger.error(f"Error applying nftables rules: {e}") + + +def modify_conf_file(conf_file_path, new_values): + config = configparser.ConfigParser() + config.read(conf_file_path) + + for section, options in new_values.items(): + for option, value in options.items(): + config.set(section, option, value) + + with open(conf_file_path, 'w') as configfile: + config.write(configfile) + +def batman_exec(batman_interface, routing_algo): + if routing_algo != "batman-adv": + #TODO here should be OLSR + return + try: + run_batman(batman_interface) + except subprocess.CalledProcessError as e: + logger.error(f"Error: {e}") + + +def run_batman(batman_interface): + logger.info(f"Setting mac address of {batman_interface} to be same as wlp1s0..") + subprocess.run(["ip", "link", "set", "dev", batman_interface, "address", get_mac_addr("wlp1s0")], check=True) + + logger.info(f"Setting {batman_interface} up..") + # Run the ifconfig batman_interface up command + subprocess.run(["ifconfig", batman_interface, "up"], check=True) + + logger.info(f"Setting {batman_interface} mtu size") + # Run the ifconfig batman_interface mtu 1460 command + subprocess.run(["ifconfig", batman_interface, "mtu", "1460"], check=True) + + # Run the ifconfig batman_interface command to show the interface information + subprocess.run(["ifconfig", batman_interface], check=True) + +""" +def mac_to_ipv6(mac_address): + # Remove any separators from the MAC address (e.g., colons, hyphens) + mac_address = mac_address.replace(":", "").replace("-", "").lower() + + # Split the MAC address into two equal halves + first_half = mac_address[:6] + + # Convert the first octet from hexadecimal to binary + binary_first_octet = bin(int(first_half[:2], 16))[2:].zfill(8) + + + # Invert the seventh bit (change 0 to 1 or 1 to 0) + inverted_seventh_bit = "1" if binary_first_octet[6] == "0" else "0" + + + # Convert the modified binary back to hexadecimal + modified_first_octet = hex(int(binary_first_octet[:6] + inverted_seventh_bit + binary_first_octet[7:], 2))[2:] + + + # Replace the original first octet with the modified one + modified_mac_address = modified_first_octet + mac_address[2:] + + + line = f"{modified_mac_address[:5]}fffe{modified_mac_address[5:]}" + + # Add "ff:fe:" to the middle of the new MAC address +# mac_with_fffe = ":".join(a + b for a, b in zip(a[::2], a[1::2])) + mac_with_fffe = ":".join([line[:3], line[3:7], line[7:11], line[11:]]) + + return f"fe80::{mac_with_fffe}" +""" + +def mac_to_ipv6(mac): + # Split MAC address and insert ff:fe in the middle + mac_parts = mac.split(":") + eui_64 = mac_parts[:3] + ['ff', 'fe'] + mac_parts[3:] + + # Modify the 7th bit (Universal/Local bit) + eui_64[0] = format(int(eui_64[0], 16) ^ 0x02, '02x') + + # Combine parts to form EUI-64 part of IPv6 + eui_64_combined = "".join(eui_64) + + return f"fe80::{eui_64_combined[:4]}:{eui_64_combined[4:8]}:{eui_64_combined[8:12]}:{eui_64_combined[12:16]}" + +def get_mac_from_ipv6(ipv6_address, interface): + try: + # Send pings to the IPv6 address to prompt an NDP exchange + # Increase the ping count, if needed, for better reliability + subprocess.run(['ping6', '-c', '1', f'{ipv6_address}%{interface}'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=5) + + # Check the neighbor cache + output = subprocess.check_output(['ip', '-6', 'neigh', 'show', ipv6_address], text=True) + + # Extract MAC address + mac_search = re.search(r"(([0-9a-f]{2}:){5}[0-9a-f]{2})", output, re.IGNORECASE) + if mac_search: + return mac_search.group(1) + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: + logger.error(f"An error occurred: {e}") + except re.error as re_e: + logger.error(f"Regex error: {re_e}") + return None + + +def extract_mac_from_ipv6(ipv6_address): #assuming link local address + # Parse and expand the IPv6 address to its full form + full_ipv6_address = ipaddress.IPv6Address(ipv6_address).exploded + + # Extract the EUI-64 identifier from the IPv6 address + eui_64_parts = full_ipv6_address.split(":")[4:8] + + # Join them together into a single hex string + eui_64_hex = "".join(eui_64_parts) + + # Separate the bytes that make up the EUI-64 identifier + eui_64_bytes = bytes.fromhex(eui_64_hex) + + # Extract the MAC address from the EUI-64 identifier + mac_bytes = bytearray(6) + mac_bytes[0] = eui_64_bytes[0] ^ 0x02 # Flip the universal/local bit + mac_bytes[1:3] = eui_64_bytes[1:3] + mac_bytes[3:5] = eui_64_bytes[5:7] + mac_bytes[5] = eui_64_bytes[7] + + return ":".join(f"{byte:02x}" for byte in mac_bytes) + + +def get_mac_addr(EXPECTED_INTERFACE): + """ + got it from common/tools/field_test_logger/wifi_info.py + """ + try: + with open(f"/sys/class/net/{EXPECTED_INTERFACE}/address", 'r') as f: + value = f.readline() + return value.strip() + except Exception: + return "NaN" + + +def set_ipv6(interface, ipv6): + command = ["ip", "-6", "addr", "add", f"{ipv6}/64", "dev", interface] + try: + result = subprocess.run(command, capture_output=True, text=True, check=True) + print(result.stdout) + except subprocess.CalledProcessError as e: + print(f"Command failed with error: {e.stderr}") + +def is_ipv4(ip): + try: + socket.inet_pton(socket.AF_INET, ip) + return True + except socket.error: + return False + +def is_ipv6(ip): + try: + socket.inet_pton(socket.AF_INET6, ip) + return True + except socket.error: + return False + +def generate_random_bytes(byte_size=32): + return os.urandom(byte_size) + +def is_interface_pingable(interface_name, ip_address): + # Return true if pingable + try: + if is_ipv4(ip_address): + ping_output = subprocess.check_output(['ping', '-c', '1', '-w', '1', ip_address], stderr=subprocess.STDOUT,universal_newlines=True) + return "1 packets transmitted, 1 received" in ping_output + elif is_ipv6(ip_address): + ping_output = subprocess.check_output(['ping', '-c', '1', '-w', '1', f'{ip_address}%{interface_name}'], stderr=subprocess.STDOUT, universal_newlines=True) + return "1 packets transmitted, 1 received" in ping_output + else: + raise ValueError("Invalid IP address") + except subprocess.CalledProcessError as e: + return False + +def wait_for_interface_to_be_pingable(interface_name, ipv6_address): + waiting_message_printed = False + while not is_interface_pingable(interface_name, ipv6_address): + # Waiting till interface is pingable + if not waiting_message_printed: + logger.info(f'Waiting for {interface_name} to be reachable..') + waiting_message_printed = True + time.sleep(1) + +def is_interface_up(interface_name): + # Check if interface is up + try: + output = subprocess.check_output(['ifconfig', interface_name], stderr=subprocess.DEVNULL) # Error suppressed as the command throws an error when inyterface is not present + return 'inet' in output.decode() + except subprocess.CalledProcessError: + return False + +def wait_for_interface_to_be_up(interface_name): + waiting_message_printed = False + while not is_interface_up(interface_name): + # Waiting till interface is up + if not waiting_message_printed: + logger.info(f'Waiting for {interface_name} to be up..') + waiting_message_printed = True + time.sleep(1) + +def xor_bytes(byte1, byte2, byte_size=32): + # Trim the bytes if they are longer than byte_size + if len(byte1) > byte_size: + byte1 = byte1[:byte_size] + if len(byte2) > byte_size: + byte2 = byte2[:byte_size] + + # Pad bytes to the required length + if len(byte1) < byte_size: + byte1 = byte1.rjust(byte_size, b'\x00') + if len(byte2) < byte_size: + byte2 = byte2.rjust(byte_size, b'\x00') + + # Return bit-wise XOR of byte1 and byte2 + return bytes(a ^ b for a, b in zip(byte1, byte2)) + +def add_interface_to_batman(interface_to_add, batman_interface): + # Add interface to batman + try: + subprocess.run(["batctl", "meshif", batman_interface, "if", "add", interface_to_add], check=True) + logger.info(f'Added interface {interface_to_add} to {batman_interface}') + except Exception as e: + logger.error(f'Error adding interface {interface_to_add} to {batman_interface}: {e}') + +def add_interface_to_bridge(interface_to_add, bridge_interface): + try: + subprocess.run(["brctl", "addif", bridge_interface, interface_to_add], check=True) + logger.info(f'Added interface {interface_to_add} to {bridge_interface}') + except Exception as e: + logger.error(f'Error adding interface {interface_to_add} to {bridge_interface}: {e}') + +def setup_bridge(bridge_interface): + # Set a bridge interface up + try: + subprocess.run(["brctl", "addbr", bridge_interface], check=True) + subprocess.run(["ip", "link", "set", bridge_interface, "up"], check=True) + logger.info(f'Setup bridge {bridge_interface}') + except Exception as e: + logger.error(f'Error setting up bridge {bridge_interface}: {e}') diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/verification_tools.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/verification_tools.py new file mode 100644 index 000000000..6fdbd0def --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/verification_tools.py @@ -0,0 +1,112 @@ +from datetime import datetime +import OpenSSL +from OpenSSL import crypto +from .utils import get_mac_from_ipv6 + + +# Custom exceptions for certificate verification +class CertificateExpiredError(Exception): + pass + + +class CertificateActivationError(Exception): + pass + + +class CertificateHostnameError(Exception): + pass + + +class CertificateIssuerError(Exception): + pass + + +class CertificateVerificationError(Exception): + pass + + +class CertificateNoPresentError(Exception): + pass + + +class CertificateDifferentCN(Exception): + pass + +class ServerConnectionRefusedError(Exception): + pass + +def verify_cert(cert, ca_cert, IPaddress, interface, logging): + try: + return validation(cert, ca_cert, IPaddress, interface, logging) + except (CertificateExpiredError, CertificateHostnameError, CertificateIssuerError, ValueError) as e: + logging.error(f"Certificate verification failed with {IPaddress}.", exc_info=True) + return False + except Exception as e: + logging.error(f"An unexpected error occurred during certificate verification with {IPaddress}.", exc_info=True) + return False + + +def validation(cert, ca_cert, IPaddress, interface, logging): + # Load the DER certificate into an OpenSSL certificate object + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, cert) + + # Get the certificate expiration date and activation date using OpenSSL methods + expiration_date_str = x509.get_notAfter().decode('utf-8') + activation_date_str = x509.get_notBefore().decode('utf-8') + + expiration_date = datetime.strptime(expiration_date_str, '%Y%m%d%H%M%SZ') + activation_date = datetime.strptime(activation_date_str, '%Y%m%d%H%M%SZ') + current_date = datetime.now() + + # Check if the certificate has expired + if expiration_date < current_date: + logging.error(f"Certificate of {IPaddress} has expired.", exc_info=True) + raise CertificateExpiredError("Certificate has expired.") + + if activation_date > current_date: + logging.error(f"Client certificate not yet active for {IPaddress}.", exc_info=True) + raise CertificateExpiredError("Client certificate not yet active") + + # Extract the actual ID from CN + common_name = x509.get_subject().CN + if common_name != get_mac_from_ipv6(IPaddress, interface): + logging.error(f"CN does not match the MAC Address for {IPaddress}", exc_info=True) + raise CertificateDifferentCN("CN does not match the MAC Address.") + + if _verify_certificate_chain(cert, ca_cert, logging): + # Extract the public key from the certificate + pub_key_der = OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1, x509.get_pubkey()) + # If the client certificate has passed all verifications, you can print or log a success message + logging.info(f"Certificate verification successful for {IPaddress}.") + return True + + else: + raise CertificateVerificationError("Verification of certificate chain failed.") + + +def _verify_certificate_chain(cert, trusted_certs, logging): + """ + this function is not being used right now, because we only have one certificate (ca.crt), not a full chain (CA, Interm CA, etc) + but I left the code for further adaptation + """ + try: + x509 = crypto.load_certificate(crypto.FILETYPE_ASN1, cert) + cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, x509) + + store = crypto.X509Store() + + # Check if trusted_certs is a list, if not make it a list + if not isinstance(trusted_certs, list): + trusted_certs = [trusted_certs] + + for _cert in trusted_certs: + cert_data = crypto.load_certificate(crypto.FILETYPE_PEM, open(_cert).read()) + store.add_cert(cert_data) + + store_ctx = crypto.X509StoreContext(store, x509) + store_ctx.verify_certificate() + + return True + except Exception as e: + logging.error(f"Certificate chain verification failed: {e}", exc_info=True) + return False diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/wpa_supplicant-11s.conf b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/wpa_supplicant-11s.conf new file mode 100644 index 000000000..44b1cb4fd --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/wpa_supplicant-11s.conf @@ -0,0 +1,20 @@ +ctrl_interface=DIR=/var/run/wpa_supplicant +# use 'ap_scan=2' on all devices connected to the network +# this is unnecessary if you only want the network to be created when no other networks.. +ap_scan=1 +country=AE +p2p_disabled=1 +mesh_max_inactivity=50 +network={ + ssid="gold" + bssid=00:11:22:33:44:55 + mode=5 + frequency=5805 + psk="123456789" + key_mgmt=SAE + ieee80211w=2 + mesh_fwding=0 + # 11b rates dropped (for better performance) + mesh_basic_rates=60 90 120 180 240 360 480 540 +} + diff --git a/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/wpactrl.py b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/wpactrl.py new file mode 100644 index 000000000..0b1e9baf8 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/features/cbma/tools/wpactrl.py @@ -0,0 +1,90 @@ +import os +import socket +import select + +WPA_CTRL_MAX_REPLY_LEN = 4096 + +class WpaCtrl: + counter = 0 + + def __init__(self, ctrl_path): + self.ctrl_path = ctrl_path + self.socket = None + self.local_path = f"/tmp/wpa_ctrl_{os.getpid()}-{WpaCtrl.counter}" + WpaCtrl.counter += 1 + + def __enter__(self): + self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM, 0) + self.socket.bind(self.local_path) + self.socket.connect(self.ctrl_path) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + os.unlink(self.local_path) + self.socket.close() + + def request(self, cmd, msg_cb=None, reply_len=WPA_CTRL_MAX_REPLY_LEN): + self.socket.sendall(cmd) + + while True: + rlist, _, _ = select.select([self.socket], [], [], 2) + + if not rlist or self.socket not in rlist: + raise TimeoutError("Timed out waiting for response") + + data = self.socket.recv(reply_len) + + if not data or data[0] != b'<': + return data + + if msg_cb: + msg_cb(data) + + def attach(self): + return self._attach_helper(True) + + def detach(self): + return self._attach_helper(False) + + def _attach_helper(self, attach): + ret = self.request(b'ATTACH' if attach else b'DETACH') + return ret == b'OK\n' if isinstance(ret, bytes) else ret + + def recv(self, reply_len=WPA_CTRL_MAX_REPLY_LEN): + return self.socket.recv(reply_len) + + def pending(self): + rlist, _, _ = select.select([self.socket], [], [], 0) + return self.socket in rlist + + def get_fd(self): + return self.socket.fileno() + +# import contextlib +# from wpacrtl import WpaCtrl +# +# with WpaCtrl("/var/run/wpa_supplicant/wlp1s0") as ctrl: +# ctrl.attach() +# +# try: +# while True: +# if ctrl.pending(): +# response = ctrl.recv() +# decoded_response = response.decode().strip() +# +# # Check for the MESH-PEER-CONNECTED event +# if "MESH-PEER-CONNECTED" in decoded_response: +# mac_address = decoded_response.split()[-1] +# event = f"MESH-PEER-CONNECTED {mac_address}" +# print(event) +# +# # Check for the MESH-PEER-DISCONNECTED event +# if "<3>MESH-PEER-DISCONNECTED" in decoded_response: +# mac_address = decoded_response.split()[-1] +# event = f"MESH-PEER-DISCONNECTED {mac_address}" +# print(event) +# +# print("<", decoded_response) +# +# except KeyboardInterrupt: +# pass \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/add_syspath.py b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/add_syspath.py new file mode 100644 index 000000000..089cbd293 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/add_syspath.py @@ -0,0 +1,11 @@ +import os +import sys + +# Get the path to dir containing current script +current_dir = os.path.dirname(os.path.abspath(__file__)) + +# Construct the path to the cbma directory +cbma_dir = os.path.join(current_dir, "../../features/cbma") + +# Add the cbma dir to sys.path +sys.path.insert(0, cbma_dir) \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/functional_tests/cbma_check_bat1_neighbor.sh b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/functional_tests/cbma_check_bat1_neighbor.sh new file mode 100644 index 000000000..e0160311c --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/functional_tests/cbma_check_bat1_neighbor.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# Run the batctl command and capture the output +output=$(batctl meshif bat1 n 2>&1) # Redirect stderr to stdout + +# Check for error message +if echo "$output" | grep -q "Error - interface bat1 is not present or not a batman-adv interface"; then + echo "Error: bat1 interface is not present or not a batman-adv interface." + echo "Fail" + exit 1 +fi + +# Count the number of lines in the output +line_count=$(echo "$output" | wc -l) + +# Check if the line count is greater than 2 (header line + neighbor info) +if [ "$line_count" -gt 2 ]; then + echo "Neighbor(s) found." + echo "Pass" +else + echo "No neighbor found." + echo "Fail" +fi diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/functional_tests/cbma_check_macsec.sh b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/functional_tests/cbma_check_macsec.sh new file mode 100644 index 000000000..891bed752 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/functional_tests/cbma_check_macsec.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# Run the command and capture the output +output=$(ip macsec show 2>&1) + +# Check for error in the output +if echo "$output" | grep -q -i "error"; then + echo "Error detected in 'ip macsec show' output" + echo "Fail" + exit 1 +fi + +# Function to check for TXSC and RXSC +check_txsc_rxsc() { + local interface=$1 + local output=$2 + + # Extract the relevant block for the interface + local interface_block=$(echo "$output" | awk -v iface="$interface:" '$0 ~ iface {flag=1; next} /offload:/ {flag=0} flag') + + # Check for TXSC and RXSC in the extracted block + if echo "$interface_block" | grep -q "TXSC" && echo "$interface_block" | grep -q "RXSC"; then + return 0 # Found both TXSC and RXSC + else + return 1 # Missing TXSC or RXSC + fi +} + +# Function to check interfaces with a specific prefix +check_interfaces_with_prefix() { + local prefix=$1 + local output="$2" + local fail=0 + local found=0 + + # Extract interfaces starting with the prefix + local interfaces=$(echo "$output" | grep -oE "${prefix}[[:alnum:]]+:" | tr -d ':') + + for interface in $interfaces; do + found=1 + if ! check_txsc_rxsc "$interface" "$output"; then + echo "Interface $interface does not have both TXSC and RXSC" + fail=1 + fi + done + + if [[ $found -eq 0 ]]; then + echo "No interface found with prefix $prefix" + fail=1 + fi + + return $fail +} + +# Check for lms* and ums* interfaces +if ! check_interfaces_with_prefix "lms" "$output"; then + echo "Fail" + exit 1 +fi + +if ! check_interfaces_with_prefix "ums" "$output"; then + echo "Fail" + exit 1 +fi + +# If script reaches this point, all checks passed +echo "Pass" diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_authClient.py b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_authClient.py new file mode 100644 index 000000000..5cfb2bdaf --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_authClient.py @@ -0,0 +1,130 @@ +import pytest +from unittest.mock import patch, MagicMock, call + +import add_syspath +from auth.authClient import AuthClient + +class MockLogger: + def __init__(self): + self.info = MagicMock() + self.error = MagicMock() + self.debug = MagicMock() + +# Mock Logger +@pytest.fixture +def mock_logger(): + with patch('auth.authClient.logger_instance.get_logger', return_value=MockLogger()) as mock_log: + yield { + 'mock_log': mock_log + } + +# Mock external module functions +@pytest.fixture +def mock_dependencies(): + with patch('auth.authClient.verify_cert', return_value=True) as mock_verify, \ + patch('auth.authClient.mac_to_ipv6', return_value="::1") as mock_mac_ipv6, \ + patch('auth.authClient.get_mac_addr', return_value="00:00:00:00:00:00") as mock_get_mac: + yield { + 'mock_verify': mock_verify, + 'mock_mac_ipv6': mock_mac_ipv6, + 'mock_get_mac': mock_get_mac + } + +# Mock ssl and socket related functions +@pytest.fixture +def mock_socket(): + with patch('auth.authClient.socket.socket') as mock_sock, \ + patch('auth.authClient.ssl.SSLContext') as mock_ssl_context: + mock_sock_instance = MagicMock() + mock_sock.return_value = mock_sock_instance + + mock_context_instance = MagicMock() + mock_ssl_context.return_value = mock_context_instance + + mock_sec_sock_instance = MagicMock() + mock_context_instance.wrap_socket.return_value = mock_sec_sock_instance + + yield { + 'mock_sock': mock_sock, + 'mock_sock_instance': mock_sock_instance, + 'mock_ssl_context': mock_ssl_context, + 'mock_context_instance': mock_context_instance, + 'mock_sec_sock_instance': mock_sec_sock_instance + } + +# Mock glob.glob +@pytest.fixture +def mock_glob(monkeypatch): + def fake_glob(pattern): + if 'ca.crt' in pattern: + return ['dummy_path/ca.crt'] + elif 'crt' in pattern: + return ['dummy_path/dummy_cert.crt'] + elif 'key' in pattern: + return ['dummy_path/dummy_key.key'] + return ['dummy_path'] + + monkeypatch.setattr("auth.authClient.glob.glob", fake_glob) + +# Mock sleep +@pytest.fixture +def mock_time_sleep(monkeypatch): + def fake_sleep(*args, **kwargs): + pass # Do nothing + + monkeypatch.setattr("auth.authClient.time.sleep", fake_sleep) + + +def test_auth_client_establish_connection(mock_dependencies, mock_socket, mock_glob, mock_logger): + client = AuthClient("some_interface", "some_server_mac", 15001, "path/to/cert", "path/to/cert/ca.crt", MagicMock()) + + client.establish_connection() + + # Add assertions to verify correct methods were called + mock_socket['mock_sock'].assert_called_once() + mock_socket['mock_context_instance'].load_verify_locations.assert_called_once() + mock_socket['mock_context_instance'].load_cert_chain.assert_called_once() + mock_socket['mock_sec_sock_instance'].connect.assert_called_once() + mock_dependencies['mock_verify'].assert_called_once() + + +def test_auth_client_connection_failure(mock_dependencies, mock_socket, mock_glob, mock_time_sleep, mock_logger): + client = AuthClient("some_interface", "some_server_mac", 15001, "path/to/cert", "path/to/cert/ca.crt", MagicMock()) + + mock_socket['mock_sec_sock_instance'].connect.side_effect = ConnectionRefusedError() + + client.establish_connection() + + # Add assertions to verify the connection retry behavior + assert mock_socket['mock_sec_sock_instance'].connect.call_count == 5 # Max retries + +def test_auth_client_certificate_verification_pass(mock_dependencies, mock_socket, mock_glob, mock_logger): + client = AuthClient("some_interface", "some_server_mac", 15001, "path/to/cert", "path/to/cert/ca.crt", MagicMock()) + + mock_dependencies['mock_verify'].return_value = True + client.establish_connection() + # Assert mua.auth_pass was called + client.mua.auth_pass.assert_called_once() + # Assert mua.auth_fail was NOT called + client.mua.auth_fail.assert_not_called() + +def test_auth_client_certificate_verification_failure(mock_dependencies, mock_socket, mock_glob, mock_logger): + client = AuthClient("some_interface", "some_server_mac", 15001, "path/to/cert", "path/to/cert/ca.crt", MagicMock()) + + mock_dependencies['mock_verify'].return_value = False + client.establish_connection() + # Assert mua.auth_fail was called + client.mua.auth_fail.assert_called_once() + # Assert mua.auth_pass was NOT called + client.mua.auth_pass.assert_not_called() + +# Remove log directory during teardown +import shutil +import os +def remove_logs_directory(): + logs_directory = 'logs' + if os.path.exists(logs_directory) and os.path.isdir(logs_directory): + shutil.rmtree(logs_directory) + +def teardown_module(): + remove_logs_directory() \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_authServer.py b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_authServer.py new file mode 100644 index 000000000..6e824b393 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_authServer.py @@ -0,0 +1,133 @@ +import pytest +from unittest.mock import patch, MagicMock, call +import shutil +import os +import socket +import ssl +import threading + +import add_syspath +from auth.authServer import AuthServer + +class MockLogger: + def __init__(self): + self.info = MagicMock() + self.error = MagicMock() + self.debug = MagicMock() + +# Mock Logger +@pytest.fixture +def mock_logger(): + with patch('auth.authServer.logger', return_value=MockLogger()) as mock_log: + yield { + 'mock_log': mock_log + } + +# Mock external module functions +@pytest.fixture +def mock_dependencies(): + with patch('auth.authServer.verify_cert', return_value=True) as mock_verify, \ + patch('auth.authServer.get_mac_addr', return_value="00:00:00:00:00:00") as mock_get_mac, \ + patch('auth.authServer.ssl.SSLContext', autospec=True) as mock_ssl_context: + + mock_ssl_instance = MagicMock() + mock_ssl_context.return_value = mock_ssl_instance + mock_ssl_instance.load_verify_locations = MagicMock() # Mocking the method + + yield { + 'mock_verify': mock_verify, + 'mock_get_mac': mock_get_mac, + 'mock_ssl_context': mock_ssl_context, + 'mock_ssl_instance': mock_ssl_instance + } + +# Mock socket related functions +@pytest.fixture +def mock_socket(): + with patch('auth.authServer.socket.socket') as mock_sock: + mock_sock_instance = MagicMock() + mock_sock.return_value = mock_sock_instance + yield { + 'mock_sock': mock_sock, + 'mock_sock_instance': mock_sock_instance + } + +# Mock glob.glob +@pytest.fixture +def mock_glob(monkeypatch): + def fake_glob(pattern): + if 'ca.crt' in pattern: + return ['dummy_path/ca.crt'] + elif 'crt' in pattern: + return ['dummy_path/dummy_cert.crt'] + elif 'key' in pattern: + return ['dummy_path/dummy_key.key'] + return ['dummy_path'] + + monkeypatch.setattr("auth.authServer.glob.glob", fake_glob) + + +def test_auth_server_handle_client(mock_dependencies, mock_socket, mock_glob, mock_logger): + server = AuthServer("some_interface", "127.0.0.1", 15001, "path/to/cert", "path/to/cert/ca.crt", MagicMock()) + + fake_client_connection = MagicMock() + fake_client_address = ('::1', 12345) + server.handle_client(fake_client_connection, fake_client_address) + + # Add assertions to verify the correct methods were called + mock_dependencies['mock_verify'].assert_called_once() + + +def test_auth_server_start_stop(mock_dependencies, mock_socket, mock_glob, mock_logger): + server = AuthServer("some_interface", "127.0.0.1", 15001, "path/to/cert", "path/to/cert/ca.crt", MagicMock()) + + # Mock accept to raise socket.timeout just once + mock_socket['mock_sock_instance'].accept.side_effect = [socket.timeout] + + # Temporarily set server.running to False to stop the loop after one iteration + with patch.object(server, 'running', new=False): + # Start the server in a separate thread + thread = threading.Thread(target=server.start_server) + thread.start() + + # Join the thread to make sure it's finished before proceeding with assertions + thread.join() + + # Add assertions to verify the server started and stopped correctly + mock_socket['mock_sock_instance'].bind.assert_called() + mock_socket['mock_sock_instance'].listen.assert_called_once() + + +def test_authenticate_client_verification_pass(mock_dependencies, mock_socket, mock_glob, mock_logger): + server = AuthServer("some_interface", '::1', 15001, "path/to/cert", "path/to/cert/ca.crt", MagicMock()) + fake_client_connection = MagicMock() + fake_client_address = ('::1', 12345) + # Mock verify_cert to return True + mock_dependencies['mock_verify'].return_value = True + # Call authenticate_client + server.authenticate_client(fake_client_connection, fake_client_address, "fake_client_mac") + # Assert mua.auth_pass was called + server.mua.auth_pass.assert_called_once() + # Assert mua.auth_fail was NOT called + server.mua.auth_fail.assert_not_called() + +def test_authenticate_client_verification_fail(mock_dependencies, mock_socket, mock_glob, mock_logger): + server = AuthServer("some_interface", '::1', 15001, "path/to/cert", "path/to/cert/ca.crt", MagicMock()) + fake_client_connection = MagicMock() + fake_client_address = ('::1', 12345) + # Mock verify_cert to return False + mock_dependencies['mock_verify'].return_value = False + # Call authenticate_client + server.authenticate_client(fake_client_connection, fake_client_address, "fake_client_mac") + # Assert mua.auth_fail was called + server.mua.auth_fail.assert_called_once() + # Assert mua.auth_pass was NOT called + server.mua.auth_pass.assert_not_called() + +def remove_logs_directory(): + logs_directory = 'logs' + if os.path.exists(logs_directory) and os.path.isdir(logs_directory): + shutil.rmtree(logs_directory) + +def teardown_module(): + remove_logs_directory() diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_macsec.py b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_macsec.py new file mode 100644 index 000000000..f274c54b7 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_macsec.py @@ -0,0 +1,86 @@ +import pytest +from unittest.mock import patch, call + +import add_syspath +from macsec.macsec import Macsec + + +@pytest.fixture +def macsec_instance(): + return Macsec(level="lower", interface="wlp1s0", macsec_encryption="off") + +def test_set_macsec_tx(macsec_instance): + client_mac = "aa:bb:cc:dd:ee:ff" + my_macsec_key = "mykey" + my_port = 12345 + + with patch("subprocess.run") as mock_run: + macsec_instance.set_macsec_tx(client_mac, my_macsec_key, my_port) + + expected_calls = [ + call(["ip", "link", "add", "link", macsec_instance.interface, "lmsaabbccddeeff", "type", "macsec", "port", str(my_port), "encrypt", macsec_instance.macsec_encryption, "cipher", "gcm-aes-256"], check=True), + call(["ip", "macsec", "add", "lmsaabbccddeeff", "tx", "sa", "0", "pn", "1", "on", "key", "01", my_macsec_key], check=True), + call(["ip", "link", "set", "lmsaabbccddeeff", "up"], check=True), + call(["ip", "macsec", "show"], check=True) + ] + + assert mock_run.call_args_list == expected_calls + +def test_set_macsec_rx(macsec_instance): + client_mac = "aa:bb:cc:dd:ee:ff" + client_macsec_key = "clientkey" + client_port = 12345 + + with patch("subprocess.run") as mock_run: + macsec_instance.set_macsec_rx(client_mac, client_macsec_key, client_port) + + expected_calls = [ + call(["ip", "macsec", "add", "lmsaabbccddeeff", "rx", "port", str(client_port), "address", client_mac], check=True), + call(["ip", "macsec", "add", "lmsaabbccddeeff", "rx", "port", str(client_port), "address", client_mac, "sa", "0", "pn", "1", "on", "key", client_mac.replace(":", ""), client_macsec_key], check=True), + call(["ip", "macsec", "show"], check=True) + ] + + assert mock_run.call_args_list == expected_calls + +def test_macsec_interface_name(macsec_instance): + assert macsec_instance.get_macsec_interface_name("aa:bb:cc:dd:ee:ff") == "lmsaabbccddeeff" + macsec_instance.level = "upper" + assert macsec_instance.get_macsec_interface_name("aa:bb:cc:dd:ee:ff") == "umsaabbccddeeff" + + +def test_assign_unique_port(macsec_instance): + client_mac = "aa:bb:cc:dd:ee:ff" + assigned_port = macsec_instance.assign_unique_port(client_mac) + + assert assigned_port in range(1, 2 ** 16) # Assert assigned port is in range + assert client_mac in macsec_instance.used_ports # Assert that record for client_mac has been added in used_ports + assert assigned_port == macsec_instance.used_ports[client_mac] # Assert that assigned_port has been recorded correctly for client_mac in used ports + assert assigned_port not in macsec_instance.available_ports # Assert that assigned port is not in available ports + + +def test_release_port(macsec_instance): + macsec_instance = Macsec(level="lower", interface="wlp1s0", macsec_encryption="off") + client_mac = "aa:bb:cc:dd:ee:ff" + assigned_port = macsec_instance.assign_unique_port(client_mac) + macsec_instance.release_port(client_mac) + + assert client_mac not in macsec_instance.used_ports # assert that client has been removed from used ports + assert assigned_port in macsec_instance.available_ports # assert that the assigned port has been released back as available port + + +def test_release_port_error(macsec_instance): + macsec_instance = Macsec(level="lower", interface="wlp1s0", macsec_encryption="off") + with pytest.raises(ValueError, match=r"Client .* is not in the list of used ports."): + macsec_instance.release_port("00:11:22:33:44:55") + + +# Remove log directory during teardown +import shutil +import os +def remove_logs_directory(): + logs_directory = 'logs' + if os.path.exists(logs_directory) and os.path.isdir(logs_directory): + shutil.rmtree(logs_directory) + +def teardown_module(): + remove_logs_directory() diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_monitoring_wpa.py b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_monitoring_wpa.py new file mode 100644 index 000000000..051aa7a15 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_monitoring_wpa.py @@ -0,0 +1,32 @@ +import queue +import pytest +from unittest.mock import Mock, patch, MagicMock + +import add_syspath +from tools.monitoring_wpa import WPAMonitor + +@pytest.fixture +def mock_ctrl_path(): + return "/path/to/ctrl" + + + +def test_handle_event_handles_mesh_peer_connected(mock_ctrl_path): + event_queue = queue.Queue() + monitor = WPAMonitor(mock_ctrl_path) + monitor._handle_event("MESH-PEER-CONNECTED some_mac_address", event_queue) + + # Assert that the event was logged and added to the queue + #mock_queue.put.assert_called_once_with(("WPA", "some_mac_address")) + assert event_queue.get() == ("WPA", "some_mac_address") + +# Remove log directory during teardown +import shutil +import os +def remove_logs_directory(): + logs_directory = 'logs' + if os.path.exists(logs_directory) and os.path.isdir(logs_directory): + shutil.rmtree(logs_directory) + +def teardown_module(): + remove_logs_directory() \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_multicast.py b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_multicast.py new file mode 100644 index 000000000..8262e104d --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_multicast.py @@ -0,0 +1,109 @@ +from unittest.mock import patch, MagicMock +import json +import pytest +import queue + +import add_syspath +from multicast.multicast import MulticastHandler + +class MockLogger: + def __init__(self): + self.info = MagicMock() + self.error = MagicMock() + self.debug = MagicMock() + +# Mock Logger +@pytest.fixture +def mock_logger(): + with patch('multicast.multicast.logger_instance.get_logger', return_value=MockLogger()) as mock_log: + yield { + 'mock_log': mock_log + } +@patch('socket.socket') +def test_send_multicast_message(mock_socket, mock_logger): + # Setup + mock_queue = MagicMock() + mock_sock_instance = MagicMock() + mock_socket.return_value.__enter__.return_value = mock_sock_instance + + # Create a MulticastHandler instance + handler = MulticastHandler(mock_queue, 'ff02::1', 10000, 'lo') + + # Data to be sent + data = 'some_mac_address' + + # Expected message + expected_message = { + 'mac_address': data, + 'message_type': 'mac_announcement' + } + + # Call send_multicast_message + handler.send_multicast_message(data) + + # Assertions + mock_sock_instance.sendto.assert_called_once_with(json.dumps(expected_message).encode('utf-8'), ('ff02::1', 10000)) + + +@patch('socket.socket') +def test_receive_multicast(mock_socket, mock_logger): + # Setup + test_queue = queue.Queue() + mock_sock_instance = MagicMock() + mock_socket.return_value.__enter__.return_value = mock_sock_instance + + # Mock recvfrom() to return a message and address + received_message = { + 'mac_address': 'some_mac', + 'message_type': 'mac_announcement' + } + + # Set recvfrom() to return the message once and then raise an exception to break the loop + mock_sock_instance.recvfrom.side_effect = [ + (json.dumps(received_message).encode('utf-8'), ('::1', 10000)), + KeyboardInterrupt + ] + + # Create a MulticastHandler instance + # Mock get_mac_addr() to return an excluded MAC address + with patch('multicast.multicast.get_mac_addr', return_value='excluded_mac'): + handler = MulticastHandler(test_queue, 'ff02::1', 10000, 'lo') + + with pytest.raises(KeyboardInterrupt): + handler.multicast_message() + + # Check if the message is added to the queue + assert not test_queue.empty() + queue_item = test_queue.get_nowait() + assert queue_item == ("MULTICAST", 'some_mac') + +@patch('socket.socket') +def test_receive_multicast_excluded_mac(mock_socket, mock_logger): + # Setup + test_queue = queue.Queue() + mock_sock_instance = MagicMock() + mock_socket.return_value.__enter__.return_value = mock_sock_instance + + # Mock recvfrom() to return a message and address + received_message = { + 'mac_address': 'excluded_mac', + 'message_type': 'mac_announcement' + } + + # Set recvfrom() to return the message once and then raise an exception to break the loop + mock_sock_instance.recvfrom.side_effect = [ + (json.dumps(received_message).encode('utf-8'), ('::1', 10000)), + KeyboardInterrupt + ] + + # Create a MulticastHandler instance + # Mock get_mac_addr() to return an excluded MAC address + with patch('multicast.multicast.get_mac_addr', return_value='excluded_mac'): + handler = MulticastHandler(test_queue, 'ff02::1', 10000, 'lo') + + + with pytest.raises(KeyboardInterrupt): + handler.multicast_message() + + # Check that excluded mac is not added to queue + assert test_queue.empty() diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_secchannel.py b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_secchannel.py new file mode 100644 index 000000000..3bd45d66d --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_secchannel.py @@ -0,0 +1,123 @@ +import pytest +from unittest.mock import MagicMock, patch +import ssl +import json +import queue +import logging +import add_syspath +from secure_channel.secchannel import SecMessageHandler + +class MockLogger: + def __init__(self): + self.info = MagicMock() + self.error = MagicMock() + self.debug = MagicMock() + +# Mock Logger +@pytest.fixture +def mock_logger(monkeypatch): + with patch('secure_channel.secchannel.logger_instance.get_logger', return_value=MockLogger()) as mock_log: + yield { + 'mock_log': mock_log + } + +@pytest.fixture +def sec_message_handler(mock_logger): + mock_socket = MagicMock(spec=ssl.SSLSocket) + handler = SecMessageHandler(mock_socket) + return handler + +def test_set_callback(sec_message_handler, mock_logger): + mock_callback = MagicMock() + sec_message_handler.set_callback(mock_callback) + assert sec_message_handler.callback == mock_callback + + +def test_is_socket_active_true(sec_message_handler, monkeypatch, mock_logger): + monkeypatch.setattr(sec_message_handler.socket, 'fileno', lambda: 5) + assert sec_message_handler._is_socket_active() + + +def test_is_socket_active_false(sec_message_handler, monkeypatch, mock_logger): + monkeypatch.setattr(sec_message_handler.socket, 'fileno', lambda: -1) + assert not sec_message_handler._is_socket_active() + + +def test_is_ssl_socket(sec_message_handler, monkeypatch, mock_logger): + assert sec_message_handler._is_ssl_socket() + + +def test_send_message_no_ssl(sec_message_handler, monkeypatch, mock_logger): + monkeypatch.setattr(sec_message_handler.socket, '__class__', MagicMock()) + sec_message_handler.send_message("test_message") + sec_message_handler.socket.sendall.assert_not_called() + + +def test_send_message_ssl_inactive(sec_message_handler, monkeypatch, mock_logger): + monkeypatch.setattr(sec_message_handler.socket, 'fileno', lambda: -1) + sec_message_handler.send_message("test_message") + sec_message_handler.socket.sendall.assert_not_called() + + +def test_send_message_successful(sec_message_handler, monkeypatch, mock_logger): + mock_ssl_socket = MagicMock(spec=ssl.SSLSocket) + mock_ssl_socket.fileno.return_value = 5 + mock_ssl_socket.sendall = MagicMock() + monkeypatch.setattr(sec_message_handler, 'socket', mock_ssl_socket) + + sec_message_handler.send_message("test_message") + mock_ssl_socket.sendall.assert_called_once_with("test_message".encode()) + + +def test_receive_message_with_macsec_params(sec_message_handler, monkeypatch): + mock_message = json.dumps({ + 'bytes_for_my_key': 'some_value', + 'bytes_for_client_key': 'some_other_value', + 'port': 12345 + }) + + mock_ssl_socket = MagicMock(spec=ssl.SSLSocket) + mock_ssl_socket.getpeername.return_value = ('127.0.0.1', 54321) + mock_ssl_socket.recv.side_effect = [mock_message.encode('utf-8'), b"GOODBYE"] + + monkeypatch.setattr(sec_message_handler, 'socket', mock_ssl_socket) + + test_queue = queue.Queue() + + sec_message_handler.receive_message(macsec_param_q=test_queue) + + assert not test_queue.empty() + received_data = test_queue.get_nowait() + assert received_data == mock_message + + +def test_receive_message_successful(monkeypatch, caplog): + # Mocking the necessary methods and attributes + mock_ssl_socket = MagicMock(spec=ssl.SSLSocket) + mock_ssl_socket.getpeername.return_value = ('127.0.0.1', 12345) + + # Simulate the reception of three messages and then the GOODBYE message + mock_ssl_socket.recv.side_effect = [b"test_message_1", b"test_message_2", b"test_message_3", b"GOODBYE"] + sec_message_handler = SecMessageHandler(mock_ssl_socket) + #monkeypatch.setattr(sec_message_handler, 'socket', mock_ssl_socket) + + # You can also capture log outputs to check if the expected logs were created. + with caplog.at_level(logging.INFO): + sec_message_handler.receive_message() + + # Check that the logger contains the expected logs + assert "Received: test_message_1 from 127.0.0.1" in caplog.text + assert "Received: test_message_2 from 127.0.0.1" in caplog.text + assert "Received: test_message_3 from 127.0.0.1" in caplog.text + assert "Other end signaled end of communication." in caplog.text + +# Remove log directory during teardown +import shutil +import os +def remove_logs_directory(): + logs_directory = 'logs' + if os.path.exists(logs_directory) and os.path.isdir(logs_directory): + shutil.rmtree(logs_directory) + +def teardown_module(): + remove_logs_directory() \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_utils.py b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_utils.py new file mode 100644 index 000000000..c78410c84 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_utils.py @@ -0,0 +1,189 @@ +import pytest +from unittest.mock import patch, Mock, call + +import add_syspath +from tools.utils import * + +def test_is_ipv4(): + assert is_ipv4("192.168.1.1") + assert not is_ipv4("fe80::230:1aff:fe4f:cf3c") + +def test_is_ipv6(): + assert is_ipv6("fe80::230:1aff:fe4f:cf3c") + assert not is_ipv6("192.168.1.1") + +def test_mac_to_ipv6(): + assert mac_to_ipv6("00:30:1a:4f:cf:3c") == "fe80::0230:1aff:fe4f:cf3c" + assert mac_to_ipv6("04:f0:21:9e:6b:39") == "fe80::06f0:21ff:fe9e:6b39" + +def test_extract_mac_from_ipv6(): + assert extract_mac_from_ipv6("fe80::230:1aff:fe4f:cf3c") == "00:30:1a:4f:cf:3c" + assert extract_mac_from_ipv6("fe80::06f0:21ff:fe9e:6b39") == "04:f0:21:9e:6b:39" + +@patch('tools.utils.subprocess.check_output') +def test_is_wpa_supplicant_running(mock_check_output): + # Case where wpa_supplicant is running + mock_check_output.return_value = b'1234 ? 00:00:02 wpa_supplicant' + assert is_wpa_supplicant_running() + + # Case where wpa_supplicant is not running + mock_check_output.return_value = b'' + assert not is_wpa_supplicant_running() + +@pytest.mark.parametrize('rules_file', ['custom1.nft', 'custom2.nft']) +def test_apply_nft_rules(rules_file): + # Mock subprocess.run to simulate successful execution + with patch('tools.utils.subprocess.run', autospec=True) as mock_run: + # Mock logger.info + with patch('tools.utils.logger.info', autospec=True) as mock_logger_info: + apply_nft_rules(rules_file) + mock_run.assert_called_once_with(['nft', '-f', rules_file], check=True) +def test_batman_exec(): + with patch('tools.utils.run_batman', autospec=True) as mock_run_batman: + batman_exec("bat0", "batman-adv") + # assert that run_batman was called once with the correct interface + mock_run_batman.assert_called_once_with("bat0") + +def test_run_batman(): + # Mocking subprocess.run to always return success + with patch('tools.utils.subprocess.run', autospec=True) as mock_run: + # Mocking get_mac_addr to return a dummy MAC address + with patch('tools.utils.get_mac_addr', return_value='00:11:22:33:44:55', autospec=True) as mock_get_mac_addr: + # Mock logger.info + with patch('tools.utils.logger.info', autospec=True) as mock_logger_info: + run_batman("bat0") + + # Verifying get_mac_addr was called with correct interface + mock_get_mac_addr.assert_called_once_with("wlp1s0") + + # Verifying subprocess.run was called with the expected commands + mock_run.assert_any_call(["ip", "link", "set", "dev", "bat0", "address", "00:11:22:33:44:55"], check=True) + mock_run.assert_any_call(["ifconfig", "bat0", "up"], check=True) + mock_run.assert_any_call(["ifconfig", "bat0", "mtu", "1460"], check=True) + mock_run.assert_any_call(["ifconfig", "bat0"], check=True) + +def test_set_ipv6(): + # Dummy result to mimic the subprocess.run successful result + mock_result = Mock() + mock_result.stdout = "Command executed successfully" + + # Mocking subprocess.run to return the dummy result + with patch('tools.utils.subprocess.run', autospec=True) as mock_run: + # Mock logger.info + with patch('tools.utils.logger.info', autospec=True) as mock_logger_info: + set_ipv6("bat0", "2001:db8::1") + + # Verifying subprocess.run was called with the expected command + mock_run.assert_called_once_with(["ip", "-6", "addr", "add", "2001:db8::1/64", "dev", "bat0"], capture_output=True, text=True, check=True) + +def test_is_interface_pingable_ipv4_success(): + with patch('tools.utils.subprocess.check_output', return_value="1 packets transmitted, 1 received", autospec=True) as mock_check_output: + assert is_interface_pingable("bat0", "192.168.1.1") + mock_check_output.assert_called_once_with(['ping', '-c', '1', '-w', '1', "192.168.1.1"], stderr=subprocess.STDOUT, universal_newlines=True) + +def test_is_interface_pingable_ipv6_success(): + with patch('tools.utils.subprocess.check_output', return_value="1 packets transmitted, 1 received", autospec=True) as mock_check_output: + assert is_interface_pingable("bat0", "fe80::230:1aff:fe4f:cf3c") + mock_check_output.assert_called_once_with(['ping', '-c', '1', '-w', '1', "fe80::230:1aff:fe4f:cf3c%bat0"], stderr=subprocess.STDOUT, universal_newlines=True) + +def test_is_interface_pingable_failure(): + with patch('tools.utils.subprocess.check_output', return_value="", autospec=True) as mock_check_output: + assert not is_interface_pingable("bat0", "fe80::230:1aff:fe4f:cf3c") + +def test_is_interface_pingable_invalid_ip(): + with pytest.raises(ValueError, match="Invalid IP address"): + is_interface_pingable("eth0", "invalid_ip") + +def test_wait_for_interface_pingable_success_after_tries(): + mock_pingable_side_effect = [False, False, True] # 2 failures followed by a success + with patch('tools.utils.is_interface_pingable', side_effect=mock_pingable_side_effect, autospec=True) as mock_is_pingable, \ + patch('tools.utils.logger.info', autospec=True) as mock_logger_info, \ + patch('tools.utils.time.sleep', autospec=True) as mock_sleep: + + wait_for_interface_to_be_pingable("bat0", "fe80::230:1aff:fe4f:cf3c") + + # Verify is_interface_pingable was called 3 times + assert mock_is_pingable.call_count == 3 + # Verify sleep was called twice + mock_sleep.assert_has_calls([call(1), call(1)]) + +def test_interface_up(): + with patch('subprocess.check_output', return_value=b'bat0: flags=4163 mtu 1460\ninet6 fe80::230:1aff:fe4f:cf3c prefixlen 64 scopeid 0x20\nether 00:30:1a:4f:cf:3c txqueuelen 1000 (Ethernet)', autospec=True) as mock_subprocess: + assert is_interface_up('bat0') + mock_subprocess.assert_called_once_with(['ifconfig', 'bat0'], stderr=subprocess.DEVNULL) + +def test_interface_down(): + with patch('subprocess.check_output', return_value=b'bat0: flags=4163 mtu 1460\nether 00:30:1a:4f:cf:3c txqueuelen 1000 (Ethernet)', autospec=True) as mock_subprocess: + assert not is_interface_up('bat0') + mock_subprocess.assert_called_once_with(['ifconfig', 'bat0'], stderr=subprocess.DEVNULL) + +def test_interface_not_present(): + with patch('subprocess.check_output', side_effect=subprocess.CalledProcessError(1, 'ifconfig bat0'), autospec=True) as mock_subprocess: + assert not is_interface_up('bat0') + mock_subprocess.assert_called_once_with(['ifconfig', 'bat0'], stderr=subprocess.DEVNULL) + +def test_wait_for_interface_to_be_up_success_after_tries(): + mock_interface_up_side_effect = [False, False, True] # 2 failures followed by a success + with patch('tools.utils.is_interface_up', side_effect=mock_interface_up_side_effect, autospec=True) as mock_is_interface_up, \ + patch('tools.utils.logger.info', autospec=True) as mock_logger_info, \ + patch('tools.utils.time.sleep', autospec=True) as mock_sleep: + + wait_for_interface_to_be_up("bat0") + + # Verify is_interface_pingable was called 3 times + assert mock_is_interface_up.call_count == 3 + # Verify sleep was called twice + mock_sleep.assert_has_calls([call(1), call(1)]) + +def test_xor_bytes_without_padding(): + result = xor_bytes(b'\x01\x02\x03', b'\x11\x01\x01', 3) + assert result == b'\x10\x03\x02' + +def test_xor_bytes_with_padding(): + result = xor_bytes(b'\x01', b'\x02', 3) + assert result == b'\x00\x00\x03' + + result = xor_bytes(b'\x01\x02\x03', b'\x01\x01\x01', 3) + assert result == b'\x00\x03\x02' + + result = xor_bytes(b'\x01\x02', b'\x01\x01\x01', 3) + assert result == b'\x01\x00\x03' + +def test_xor_bytes_with_oversized_input(): + # Test with bytes that are longer than byte_size + result = xor_bytes(b'\x01\x02\x03\x04', b'\x01\x01\x01\x01', 3) + assert result == b'\x00\x03\x02' + + result = xor_bytes(b'\x01\x02\x03\x04', b'\x01\x01\x01', 3) + assert result == b'\x00\x03\x02' + + result = xor_bytes(b'\x01\x02\x03', b'\x01\x01\x01\x01', 3) + assert result == b'\x00\x03\x02' + +def test_add_interface_to_batman_successful(): + with patch("subprocess.run") as mock_run, \ + patch('tools.utils.logger.info', autospec=True) as mock_logger_info: + add_interface_to_batman("wlan0", "bat0") + mock_run.assert_called_once_with(["batctl", "meshif", "bat0", "if", "add", "wlan0"], check=True) + +def test_add_interface_to_batman_failed(): + with patch("subprocess.run", side_effect=Exception("Dummy error")) as mock_run, \ + patch("tools.utils.logger.error") as mock_logger_error: + + add_interface_to_batman("wlan0", "bat0") + mock_run.assert_called_once_with(["batctl", "meshif", "bat0", "if", "add", "wlan0"], check=True) + mock_logger_error.assert_called_once_with("Error adding interface wlan0 to bat0: Dummy error") + +def test_add_interface_to_bridge_successful(): + with patch("subprocess.run") as mock_run, \ + patch('tools.utils.logger.info', autospec=True) as mock_logger_info: + add_interface_to_bridge("eth0", "br0") + mock_run.assert_called_once_with(["brctl", "addif", "br0", "eth0"], check=True) + +def test_add_interface_to_bridge_failed(): + with patch("subprocess.run", side_effect=Exception("Dummy error")) as mock_run, \ + patch("tools.utils.logger.error") as mock_logger_error: + + add_interface_to_bridge("eth0", "br0") + mock_run.assert_called_once_with(["brctl", "addif", "br0", "eth0"], check=True) + mock_logger_error.assert_called_once_with("Error adding interface eth0 to br0: Dummy error") \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_verification_tools.py b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_verification_tools.py new file mode 100644 index 000000000..3ed927d72 --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/2_0/test_features/cbma/test_verification_tools.py @@ -0,0 +1,108 @@ +import datetime +from unittest import TestCase, mock +import pytest +import logging +import OpenSSL +logger = logging.getLogger("Test") +import add_syspath +from tools.verification_tools import verify_cert, CertificateDifferentCN, CertificateExpiredError, _verify_certificate_chain, validation + +# Get the current date and time +current_date = datetime.datetime.now() + +# Format the current date and time to match the format you desire +not_after = (current_date + datetime.timedelta(days=365 * 80)).strftime('%Y%m%d%H%M%SZ') # 80 years from now +not_before = (current_date - datetime.timedelta(days=365 * 2)).strftime('%Y%m%d%H%M%SZ') # 2 years ago + +# Mock OpenSSL certificate object +mock_cert_obj = mock.Mock() +mock_cert_obj.get_notAfter.return_value = not_after.encode() +mock_cert_obj.get_notBefore.return_value = not_before.encode() +mock_cert_obj.get_subject.return_value.CN = 'some_mac_address' + +# Mock data +cert = b'some_cert_data' +ca_cert = 'path_to_ca_cert' +IPaddress = 'some_ipv6_address' +interface = 'interface' + + +@mock.patch('tools.verification_tools._verify_certificate_chain', return_value=True) +@mock.patch('tools.verification_tools.OpenSSL.crypto.dump_publickey', return_value=b'some_public_key_data') +@mock.patch('tools.verification_tools.OpenSSL.crypto.load_certificate', return_value=mock_cert_obj) +@mock.patch('tools.verification_tools.get_mac_from_ipv6', return_value='some_mac_address') +def test_verify_cert_success(mock_get_mac, mock_load_certificate, mock_dump_publickey, mock_verify_chain): + assert verify_cert(cert, ca_cert, IPaddress, interface, mock.Mock()) is True + + +# Sample of ca cert +ca_cert_pem = b"""-----BEGIN CERTIFICATE----- +MIIBcTCCARegAwIBAgIUe3U5E1wplyB8SQ9XVAIJrCAPp74wCgYIKoZIzj0EAwIw +DjEMMAoGA1UEAwwDVElJMB4XDTIzMTAyNzA5MjQxM1oXDTI0MTAyNjA5MjQxM1ow +DjEMMAoGA1UEAwwDVElJMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEb0ey66LW +i4IbPAXcAXg5sp56HXGLemQVGfTiE1Nlwxzu/PTOybOUZe70ssceR1gZo6YNtywj +HXMG+GA0f1+OOaNTMFEwHQYDVR0OBBYEFDgm8kYZpzMYxPJQmcFey89qzaoCMB8G +A1UdIwQYMBaAFDgm8kYZpzMYxPJQmcFey89qzaoCMA8GA1UdEwEB/wQFMAMBAf8w +CgYIKoZIzj0EAwIDSAAwRQIhAJ+eESqOmyJikkNshd0vE6GcgG/sXNyi94i9eoJe +0xi1AiBbAN95OGIsDC6kXm03l26kjGOEbTKMtfw98m1+SjjyXg== +-----END CERTIFICATE-----""" + +# Sample of ca signed certificate +cert_pem = b"""-----BEGIN CERTIFICATE----- +MIIBJDCBywIUSPf4jQ/6c2+pkOrQT+dd84Brk38wCgYIKoZIzj0EAwIwDjEMMAoG +A1UEAwwDVElJMB4XDTIzMTAyNzA5MjQzMloXDTI0MTAyNjA5MjQzMlowHDEaMBgG +A1UEAwwRMDQ6ZjA6MjE6OWU6NmI6MzkwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC +AAR9f+vKjysitibhyyVdXgHfszGmTnigLPo2jMikw6YvM78C3LBrXie4WGEKMxBU +1oZHrit1bf/UtKYJR7KjuDLtMAoGCCqGSM49BAMCA0gAMEUCIQDKbKoESzjZqbeY +W9hUVtMr6tBeOeQBu4k1Ob6yBI4gowIgZDMOUsItvJn1CmHGxT/q8FdPEG5/w4qL +a/RrY89BYTc= +-----END CERTIFICATE-----""" + + +@mock.patch('builtins.open', mock.mock_open(read_data=ca_cert_pem)) +def test_verify_certificate_chain_pass(): + # Convert the certificate from PEM to DER format + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem) + cert_der = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) + + # Run the function under test with valid certificate + result = _verify_certificate_chain(cert_der, 'path_to_ca_cert', mock.MagicMock()) + + # Assertions + assert result + +@mock.patch('builtins.open', mock.mock_open(read_data=ca_cert_pem)) +def test_verify_certificate_chain_failure(): + # Run the function under test with fake certificate + result = _verify_certificate_chain("Fake certificate", 'path_to_ca_cert', mock.MagicMock()) + + # Assertions + assert not result + + +@mock.patch('tools.verification_tools.OpenSSL.crypto.load_certificate', return_value=mock_cert_obj) +@mock.patch('tools.verification_tools.get_mac_from_ipv6', return_value='different_mac_address') +def test_certificate_different_cn(mock_get_mac, mock_load_certificate): + with pytest.raises(CertificateDifferentCN): # Checking that exception is raised + validation(cert, ca_cert, IPaddress, interface, mock.Mock()) + assert not verify_cert(cert, ca_cert, IPaddress, interface, mock.Mock()) # Assert that certificate verification fails + + +@mock.patch('tools.verification_tools.OpenSSL.crypto.load_certificate', return_value=mock_cert_obj) +def test_certificate_expired(mock_load_certificate): + # Change the 'notAfter' value to a date in the past + mock_cert_obj.get_notAfter.return_value = (current_date - datetime.timedelta(days=1)).strftime('%Y%m%d%H%M%SZ').encode() + with pytest.raises(CertificateExpiredError): # Checking that exception is raised + validation(cert, ca_cert, IPaddress, interface, mock.Mock()) + assert not verify_cert(cert, ca_cert, IPaddress, interface, mock.Mock()) # Assert that certificate verification fails + +# Remove log directory during teardown +import shutil +import os +def remove_logs_directory(): + logs_directory = 'logs' + if os.path.exists(logs_directory) and os.path.isdir(logs_directory): + shutil.rmtree(logs_directory) + +def teardown_module(): + remove_logs_directory() \ No newline at end of file