diff --git a/modules/sc-mesh-secure-deployment/src/nats/comms_nats_controller.py b/modules/sc-mesh-secure-deployment/src/nats/comms_nats_controller.py index c1857ff3c..81706b176 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/comms_nats_controller.py +++ b/modules/sc-mesh-secure-deployment/src/nats/comms_nats_controller.py @@ -70,6 +70,7 @@ def stop(self): self.batman.thread_running = False # thread loop disabled self.t2.join() # wait for thread to finish + # pylint: disable=too-many-instance-attributes class CommsController: # pylint: disable=too-few-public-methods """ @@ -99,7 +100,8 @@ def __init__(self, server: str, port: str, interval: int = 1000): # logger for this module and derived from main logger self.logger = self.main_logger.getChild("controller") -class CommsCsa: # pylint: disable=too-few-public-methods + +class CommsCsa: # pylint: disable=too-few-public-methods """ Comms CSA class to storage settings for CSA for a state change """ @@ -117,6 +119,15 @@ async def main(server, port, keyfile=None, certfile=None, interval=1000): nats_client = NATS() csac = CommsCsa() + status, _, identity_dict = cc.command.get_identity() + + if status == "OK": + identity = identity_dict["identity"] + cc.logger.debug("Identity: %s", identity) + else: + cc.logger.error("Failed to get identity!") + return + async def stop(): await asyncio.sleep(1) asyncio.get_running_loop().stop() @@ -156,6 +167,7 @@ async def reconnected_cb(): reconnected_cb=reconnected_cb, disconnected_cb=disconnected_cb, max_reconnect_attempts=-1) + async def handle_settings_csa_post(ret): if ret == "OK": ret = "ACK" @@ -181,16 +193,16 @@ async def message_handler(message): cc.logger.debug("Received a message on '%s': %s", subject, data) ret, info, resp = "FAIL", "Not supported subject", "" - if subject == "comms.settings": + if subject == f"comms.settings.{identity}": ret, info = cc.settings.handle_mesh_settings(data) elif subject == "comms.settings_csa": ret, info, delay = cc.settings.handle_mesh_settings_csa(data) csac.delay = delay csac.ack_sent = "status" in data - elif subject == "comms.command": + elif subject == f"comms.command.{identity}" or subject == "comms.identity": ret, info, resp = cc.command.handle_command(data, cc) - elif subject == "comms.status": + elif subject == f"comms.status.{identity}": ret, info = "OK", "Returning current status" if subject == "comms.settings_csa": @@ -204,7 +216,7 @@ async def message_handler(message): 'visualisation_active': cc.comms_status.is_visualisation_active, 'mesh_radio_on': cc.comms_status.is_mesh_radio_on, 'ap_radio_on': cc.comms_status.is_ap_radio_on, - 'security_status': cc.comms_status.security_status } + 'security_status': cc.comms_status.security_status} if resp != "": response['data'] = resp @@ -212,10 +224,11 @@ async def message_handler(message): cc.logger.debug("Sending response: %s", str(response)[:1000]) await message.respond(json.dumps(response).encode("utf-8")) - await nats_client.subscribe("comms.settings", cb=message_handler) + await nats_client.subscribe(f"comms.settings.{identity}", cb=message_handler) await nats_client.subscribe("comms.settings_csa", cb=message_handler) - await nats_client.subscribe("comms.command", cb=message_handler) - await nats_client.subscribe("comms.status", cb=message_handler) + await nats_client.subscribe(f"comms.command.{identity}", cb=message_handler) + await nats_client.subscribe("comms.identity", cb=message_handler) + await nats_client.subscribe(f"comms.status.{identity}", cb=message_handler) cc.logger.debug("comms_nats_controller Listening for requests") while True: @@ -223,8 +236,8 @@ async def message_handler(message): try: if cc.telemetry.visualisation_enabled: msg = cc.telemetry.mesh_visual() - cc.logger.debug("Publishing comms.visual: %s", msg) - await nats_client.publish("comms.visual", msg.encode()) + cc.logger.debug(f"Publishing comms.visual.{identity}: %s", msg) + await nats_client.publish(f"comms.visual.{identity}", msg.encode()) except Exception as e: cc.logger.error("Error:", e) diff --git a/modules/sc-mesh-secure-deployment/src/nats/comms_nats_discovery.py b/modules/sc-mesh-secure-deployment/src/nats/comms_nats_discovery.py index 65825cfdf..df876fef0 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/comms_nats_discovery.py +++ b/modules/sc-mesh-secure-deployment/src/nats/comms_nats_discovery.py @@ -4,18 +4,28 @@ import time import logging import argparse -import netifaces as ni +import netifaces as netifaces +import os +import textwrap -class NatsDiscovery: + +class NatsDiscovery: # pylint: disable=too-few-public-methods """ Nats Discovery class. Utilizes the batctl command to discover devices on the mesh network. """ - def __init__(self, role, key, cert): + def __init__(self, role, key, cert, servercert, ca): self.role = role self.key = key self.cert = cert + self.server_cert = servercert + self.cert_authority = ca self.leaf_port = 7422 self.seed_ip_address = "" + self.tls_required = False + + if os.path.exists(self.key) and os.path.exists(self.cert) \ + and os.path.exists(self.server_cert) and os.path.exists(self.cert_authority): + self.tls_required = True # base logger for discovery self.main_logger = logging.getLogger("nats") @@ -29,40 +39,41 @@ def __init__(self, role, key, cert): # logger for this module and derived from main logger self.logger = self.main_logger.getChild("discovery") - def __get_authorization_config(self) -> str: - """ - Get the authorization configuration for the nats-server configuration file. - :return: authorization configuration - """ - if self.key is not None or self.cert is not None: - authorization = f""" -tls {{ - cert_file: "{self.cert}" - key_file: "{self.key}" - timeout: 2 - verify: true -}} -""" - else: - authorization = "" - - return authorization - def __generate_seed_config(self) -> None: """ Generate the nats-server configuration file for the seed node. :return: None """ - config = f""" -listen: 0.0.0.0:4222 -leafnodes {{ - port: 7422 -}} -{self.__get_authorization_config()} -""" - with open('/var/run/nats.conf', 'w') as f: - f.write(config) + if self.tls_required: + config = textwrap.dedent(f""" + listen: 0.0.0.0:4222 + tls {{ + key_file: {self.key} + cert_file: {self.server_cert} + ca_file: {self.cert_authority} + verify: true + }} + leafnodes {{ + port: 7422 + tls {{ + key_file: {self.key} + cert_file: {self.server_cert} + ca_file: {self.cert_authority} + verify: true + }} + }} + """) + else: + config = textwrap.dedent(""" + listen: 0.0.0.0:4222 + leafnodes { + port: 7422 + } + """) + + with open('/var/run/nats.conf', 'w', encoding='UTF-8') as file_nats_conf: + file_nats_conf.write(config) def __generate_leaf_config(self, _seed_route) -> None: """ @@ -70,19 +81,45 @@ def __generate_leaf_config(self, _seed_route) -> None: :param _seed_route: seed node route :return: None """ - config = f""" -listen: 0.0.0.0:4222 -leafnodes {{ - remotes = [ - {{ - url: "nats://{_seed_route}" - }}, - ] -}} -{self.__get_authorization_config()} -""" - with open('/var/run/nats.conf', 'w') as f: - f.write(config) + if self.tls_required: + protocol = "tls" + config = textwrap.dedent(f""" + listen: 0.0.0.0:4222 + tls {{ + key_file: {self.key} + cert_file: {self.server_cert} + ca_file: {self.cert_authority} + verify: true + }} + leafnodes {{ + remotes = [ + {{ + url: "{protocol}://{_seed_route}" + tls {{ + key_file: {self.key} + cert_file: {self.cert} + ca_file: {self.cert_authority} + verify: true + }} + }} + ] + }} + """) + else: + protocol = "nats" + config = textwrap.dedent(f""" + listen: 0.0.0.0:4222 + leafnodes {{ + remotes = [ + {{ + url: "{protocol}://{_seed_route}" + }} + ] + }} + """) + + with open('/var/run/nats.conf', 'w', encoding='UTF-8') as file_nats_conf: + file_nats_conf.write(config) def __reload_nats_server_config(self) -> int: """ @@ -102,14 +139,14 @@ def __reload_nats_server_config(self) -> int: return 0 - def __update_configurations_and_restart(self, ip) -> None: + def __update_configurations_and_restart(self, ip_address) -> None: """ Update the nats-server configuration and restart the nats-server. - :param ip: ip address of the seed node + :param ip_address: ip address of the seed node :return: None """ self.logger.debug("Updating configurations and reloading nats-server configuration") - self.__generate_leaf_config(ip) + self.__generate_leaf_config(ip_address) # reload nats-server configuration ret = self.__reload_nats_server_config() @@ -124,12 +161,13 @@ def __get_mesh_macs() -> list: :return: list of mac addresses """ try: - ret = subprocess.run(["batctl", "o", "-H"], shell=False, check=True, capture_output=True) + ret = subprocess.run(["batctl", "o", "-H"], shell=False, + check=True, capture_output=True) if ret.returncode != 0: return [] - else: - macs = re.findall(r' \* (([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2}))', ret.stdout.decode('utf-8')) - return [mac[0] for mac in macs] + macs = re.findall(r' \* (([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2}))', + ret.stdout.decode('utf-8')) + return [mac[0] for mac in macs] except: return [] @@ -140,9 +178,9 @@ def __mac_to_ip(mac) -> str: :param mac: mac address :return: ip address """ - ip_br_lan = ni.ifaddresses('br-lan')[ni.AF_INET][0]['addr'].split(".")[0:-1] + ip_br_lan = netifaces.ifaddresses('br-lan')[netifaces.AF_INET][0]['addr'].split(".")[0:-1] ip_br_lan = ".".join(ip_br_lan) + "." - return ip_br_lan + str(int(mac.split(":")[-1],16)) + return ip_br_lan + str(int(mac.split(":")[-1], 16)) @staticmethod def __scan_port(host, port) -> bool: @@ -171,20 +209,19 @@ def run(self): self.__generate_seed_config() self.__reload_nats_server_config() return - else: - # create temporary leaf configuration for nats-server to start - self.__generate_leaf_config("192.168.1.2") + # create temporary leaf configuration for nats-server to start + self.__generate_leaf_config("192.168.1.2") while True: macs = self.__get_mesh_macs() self.logger.debug(f"{macs} len: {len(macs)}") for mac in macs: - ip = self.__mac_to_ip(mac) + ip_address = self.__mac_to_ip(mac) if self.seed_ip_address == "": - self.logger.debug(f"Scanning {ip}, {mac}") - if self.__scan_port(ip, self.leaf_port): - self.seed_ip_address = ip + self.logger.debug(f"Scanning {ip_address}, {mac}") + if self.__scan_port(ip_address, self.leaf_port): + self.seed_ip_address = ip_address gcs_found = 1 if gcs_found: @@ -195,17 +232,15 @@ def run(self): time.sleep(4) + if __name__ == "__main__": - """ - Main function. - :param args: command line arguments - :return: None - """ parser = argparse.ArgumentParser(description='NATS Discovery') parser.add_argument('-r', '--role', help='device role', required=True) parser.add_argument('-k', '--key', help='key file', required=False) - parser.add_argument('-c', '--cert', help='cert file', required=False) + parser.add_argument('-c', '--cert', help='client cert file', required=False) + parser.add_argument('-s', '--servercert', help='server cert file', required=False) + parser.add_argument('-a', '--ca', help='certificate authority file', required=False) args = parser.parse_args() - forrest = NatsDiscovery(args.role, args.key, args.cert) + forrest = NatsDiscovery(args.role, args.key, args.cert, args.servercert, args.ca) forrest.run() diff --git a/modules/sc-mesh-secure-deployment/src/nats/initd/S90nats_discovery b/modules/sc-mesh-secure-deployment/src/nats/initd/S90nats_discovery index 11e775493..7da763190 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/initd/S90nats_discovery +++ b/modules/sc-mesh-secure-deployment/src/nats/initd/S90nats_discovery @@ -17,9 +17,11 @@ source_configuration KEY_FILE="/etc/ssl/private/comms_auth_private_key.pem" CERT_FILE="/etc/ssl/certs/comms_auth_cert.pem" +SERVER_CERT_FILE="/etc/ssl/certs/comms_server_cert.pem" +CA="/etc/ssl/certs/root-ca.cert.pem" if [ -e "$KEY_FILE" ] && [ -e "$CERT_FILE" ]; then - ARGS="--role $ROLE -k $KEY_FILE -c $CERT_FILE" + ARGS="--role $ROLE -k $KEY_FILE -c $CERT_FILE -s $SERVER_CERT_FILE -a $CA" else ARGS="--role $ROLE" fi diff --git a/modules/sc-mesh-secure-deployment/src/nats/initd/S90nats_server b/modules/sc-mesh-secure-deployment/src/nats/initd/S90nats_server index 0cd1c0550..c50f0ad8f 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/initd/S90nats_server +++ b/modules/sc-mesh-secure-deployment/src/nats/initd/S90nats_server @@ -12,14 +12,7 @@ LOG_FILE=/opt/nats-server.log # shellcheck source=/dev/null [ -r "/etc/default/$DAEMON" ] && . "/etc/default/$DAEMON" -KEY_FILE="/etc/ssl/private/comms_auth_private_key.pem" -CERT_FILE="/etc/ssl/certs/comms_auth_cert.pem" - -if [ -e "$KEY_FILE" ] && [ -e "$CERT_FILE" ]; then - NATS_SERVER_ARGS="-l $LOG_FILE -c /var/run/nats.conf --tlsverify --tlscert=$CERT_FILE --tlskey=$KEY_FILE" -else - NATS_SERVER_ARGS="-l $LOG_FILE -c /var/run/nats.conf" -fi +NATS_SERVER_ARGS="-l $LOG_FILE -c /var/run/nats.conf" start() { echo "$NATS_SERVER_ARGS" diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_identity.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/_cli_command_get_identity.py similarity index 51% rename from modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_identity.py rename to modules/sc-mesh-secure-deployment/src/nats/scripts/_cli_command_get_identity.py index 54c174e63..33f8d6e02 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_identity.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/_cli_command_get_identity.py @@ -1,31 +1,31 @@ import asyncio -import nats +import client import json -import base64 -import config async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "GET_IDENTITY"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", - cmd.encode(), - timeout=2) + rep = await nc.request("comms.identity", + cmd.encode(), + timeout=2) print(rep.data) parameters = json.loads(rep.data.decode()) print(json.dumps(parameters, indent=2)) + if "identity" in parameters["data"]: + with open("identity.py", "w") as f: + f.write(f"MODULE_IDENTITY=\"{parameters['data']['identity']}\"\n") + else: + print("No identity received!!!!!!!!!!!") await nc.close() exit(0) if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() - - diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_config_hostapd.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_config_hostapd.py index 1b4151595..cfe38039f 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_config_hostapd.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_config_hostapd.py @@ -1,5 +1,5 @@ import asyncio -import nats +import client import json import base64 import config @@ -7,27 +7,23 @@ async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "GET_CONFIG", "param": "HOSTAPD_CONFIG"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", - cmd.encode(), - timeout=4) + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", + cmd.encode(), + timeout=4) parameters = json.loads(rep.data.decode()) if parameters["status"] == "OK": b64_data = base64.b64decode(parameters["data"].encode()) print(b64_data.decode()) elif parameters["status"] == "FAIL": print(parameters) - await nc.close() exit(0) if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() - - diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_config_wpa.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_config_wpa.py index 116a72cb6..1a22c0383 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_config_wpa.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_get_config_wpa.py @@ -1,5 +1,5 @@ import asyncio -import nats +import client import json import base64 import config @@ -7,13 +7,13 @@ async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "GET_CONFIG", "param": "WPA_CONFIG"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", - cmd.encode(), - timeout=4) + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", + cmd.encode(), + timeout=4) parameters = json.loads(rep.data.decode()) if parameters["status"] == "OK": b64_data = base64.b64decode(parameters["data"].encode()) @@ -27,7 +27,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() - - diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_controller.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_controller.py index 41e25ec60..7e673c97c 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_controller.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_controller.py @@ -1,5 +1,5 @@ import asyncio -import nats +import client import json import base64 import config @@ -7,11 +7,11 @@ async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "LOGS", "param": "CONTROLLER"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", cmd.encode(), timeout=2) print(rep.data) @@ -25,7 +25,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() - - diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_dmesg.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_dmesg.py index 2a9356f09..fbac515cb 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_dmesg.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_dmesg.py @@ -1,5 +1,5 @@ import asyncio -import nats +import client import json import base64 import config @@ -7,11 +7,11 @@ async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "LOGS", "param": "DMESG"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", cmd.encode(), timeout=2) print(rep.data) @@ -25,7 +25,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() - - diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_hostapd.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_hostapd.py index 8daacf898..9f572cab2 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_hostapd.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_hostapd.py @@ -1,5 +1,5 @@ import asyncio -import nats +import client import json import base64 import config @@ -7,13 +7,13 @@ async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "LOGS", "param": "HOSTAPD"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", - cmd.encode(), - timeout=2) + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", + cmd.encode(), + timeout=2) print(rep.data) parameters = json.loads(rep.data.decode()) if parameters["data"] is None: @@ -26,7 +26,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() - - diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_wpa.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_wpa.py index ea1a34e98..70c79d57c 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_wpa.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_command_logs_wpa.py @@ -1,5 +1,5 @@ import asyncio -import nats +import client import json import base64 import config @@ -7,13 +7,13 @@ async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "LOGS", "param": "WPA"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", - cmd.encode(), - timeout=2) + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", + cmd.encode(), + timeout=2) print(rep.data) parameters = json.loads(rep.data.decode()) if parameters["data"] is None: @@ -28,7 +28,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() - - diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_radio_off.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_radio_off.py index 6d781bc00..91c09a258 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_radio_off.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_radio_off.py @@ -1,15 +1,16 @@ import asyncio -import nats +import client import json import config + async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() - cmd_dict = {"api_version": 1,"cmd": "DOWN"} + cmd_dict = {"api_version": 1, "cmd": "DOWN"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", cmd.encode(), timeout=2) parameters = json.loads(rep.data) print(parameters) @@ -20,5 +21,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() - loop.close() \ No newline at end of file + loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_radio_on.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_radio_on.py index b2d01af81..7c875ac2d 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_radio_on.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_radio_on.py @@ -1,15 +1,16 @@ import asyncio -import nats +import client import json import config + async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() - cmd_dict = {"api_version": 1,"cmd": "UP"} + cmd_dict = {"api_version": 1, "cmd": "UP"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", cmd.encode(), timeout=2) parameters = json.loads(rep.data) print(parameters) @@ -20,5 +21,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() - loop.close() \ No newline at end of file + loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_apply.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_apply.py index 2be1c2f71..bbd7001e3 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_apply.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_apply.py @@ -1,18 +1,17 @@ import asyncio -import nats +import client import json import config async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "APPLY"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", - cmd.encode(), - timeout=5) + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", + cmd.encode(), timeout=10) parameters = json.loads(rep.data) print(parameters) @@ -23,5 +22,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request.py index de850df06..eb37c41f6 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request.py @@ -1,19 +1,19 @@ import asyncio -import nats +import client import json import config async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "ssid": "test_mesh", "key": "1234567890", "ap_mac": "00:11:22:33:44:55", "country": "FI", "frequency": "5220", "frequency_mcc": "2412", "routing": "batman-adv", "priority": "long_range", "ip": "192.168.1.2", "subnet": "255.255.255.0", "tx_power": "5", "mode": "mesh", "role": f"{config.MODULE_ROLE}"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.settings", + rep = await nc.request(f"comms.settings.{config.MODULE_IDENTITY}", cmd.encode(), timeout=2) parameters = json.loads(rep.data) @@ -26,5 +26,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_csa.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_csa.py index d0fce2c43..8b3b61837 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_csa.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_csa.py @@ -1,13 +1,12 @@ import asyncio -import nats +import client import json -import config async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") - cmd_dict = {"frequency": "2472", "delay": "1", "amount": "2"} + nc = await client.connect_nats() + cmd_dict = {"frequency": "2412", "delay": "1", "amount": "2"} cmd = json.dumps(cmd_dict) rep = await nc.publish("comms.settings_csa", cmd.encode()) print(f"Published to comms.settings_csa: {cmd} ({rep})") @@ -18,5 +17,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_mcc.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_mcc.py index 129606e3f..841802a75 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_mcc.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_mcc.py @@ -1,19 +1,19 @@ import asyncio -import nats +import client import json import config async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "ssid": "test_mesh", "key": "1234567890", "ap_mac": "00:11:22:33:44:55", "country": "FI", "frequency": "5220", "frequency_mcc": "2412", "routing": "batman-adv", "priority": "long_range", "ip": "192.168.1.2", "subnet": "255.255.255.0", "tx_power": "5", "mode": "ap+mesh_mcc", "role": f"{config.MODULE_ROLE}"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.settings", + rep = await nc.request(f"comms.settings.{config.MODULE_IDENTITY}", cmd.encode(), timeout=2) parameters = json.loads(rep.data) @@ -26,5 +26,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_scc.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_scc.py index 3ff3f1ce2..4bff03be5 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_scc.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_request_scc.py @@ -1,19 +1,19 @@ import asyncio -import nats +import client import json import config async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "ssid": "test_mesh", "key": "1234567890", "ap_mac": "00:11:22:33:44:55", "country": "FI", "frequency": "5220", "frequency_mcc": "2412", "routing": "batman-adv", "priority": "long_range", "ip": "192.168.1.2", "subnet": "255.255.255.0", "tx_power": "5", "mode": "ap+mesh_scc", "role": f"{config.MODULE_ROLE}"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.settings", + rep = await nc.request(f"comms.settings.{config.MODULE_IDENTITY}", cmd.encode(), timeout=2) parameters = json.loads(rep.data) @@ -26,5 +26,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_revoke.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_revoke.py index f3e14a82d..a3767f215 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_revoke.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_settings_revoke.py @@ -1,15 +1,16 @@ import asyncio -import nats +import client import json import config + async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "REVOKE"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", cmd.encode(), timeout=4) parameters = json.loads(rep.data) @@ -21,5 +22,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() - loop.close() \ No newline at end of file + loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_status_check.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_status_check.py index f63a8111c..947bf3af1 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_status_check.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_status_check.py @@ -1,15 +1,15 @@ import asyncio -import nats +import client import json import config async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.status", + rep = await nc.request(f"comms.status.{config.MODULE_IDENTITY}", cmd.encode(), timeout=1) parameters = json.loads(rep.data) @@ -22,5 +22,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() - loop.close() \ No newline at end of file + loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_subscribe_settings_csa.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_subscribe_settings_csa.py index 0a86ee3d2..ba3ac16a6 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_subscribe_settings_csa.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_subscribe_settings_csa.py @@ -2,7 +2,8 @@ import signal import json from nats.aio.client import Client as NATS -import config +import client + async def run(loop): nc = NATS() @@ -17,7 +18,7 @@ async def closed_cb(): await nc.close() async def reconnected_cb(): - print(f"Connected to NATS ...") + print("Connected to NATS ...") async def subscribe_handler(msg): subject = msg.subject @@ -27,10 +28,9 @@ async def subscribe_handler(msg): subject=subject, reply=reply, data=data)) try: - await nc.connect(f"nats://{config.MODULE_IP}:{config.MODULE_PORT}", - reconnected_cb=reconnected_cb, - closed_cb=closed_cb, - max_reconnect_attempts=-1) + await client.connect(nc, reconnected_cb=reconnected_cb, + closed_cb=closed_cb, + max_reconnect_attempts=-1) except Exception as e: print(e) @@ -53,4 +53,4 @@ def signal_handler(): try: loop.run_forever() finally: - loop.close() \ No newline at end of file + loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_start.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_start.py index 6c396a4dd..3bc777a0b 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_start.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_start.py @@ -1,14 +1,15 @@ import asyncio -import nats +import client import json import config + async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "ENABLE_VISUALISATION", "interval": "1000"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", cmd.encode(), timeout=2) parameters = json.loads(rep.data.decode()) @@ -19,5 +20,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() - loop.close() \ No newline at end of file + loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_stop.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_stop.py index a998ce6f9..3a98f87a7 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_stop.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_stop.py @@ -1,15 +1,15 @@ import asyncio -import nats +import client import json import config async def main(): # Connect to NATS! - nc = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + nc = await client.connect_nats() cmd_dict = {"api_version": 1, "cmd": "DISABLE_VISUALISATION", "interval": "1000"} cmd = json.dumps(cmd_dict) - rep = await nc.request("comms.command", + rep = await nc.request(f"comms.command.{config.MODULE_IDENTITY}", cmd.encode(), timeout=2) parameters = json.loads(rep.data) @@ -21,5 +21,4 @@ async def main(): if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main()) - loop.run_forever() loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_subscribe.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_subscribe.py index 1e70fc1d2..557ff72a8 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_subscribe.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/cli_visual_subscribe.py @@ -2,7 +2,8 @@ import signal import json from nats.aio.client import Client as NATS -import config +import client + async def run(loop): nc = NATS() @@ -17,7 +18,7 @@ async def closed_cb(): await nc.close() async def reconnected_cb(): - print(f"Connected to NATS ...") + print("Connected to NATS ...") async def subscribe_handler(msg): subject = msg.subject @@ -27,10 +28,9 @@ async def subscribe_handler(msg): subject=subject, reply=reply, data=data)) try: - await nc.connect(f"nats://{config.MODULE_IP}:{config.MODULE_PORT}", - reconnected_cb=reconnected_cb, - closed_cb=closed_cb, - max_reconnect_attempts=-1) + await client.connect(nc, reconnected_cb=reconnected_cb, + closed_cb=closed_cb, + max_reconnect_attempts=-1) except Exception as e: print(e) @@ -45,7 +45,7 @@ def signal_handler(): for sig in ('SIGINT', 'SIGTERM'): loop.add_signal_handler(getattr(signal, sig), signal_handler) - await nc.subscribe("comms.visual", "", cb=subscribe_handler) + await nc.subscribe("comms.visual.*", "", cb=subscribe_handler) if __name__ == '__main__': loop = asyncio.get_event_loop() @@ -53,4 +53,4 @@ def signal_handler(): try: loop.run_forever() finally: - loop.close() \ No newline at end of file + loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/client.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/client.py new file mode 100644 index 000000000..52366244f --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/client.py @@ -0,0 +1,51 @@ +import ssl +import os +import nats +from nats.aio.client import Client +import config + +client_cert = "/etc/ssl/certs/comms_auth_cert.pem" +key = "/etc/ssl/private/comms_auth_private_key.pem" +ca_cert = "/etc/ssl/certs/root-ca.cert.pem" + + +async def connect_nats(): + if os.path.exists(client_cert) and \ + os.path.exists(key) and \ + os.path.exists(ca_cert): + + ssl_context = ssl.create_default_context() + ssl_context.load_cert_chain(certfile=client_cert, keyfile=key) + ssl_context.load_verify_locations(cafile=ca_cert) + + nats_client = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}", + tls=ssl_context) + + else: + nats_client = await nats.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}") + + return nats_client + + +async def connect(nats_client: Client, recon_cb=None, closed_cb=None, max_recon_attempts=None): + if os.path.exists(client_cert) and \ + os.path.exists(key) and \ + os.path.exists(ca_cert): + + ssl_context = ssl.create_default_context() + ssl_context.load_cert_chain(certfile=client_cert, keyfile=key) + ssl_context.load_verify_locations(cafile=ca_cert) + + await nats_client.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}", + tls=ssl_context, + reconnected_cb=recon_cb, + closed_cb=closed_cb, + max_reconnect_attempts=max_recon_attempts) + + else: + await nats_client.connect(f"{config.MODULE_IP}:{config.MODULE_PORT}", + reconnected_cb=recon_cb, + closed_cb=closed_cb, + max_reconnect_attempts=max_recon_attempts) + + return None diff --git a/modules/sc-mesh-secure-deployment/src/nats/scripts/config.py b/modules/sc-mesh-secure-deployment/src/nats/scripts/config.py index 7afb6144a..db201fd58 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/scripts/config.py +++ b/modules/sc-mesh-secure-deployment/src/nats/scripts/config.py @@ -1,3 +1,19 @@ -MODULE_IP = "192.168.1.3" +""" +This file contains the configuration for the NATS client. +""" +try: + from identity import MODULE_IDENTITY as IDENTITY +except ImportError: + IDENTITY = "no_identity" + print("No identity found!!!!!!!!!!!!!!!!!!!!!!!!") + print("Please run _cli_command_get_identity.py to get the identity (creates identity.py)") + +MODULE_IP = "10.10.10.2" # or 192.168.1.x - brlan ip address MODULE_PORT = "4222" + +# IPv6 connection example +# MODULE_IP = "[2001:db8:1234:5678:2e0:4cff:fe68:7a73]" +# MODULE_PORT = "6222" + MODULE_ROLE = "drone" # drone, sleeve or gcs +MODULE_IDENTITY = IDENTITY # messages are sent to this device diff --git a/modules/sc-mesh-secure-deployment/src/nats/src/comms_command.py b/modules/sc-mesh-secure-deployment/src/nats/src/comms_command.py index f87d19d7d..f0d8bc723 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/src/comms_command.py +++ b/modules/sc-mesh-secure-deployment/src/nats/src/comms_command.py @@ -43,7 +43,7 @@ class ConfigFiles: # pylint: disable=too-few-public-methods IDENTITY = "/opt/identity" -class Command: # pylint: disable=too-few-public-methods +class Command: # pylint: disable=too-few-public-methods, too-many-instance-attributes """ Command class """ @@ -111,7 +111,7 @@ def handle_command(self, msg: str, cc, csa=False, delay="0") -> (str, str, str): elif self.command == COMMAND.get_config: ret, info, data = self.__get_configs(self.param) elif self.command == COMMAND.get_identity: - ret, info, data = self.__get_identity() + ret, info, data = self.get_identity() else: ret, info = "FAIL", "Command not supported" return ret, info, data @@ -347,7 +347,7 @@ def __get_configs(self, param) -> (str, str, str): else: return "OK", f"{param}", file_b64.decode() - def __get_identity(self) -> (str, str, dict): + def get_identity(self) -> (str, str, dict): identity_dict = {} try: files = ConfigFiles() @@ -361,5 +361,5 @@ def __get_identity(self) -> (str, str, dict): except: return "FAIL", "Not able to get identity file", None - self.logger.debug("__get_identity done") + self.logger.debug("get_identity done") return "OK", "Identity and NATS URL", identity_dict \ No newline at end of file diff --git a/modules/sc-mesh-secure-deployment/src/nats/src/comms_hsm_controller.py b/modules/sc-mesh-secure-deployment/src/nats/src/comms_hsm_controller.py index 753d37d9a..38c70d2e6 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/src/comms_hsm_controller.py +++ b/modules/sc-mesh-secure-deployment/src/nats/src/comms_hsm_controller.py @@ -30,6 +30,9 @@ from Crypto.Util.Padding import unpad from Crypto.Cipher import AES +# To get IP addresses for CSR SAN extension +import netifaces as ni + class CommsHSMController: """ @@ -43,6 +46,7 @@ def __init__(self, base_dir: str, board_version: float) -> None: self.__user_pin_file = self.__base_dir + "/hsm/user_pin" self.__so_pin_file = self.__base_dir + "/hsm/so_pin" self.__comms_board_version = board_version + self.__ip_addresses = ["10.10.10.2", "10.10.20.2"] self.__use_soft_hsm = False self.__login_required = False # CKF_LOGIN_REQUIRED @@ -52,10 +56,11 @@ def __init__(self, base_dir: str, board_version: float) -> None: self.__token_has_rng = False # CKF_RNG (Random Number Generator) # Used with SoftHSM and related PIN encryption - self.__token_label = "secccoms" + self.__token_label = "nats_token" self.__current_token_label = "" - # Define configuration file for openssl + # Define configuration files for openssl + self.__csr_cnf_file = self.__base_dir + "/csr.cnf" os.environ['OPENSSL_CONF'] = self.__base_dir + "/comms_openssl.cnf" self.__pkcs11_engine = "" @@ -146,8 +151,8 @@ def __recover_pin(self, filename): # Decrypt with AES-256 / CBC / PKCS7 Padding cipher = AES.new(key, AES.MODE_CBC, iv) return unpad(cipher.decrypt(ciphertext), 16).decode().split('\n')[0] - except: - print("No pin found") + except Exception as e: + print("No pin found, error:", e) return "" def __generate_pin(self): @@ -739,7 +744,7 @@ def generate_ec_keypair(self, keypair_id: str, label: str) -> bool: print("Generated EC keypair") return True - def create_csr_via_openssl(self, priv_key_id: str, device_id: str, filename: str): + def create_csr_via_openssl(self, priv_key_id: str, device_id: str, filename: str, is_server=False): """ Create a Certificate Signing Request (CSR) using OpenSSL with the provided arguments. @@ -748,6 +753,7 @@ def create_csr_via_openssl(self, priv_key_id: str, device_id: str, filename: str priv_key_id (str) -- Identifier number for the private key. subject (str) -- Subject to be used in CSR. filename (str) -- Output file name. + is_server (bool) -- Defines for what role CSR is created, Returns: bool: True if CSR generation is successful, False otherwise. @@ -759,23 +765,22 @@ def create_csr_via_openssl(self, priv_key_id: str, device_id: str, filename: str os.environ['PKCS11_PIN'] = self.__token_user_pin subject = "/CN=" + device_id - # SoftHSM engine used by default - pkcs_engine = 'pkcs11' - if not self.__use_soft_hsm: - pkcs_engine = 'e4sss' + # Generate CSR config file + self.__generate_csr_config_file(subject, is_server) command = [ 'openssl', 'req', '-new', - '-engine', pkcs_engine, + '-engine', self.__pkcs11_engine, '-keyform', 'engine', '-key', str(priv_key_id), '-passin', 'env:PKCS11_PIN', '-out', str(filename), - '-subj', str(subject) + '-subj', str(subject), + '-config', self.__csr_cnf_file ] - + print(command) # Run the command result = subprocess.run(command, capture_output=True, text=True) @@ -876,7 +881,148 @@ def __get_public_key_data(self, object_handle): return public_key, parameters, key_type - def create_csr(self, priv_key_id: str, device_id: str, filename: str): + def __create_csr_info(self, subject, public_key, algorithm_name, pub_key_params, is_server): + """ + Creates the csr info to be used in certificate request. + + Arguments: + subject (str) -- Identifies to be used as common name identifier. + public_key (any) -- Public key object instance + algorithm_name -- Algorithm name to be used in CSR + pub_key_params -- Public key paramters + is_server (bool) -- Boolean value to define is the CSR for server or + client usage. + Returns: + Handle to publick key object or None + + """ + x509_subject = x509.Name.build({"common_name": subject}) + + if is_server: + key_usage = "server_auth" + else: + key_usage = "client_auth" + + extensions = [("basic_constraints", + x509.BasicConstraints({"ca": False}), + False), + ("key_usage", + x509.KeyUsage({"digital_signature", + "key_encipherment"}), + True), + ("extended_key_usage", + x509.ExtKeyUsageSyntax([key_usage]), + False)] + + if is_server: + ip_list = self.__get_ip_addresses() + names = x509.GeneralNames() + for ip in ip_list: + names.append(x509.GeneralName("ip_address", ip)) + extensions.append(("subject_alt_name", names, False)) + + csr_info = csr.CertificationRequestInfo({ + "version": "v1", + "subject": x509_subject, + "subject_pk_info": { + "algorithm": { + "algorithm": algorithm_name, + "parameters": pub_key_params, + }, + "public_key": public_key + }, + "attributes": [{ + "type": "extension_request", + "values": [[self.__create_extension(x) for x in extensions]] + }] + }) + return csr_info + + def __create_extension(self, extension): + """ + Creates an ASN.1 certificate request extension structure + + Arguments: + extension (tuple): -- Tuple with three values: name of the + extension (str), value for the extension(str) + and boolean to express if extension should be + considered critical. + Returns: + Extension dictionary. + """ + name, value, critical = extension + return {"extn_id": name, + "extn_value": value, + "critical": critical} + + def __get_ip_addresses(self): + """ + Creates a list of IP addresses from device network interfaces. + + Arguments: + server (bool) -- Boolean value to determine is CSR meant + for server or client usage. + Returns: + List of IP addresses + """ + ip_address_set = set(self.__ip_addresses) + + interfaces = ni.interfaces() + for interface in interfaces: + addrs = ni.ifaddresses(interface) + + if ni.AF_INET in addrs: + ipv4_addresses = [addr['addr'] for addr in addrs[ni.AF_INET]] + for ipv4_addr in ipv4_addresses: + if ipv4_addr not in ip_address_set: + self.__ip_addresses.append(ipv4_addr) + + if ni.AF_INET6 in addrs: + ipv6_addresses = [addr['addr'].split('%')[0] for addr in addrs[ni.AF_INET6]] + for ipv6_addr in ipv6_addresses: + if ipv6_addr not in ip_address_set: + self.__ip_addresses.append(ipv6_addr) + + # Remove any Docker interfaces + self.__ip_addresses = [ip for ip in self.__ip_addresses if not ip.startswith("docker")] + + return self.__ip_addresses + + def __generate_csr_config_file(self, subject: str, server=False): + """ + Creates a configurarion file for CSR creation via openssl + + Arguments: + server (bool) -- Boolean value to determine is CSR meant + for server or client usage. + Returns: + bool: True if the config file was generated, False otherwise. + """ + try: + with open(self.__csr_cnf_file, 'w') as file: + file.write("[ req ]\n") + file.write("distinguished_name=req_distinguished_name\n") + file.write("req_extensions = v3_req\n") + file.write("\n[ req_distinguished_name ]\n") + file.write(subject+"\n") + file.write("\n[ v3_req ]\n") + file.write("basicConstraints = CA:FALSE\n") + file.write("keyUsage=digitalSignature, keyEncipherment\n") + if not server: + file.write("extendedKeyUsage = clientAuth\n") + else: + ip_addresses = self.__get_ip_addresses() + file.write("extendedKeyUsage = serverAuth\n") + file.write("subjectAltName = @alt_names\n") + file.write("\n[alt_names]\n") + for i, ip_address in enumerate(ip_addresses, start=1): + file.write(f"IP.{i} = {ip_address}\n") + return True + except Exception as e: + print("Error creating CSR config file:", e) + return False + + def create_csr(self, priv_key_id: str, device_id: str, filename: str, server=False): """ Create a Certificate Signing Request (CSR) using asn1crypto library. CSR is signed via HSM. @@ -885,7 +1031,8 @@ def create_csr(self, priv_key_id: str, device_id: str, filename: str): priv_key_id (str) -- Identifier number for the private key. subject (str) -- Subject to be used in CSR. filename (str) -- Output file name. - + server (bool) -- Boolean value to define is this CSR for server + or client usage. Returns: bool: True if CSR generation is successful, False otherwise. """ @@ -903,8 +1050,7 @@ def create_csr(self, priv_key_id: str, device_id: str, filename: str): if public_key_handle is None: return False - subject = x509.Name.build({"common_name": device_id}) - public_key, parameters, key_type = self.__get_public_key_data(public_key_handle) + public_key, pub_key_params, key_type = self.__get_public_key_data(public_key_handle) if key_type == PyKCS11.LowLevel.CKK_RSA: public_key = keys.RSAPublicKey.load(public_key) @@ -916,18 +1062,7 @@ def create_csr(self, priv_key_id: str, device_id: str, filename: str): algorithm_name = "ec" sign_algorithm_name = "ecdsa" - csr_info = csr.CertificationRequestInfo({ - 'version': 0, - 'subject': subject, - 'subject_pk_info': { - 'algorithm': { - 'algorithm': algorithm_name, - 'parameters': parameters, - }, - 'public_key': public_key - }, - 'attributes': csr.CRIAttributes([]) - }) + csr_info = self.__create_csr_info(device_id, public_key, algorithm_name, pub_key_params, server) # Sign the CSR Info signature = self.__pkcs11_session.sign(private_key_handle, diff --git a/modules/sc-mesh-secure-deployment/src/nats/src/comms_provisioning.py b/modules/sc-mesh-secure-deployment/src/nats/src/comms_provisioning.py index 318c34186..1788bc85d 100644 --- a/modules/sc-mesh-secure-deployment/src/nats/src/comms_provisioning.py +++ b/modules/sc-mesh-secure-deployment/src/nats/src/comms_provisioning.py @@ -15,8 +15,11 @@ import argparse import asyncio -import requests import os +import requests +import time +from hw_control import LedControl + from cryptography.x509 import load_pem_x509_certificate @@ -31,13 +34,15 @@ def __init__(self, server: str = "localhost", port: str = "80", outdir: str = "/ self.__session = None self.__outdir = outdir self.__device_id_file = self.__outdir + "/identity" - self.__csr_file = self.__outdir + "/csr/prov_csr.csr" + self.__client_csr_file = self.__outdir + "/csr/client_csr.csr" + self.__server_csr_file = self.__outdir + "/csr/server_csr.csr" self.__server_url = "http://" + server + ":" + port + "/api/mesh/provision" self.__device_id = self.__get_device_id() self.__pcb_version = self.__get_comms_pcb_version("/opt/hardware/comms_pcb_version") self.__auth_key_id = "99887766" self.__auth_key_label = "CommsDeviceAuth" + self.__server_cert_label = "CommsServerCert" self.__hsm_ctrl = comms_hsm_controller.CommsHSMController( self.__outdir, self.__pcb_version) self.__session = self.__hsm_ctrl.open_session() @@ -81,35 +86,43 @@ def do_provisioning(self): Returns: bool: True if provisioning is ready, False otherwise. """ - if self.__hsm_ctrl.get_certificate(self.__auth_key_id, self.__auth_key_label) is None: + if self.__hsm_ctrl.get_certificate(self.__auth_key_id, self.__auth_key_label) is None or \ + self.__hsm_ctrl.get_certificate(self.__auth_key_id, self.__server_cert_label) is None: if not self.__hsm_ctrl.has_private_key(self.__auth_key_id, self.__auth_key_label): - self.__hsm_ctrl.generate_rsa_keypair_via_openssl(self.__auth_key_id, self.__auth_key_label) - # self.__hsm_ctrl.generate_ec_keypair_via_openssl(self.__auth_key_id, self.__auth_key_label) - # self.__hsm_ctrl.generate_rsa_keypair(self.__auth_key_id, self.__auth_key_label) - # self.__hsm_ctrl.generate_ec_keypair(self.__auth_key_id, self.__auth_key_label) - - # Create certificate signing request - # csr_created = self.__hsm_ctrl.create_csr_via_openssl(priv_key_id=self.__auth_key_id, - # device_id=self.__device_id, - # filename=self.__csr_file) - csr_created = self.__hsm_ctrl.create_csr(priv_key_id=self.__auth_key_id, - device_id=self.__device_id, - filename=self.__csr_file) - if csr_created: - # TODO: Works with RSA key only for now. - # EC key needs to be tested once with such - # provisioning server that uses EC keys also. - req_status = self.__request_certificate() - return req_status + self.__hsm_ctrl.generate_rsa_keypair_via_openssl(self.__auth_key_id, + self.__auth_key_label) + # Create certificate signing request to get client certificate + client_csr_created = self.__hsm_ctrl.create_csr(priv_key_id=self.__auth_key_id, + device_id=self.__device_id, + filename=self.__client_csr_file) + if not client_csr_created: + print("Problem creating client CSR") + return False else: + client_req_status = self.__request_client_certificate() + if not client_req_status: + print("Problem getting client certificate") + return False + + # Create certificate signing request to get server certificate + server_csr_created = self.__hsm_ctrl.create_csr(priv_key_id=self.__auth_key_id, + device_id=self.__device_id, + filename=self.__server_csr_file, + server=True) + if not server_csr_created: + print("Problem creating server CSR") return False + else: + server_req_status = self.__request_server_certificate() + if not server_req_status: + print("Problem getting server certificate") + return False else: return True - def __request_certificate(self): - with open(self.__csr_file, 'rb') as file: + def __request_client_certificate(self): + with open(self.__client_csr_file, 'rb') as file: csr = file.read() - try: # Post the CSR to the signing server url = self.__server_url @@ -126,14 +139,6 @@ def __request_certificate(self): signed_certificate_data = response_json["certificate"] ca_certificate = response_json["caCertificate"] - print("### Printing received certificates from server: ###") - print(signed_certificate_data) - print("### Finished printing received certificates from server: ###") - - print("### Printing root CA certificate from server ###") - print(ca_certificate) - print("### Finished printing CA certificate from server ###") - # Save received certificates into filesystem certificate = "/etc/ssl/certs/comms_auth_cert.pem" root_certificate = "/etc/ssl/certs/root-ca.cert.pem" @@ -193,28 +198,103 @@ def __request_certificate(self): saved = self.__hsm_ctrl.save_certificate(certificate, self.__auth_key_id, label) return saved + def __request_server_certificate(self): + with open(self.__server_csr_file, 'rb') as file: + csr = file.read() + try: + # Post the CSR to the signing server + url = self.__server_url + headers = {"Content-Type": "application/json"} + payload = {"csr": csr.decode("utf-8")} + response = requests.post(url, headers=headers, json=payload, timeout=5) + response.raise_for_status() + except requests.exceptions.RequestException as e: + print(f"Connection error: {e}") + return False + + # Extract the signed certificate from the response + response_json = response.json() + signed_certificate_data = response_json["certificate"] + + # Save received certificates into filesystem + certificate = "/etc/ssl/certs/comms_server_cert.pem" + + # Create directories if they don't exist + os.makedirs(os.path.dirname(certificate), exist_ok=True) + + # Open the file in write mode and write the PEM data + try: + with open(certificate, "w") as file: + file.write(signed_certificate_data) + except Exception as e: + print(f"Error saving server certificate file: {str(e)}") + + # Save certificate to HSM: + signed_certificates = signed_certificate_data.split('-----END CERTIFICATE-----\n') + + # Remove any empty strings from the split operation + signed_certificates = [cert.strip() for cert in signed_certificates if cert.strip()] + + for index, cert_data in enumerate(signed_certificates): + cert_data = cert_data + "\n-----END CERTIFICATE-----\n" + + try: + # Read certificate as PEM format + certificate = load_pem_x509_certificate(cert_data.encode("utf-8")) + except Exception as e: + print("Could not load certificate at index {index}", e) + return False + + device_id = self.__device_id + cert_name = self.__server_cert_label + + if device_id in certificate.subject.rfc4514_string(): + label = cert_name + else: + label = cert_name + " Issuer " + str(index) + saved = self.__hsm_ctrl.save_certificate(certificate, self.__auth_key_id, label) + if not saved: + print("Failed to store certificate with label {label} in index {index}") + return False + return saved + -async def main(server, port, outdir): +async def main(server, port, outdir, timeout): """ main """ + led_status = LedControl() + led_status.provisioning_led_control("start") + # start time of provisioning + start_time = time.time() + prov_agent = CommsProvisioning(server, port, outdir) while True: + led_status.provisioning_led_control("active") status = prov_agent.do_provisioning() if status: prov_agent.close_session() + led_status.provisioning_led_control("stop") + break + + led_status.provisioning_led_control("start") + time.sleep(10) + # if provisioning takes more than 30 seconds, break out + if time.time() - start_time > int(timeout): + led_status.provisioning_led_control("fail") break - await asyncio.sleep(10) + time.sleep(5) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Provisioning agent settings') parser.add_argument('-s', '--server', help='Provisioning Server IP', required=False) parser.add_argument('-p', '--port', help='Server port', required=False) parser.add_argument('-o', '--outdir', help='Output folder for files', required=False) + parser.add_argument('-t', '--timeout', help='Timeout for provisioning trial', required=False) args = parser.parse_args() loop = asyncio.new_event_loop() - loop.run_until_complete(main(args.server, args.port, args.outdir)) + loop.run_until_complete(main(args.server, args.port, args.outdir, args.timeout)) loop.close() diff --git a/modules/sc-mesh-secure-deployment/src/nats/src/hw_control.py b/modules/sc-mesh-secure-deployment/src/nats/src/hw_control.py new file mode 100644 index 000000000..8d6f0ffcd --- /dev/null +++ b/modules/sc-mesh-secure-deployment/src/nats/src/hw_control.py @@ -0,0 +1,138 @@ +""" +This module is used to control the provisioning LED on the SC Mesh Secure Deployment board. +""" +import os +import sys + +# pylint: disable=too-few-public-methods +class LedControl: + """ + This class is used to control the provisioning LED on the SC Mesh Secure Deployment board. + """ + + def __init__(self): + self.not_supported = False + try: + with open("/etc/comms_pcb_version", "r", encoding="utf8") as version_file: + self.comms_pcb_version = float(version_file.read().split("=")[-1].strip()) + except FileNotFoundError: + self.not_supported = True + + @staticmethod + def _write_to_file(path, file, value): + try: + with open(os.path.join(path, file), "w", encoding="utf-8") as engine: + engine.write(value) + except (FileNotFoundError, PermissionError): + pass + + def _led_control_0(self, state) -> None: + """ + Control the provisioning LED. + :param state: start, active, stop, fail + :return: None + """ + path = "/sys/class/leds/mesh" + trigger_seq = "timer" + trigger_none = "none" + start_seq = 1000 + active_seq = 100 + + trigger_used = "none" + seq_used = 0 + + if state == "start": + seq_used = start_seq + trigger_used = trigger_seq + elif state == "active": + seq_used = active_seq + trigger_used = trigger_seq + elif state == "stop": + seq_used = 0 + trigger_used = trigger_none + elif state == "fail": + seq_used = 1 + trigger_used = trigger_none + + # write to sys class led + self._write_to_file(path, "trigger", trigger_used) + self._write_to_file(path, "delay_off", str(seq_used)) + self._write_to_file(path, "delay_on", str(seq_used)) + self._write_to_file(path, "brightness", str(seq_used)) + + def _led_control_1(self, state) -> None: + """ + Control the provisioning LED. + :param state: start, active, stop, fail + :return: None + """ + path = "/sys/class/leds/rgb_leds:channel0/device" + led_matrix = "000000001" + code_used = "9d094000a001" + + # LASM compiler: + # .segment program1 ; segment begins + # mux_sel 9 ; select led 9 blue + # loop1: set_pwm 2Fh + # wait 0.4 + # wait 0.4 + # wait 0.2 + # set_pwm 00h + # wait 0.4 + # wait 0.4 + # wait 0.2 + # branch 0,loop1 + start_code = "9d09402f740074005a004000740074005a00a001" + # LASM compiler: + # .segment program1 ; segment begins + # mux_sel 9 ; select led 9 blue + # loop1: set_pwm 2Fh + # wait 0.1 + # set_pwm 00h + # wait 0.1 + # branch 0,loop1 + active_code = "9d09402f4c0040004c00a001" + + if state == "start": + code_used = start_code + elif state == "active": + code_used = active_code + elif state == "stop": + # pwm 00 + code_used = "9d094000a001" + elif state == "fail": + # pwm 2f + code_used = "9d09402fa001" + + # write to led driver + self._write_to_file(path, "engine3_mode", "disabled") + self._write_to_file(path, "engine3_mode", "load") + self._write_to_file(path, "engine3_load", code_used) + self._write_to_file(path, "engine3_leds", led_matrix) + self._write_to_file(path, "engine3_mode", "run") + + + def provisioning_led_control(self, state): + """ + Control the provisioning LED. + :param state: start, active, stop, fail + :return: None + """ + if self.not_supported: + return None + + if self.comms_pcb_version == 1: + self._led_control_1(state) + elif self.comms_pcb_version in (0.5, 0): + self._led_control_0(state) + +if __name__ == "__main__": + led = LedControl() + import time + led.provisioning_led_control("start") + time.sleep(5) + led.provisioning_led_control("active") + time.sleep(5) + led.provisioning_led_control("stop") + time.sleep(5) + led.provisioning_led_control("fail") diff --git a/modules/utils/docker/entrypoint_nats.sh b/modules/utils/docker/entrypoint_nats.sh index c81c02da3..77a6dc28d 100755 --- a/modules/utils/docker/entrypoint_nats.sh +++ b/modules/utils/docker/entrypoint_nats.sh @@ -33,17 +33,10 @@ else done echo "starting provisioning agent" - /opt/S90provisioning_agent start - loop_count=0 - while ps aux | grep [c]omms_provisioning >/dev/null; do - sleep 1 - ((loop_count++)) - if [ "$loop_count" -ge 30 ]; then - echo "Stopping provisioning agent due to timeout" - /opt/S90provisioning_agent stop - fi - done - + # blocks execution until provisioning is done or timeout (30s) + # IP address and port are passed as arguments and hardcoded. TODO: mDNS + python /opt/nats/src/comms_provisioning.py -t 30 -s 192.168.1.254 -p 8080 -o /opt > /opt/comms_provisioning.log 2>&1 + echo "Start nats server and client nodes" /opt/S90nats_discovery start