Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CBMA code and test scripts - Original Changes #348

Merged
merged 18 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import socket
import ssl
import sys
sys.path.insert(0, '../')
from tools.verification_tools import *
from tools.custom_logger import CustomLogger
from tools.utils import mac_to_ipv6, get_mac_addr
import glob
import random
import time

MAX_RETRIES = 5
MIN_WAIT_TIME = 1 # seconds
MAX_WAIT_TIME = 3 # seconds

logger_instance = CustomLogger("authClient")



class AuthClient:
def __init__(self, interface, server_mac, server_port, cert_path, ca_path, mua):
self.sslServerIP = mac_to_ipv6(server_mac)
self.sslServerPort = server_port
self.CERT_PATH = cert_path
self.interface = interface
self.secure_client_socket = None
self.logger = logger_instance.get_logger()
self.ca = ca_path
self.mymac = get_mac_addr(self.interface)
self.server_mac = server_mac
self.mua = mua

def establish_connection(self):
# Create an SSL context
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
context.verify_mode = ssl.CERT_REQUIRED

# Uncomment to enable Certificate Revocation List (CRL) check
# context.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF

context.load_verify_locations(glob.glob(self.ca)[0])
context.load_cert_chain(
certfile=glob.glob(f"{self.CERT_PATH}/macsec_{self.mymac.replace(':', '')}.crt")[0],
keyfile=glob.glob(f"{self.CERT_PATH}/macsec_{self.mymac.replace(':', '')}.key")[0],
)

# Detect if the server IP is IPv4 or IPv6 and create a socket accordingly
if ":" in self.sslServerIP:
clientSocket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
else:
clientSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

# Make the client socket suitable for secure communication
self.secure_client_socket = context.wrap_socket(clientSocket)
try:
result = self.connection(self.secure_client_socket)
if result['authenticated']:
self.mua.auth_pass(secure_client_socket=self.secure_client_socket, client_mac=self.server_mac)
else:
self.mua.auth_fail(client_mac=self.server_mac)
except Exception as e:
self.logger.error("Define better this exception.", exc_info=True)
self.mua.auth_fail(client_mac=self.server_mac)
# finally:
# # Close the socket
# secureClientSocket.close()

def connection(self, secureClientSocket):
result = {
'IP': self.sslServerIP,
'authenticated': False
}

try:
self.to_validate(secureClientSocket, result)
except Exception as e:
self.logger.error("An error occurred during the connection process.", exc_info=True)

finally:
return result

def to_validate(self, secureClientSocket, result):
# If the IP is a link-local IPv6 address, connect it with the interface index
retries = 0
while retries < MAX_RETRIES:
try:
if self.sslServerIP.startswith("fe80"):
secureClientSocket.connect(
(self.sslServerIP, self.sslServerPort, 0, socket.if_nametoindex(self.interface)))
else:
secureClientSocket.connect((self.sslServerIP, self.sslServerPort))
break # break out of loop if connection is successful
except ConnectionRefusedError:
retries += 1
if retries < MAX_RETRIES:
wait_time = random.uniform(MIN_WAIT_TIME, MAX_WAIT_TIME)
self.logger.info(f"Connection refused. Retrying in {wait_time:.2f} seconds...")
time.sleep(wait_time)
else:
self.logger.error("Exceeded maximum retry attempts. Unable to connect to server.")
#raise ServerConnectionRefusedError("Unable to connect to server socket")
return

server_cert = secureClientSocket.getpeercert(binary_form=True)
if not server_cert:
self.logger.error("Unable to get the server certificate", exc_info=True)
#raise CertificateNoPresentError("Unable to get the server certificate")
return

result['authenticated'] = verify_cert(server_cert, self.ca, self.sslServerIP, self.interface, self.logger)

# # Safe to proceed with the communication, even if the certificate is not authenticated
# msgReceived = secureClientSocket.recv(1024)
# logger.info(f"Secure communication received from server: {msgReceived.decode()}")


if __name__ == "__main__":
# IP address and the port number of the server
sslServerIP = "127.0.0.1"
sslServerPort = 15001
CERT_PATH = '../../../certificates' # Change this to the actual path of your certificates

auth_client = AuthClient(sslServerIP, sslServerPort, CERT_PATH)
auth_client.establish_connection()



Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import socket
import ssl
import threading
from tools.utils import *
import sys

sys.path.insert(0, '../')
from tools.verification_tools import *
from tools.custom_logger import CustomLogger
from tools.utils import wait_for_interface_to_be_pingable
import glob

logger_instance = CustomLogger("Server")
logger = logger_instance.get_logger()


class AuthServer:
def __init__(self, interface, ip_address, port, cert_path, ca_path, mua):
threading.Thread.__init__(self)
self.running = True
self.ipAddress = ip_address
self.port = port
self.CERT_PATH = cert_path
self.ca = ca_path
self.interface = interface
self.mymac = get_mac_addr(self.interface)
# Create the SSL context here and set it as an instance variable
self.context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
self.context.verify_mode = ssl.CERT_REQUIRED
self.context.load_verify_locations(glob.glob(self.ca)[0])
self.context.load_cert_chain(
certfile=glob.glob(f"{self.CERT_PATH}/macsec_{self.mymac.replace(':', '')}.crt")[0],
keyfile=glob.glob(f"{self.CERT_PATH}/macsec_{self.mymac.replace(':', '')}.key")[0],
)
self.client_auth_results = {}
self.active_sockets = {}
self.client_auth_results_lock = threading.Lock()
self.active_sockets_lock = threading.Lock()
self.mua = mua

def handle_client(self, client_connection, client_address):
client_mac = extract_mac_from_ipv6(client_address[0]) # TODO: check if it is safe to do so
print("------------------server---------------------")
if client_mac not in self.mua.connected_peers_status:
with self.mua.connected_peers_status_lock:
self.mua.connected_peers_status[client_mac] = ["ongoing",0] # Update status as ongoing, num of failed attempts = 0
else:
with self.mua.connected_peers_status_lock:
self.mua.connected_peers_status[client_mac][0] = "ongoing" # Update status as ongoing, num of failed attempts = same as before
self.authenticate_client(client_connection, client_address, client_mac)

def authenticate_client(self, client_connection, client_address, client_mac):
secure_client_socket = self.context.wrap_socket(client_connection, server_side=True)
try:
client_cert = secure_client_socket.getpeercert(binary_form=True)
if not client_cert:
logger.error(f"Unable to get the certificate from the client {client_address[0]}", exc_info=True)
raise CertificateNoPresentError("Unable to get the certificate from the client")

auth = verify_cert(client_cert, self.ca, client_address[0], self.interface, logger)
with self.client_auth_results_lock:
self.client_auth_results[client_address[0]] = auth
if auth:
with self.active_sockets_lock:
self.active_sockets[client_address[0]] = secure_client_socket
self.mua.auth_pass(secure_client_socket=secure_client_socket, client_mac=client_mac)
else:
# Handle the case when authentication fails, maybe send an error message
self.mua.auth_fail(client_mac=client_mac)
# secure_client_socket.sendall(b"Authentication failed.")
except Exception as e:
logger.error(f"An error occurred while handling the client {client_address[0]}.", exc_info=True)
self.mua.auth_fail(client_mac=client_mac)
# finally:
# secure_client_socket.close()

def get_secure_socket(self, client_address):
with self.active_sockets_lock:
return self.active_sockets.get(client_address)

def get_client_auth_result(self, client_address):
with self.client_auth_results_lock:
return self.client_auth_results.get(client_address, None)

def start_server(self):
if is_ipv4(self.ipAddress):
self.serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.serverSocket.bind((self.ipAddress, self.port))
elif is_ipv6(self.ipAddress):
self.serverSocket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
scope_id = socket.if_nametoindex(self.interface)
self.serverSocket.bind((self.ipAddress, int(self.port), 0, scope_id))
else:
raise ValueError("Invalid IP address")

self.serverSocket.listen()
self.serverSocket.settimeout(99999) # maybe we can remove timeout since server needs to be listening throughout
logger.info("Server listening")

while self.running and not self.mua.shutdown_event.is_set():
try:
client_connection, client_address = self.serverSocket.accept()
threading.Thread(target=self.handle_client, args=(client_connection, client_address)).start()
except socket.timeout: # In case we add a timeout later.
continue
except Exception as e:
if self.running:
logger.error("Unexpected error in server loop.", exc_info=True)

def stop_server(self):
self.running = False
if is_ipv4(self.ipAddress):
serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
serverSocket.bind((self.ipAddress, self.port))
elif is_ipv6(self.ipAddress):
serverSocket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
scope_id = socket.if_nametoindex(self.interface)
serverSocket.bind((self.ipAddress, int(self.port), 0, scope_id))
if hasattr(self, "serverSocket"):
self.serverSocket.close()
for sock in auth_server.active_sockets.values():
sock.close()

if __name__ == "__main__":
# IP address and the port number of the server
ipAddress = "127.0.0.1"
port = 15001
CERT_PATH = '../../../certificates' # Change this to the actual path of your certificates

auth_server = AuthServer(ipAddress, port, CERT_PATH)
auth_server.start_server()

# Access the authentication result for a specific client
client_address = ("127.0.0.1", 12345) # Replace with the actual client address you want to check
auth_result = auth_server.get_client_auth_result(client_address)
print(f"Authentication result for {client_address}: {auth_result}")
Loading
Loading