From 5791f1ad48df28ae0ea6b8f2dced30ba88d467d0 Mon Sep 17 00:00:00 2001 From: maico Date: Wed, 24 Apr 2024 13:29:24 +0200 Subject: [PATCH] Resolve periodic socket.timeout causing control channel message drops --- .../R/scripts/server_listener.py | 36 +++++++++++-------- .../python/scripts/launch_ipykernel.py | 36 +++++++++++-------- 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/etc/kernel-launchers/R/scripts/server_listener.py b/etc/kernel-launchers/R/scripts/server_listener.py index 2ca755c6..bfb04179 100644 --- a/etc/kernel-launchers/R/scripts/server_listener.py +++ b/etc/kernel-launchers/R/scripts/server_listener.py @@ -5,9 +5,11 @@ import logging import os import random +import select import socket import uuid from threading import Thread +from typing import Any, Dict, Optional from Cryptodome.Cipher import AES, PKCS1_v1_5 from Cryptodome.PublicKey import RSA @@ -169,30 +171,34 @@ def _get_candidate_port(lower_port, upper_port): return random.randint(lower_port, upper_port) -def get_server_request(sock): +def get_server_request(sock: socket.socket) -> Optional[Dict[str, Any]]: """Gets a request from the server and returns the corresponding dictionary. This code also exists in the Python kernel-launcher's launch_ipykernel.py script. """ - conn = None - data = "" - request_info = None + conn: Optional[socket.socket] = None try: - conn, addr = sock.accept() - while True: - buffer = conn.recv(1024).decode("utf-8") - if not buffer: # send is complete - request_info = json.loads(data) - break - data = data + buffer # append what we received until we get no more... - except Exception as e: - if type(e) is not socket.timeout: - raise e + # Enterprise gateway establishes a new connection for every request. + # Wait until a new connection is made, before accepting and reading. + input_ready, _, except_ready = select.select([sock], [], []) + for sock_ready in input_ready: + conn, addr = sock_ready.accept() + logger.info(f"Accepted connection on control channel from {addr=}") + data: str = "" + while buffer := conn.recv(1024).decode("utf-8"): + data = data + buffer # append what we received until we get no more... + return json.loads(data) if data else None + + if except_ready: + logger.error("Control channel socket reported error state") + + except Exception: + logger.exception("Control channel socket closed unexpectedly") finally: if conn: conn.close() - return request_info + return None def server_listener(sock, parent_pid): diff --git a/etc/kernel-launchers/python/scripts/launch_ipykernel.py b/etc/kernel-launchers/python/scripts/launch_ipykernel.py index 4ed8b3b1..de55ad0c 100644 --- a/etc/kernel-launchers/python/scripts/launch_ipykernel.py +++ b/etc/kernel-launchers/python/scripts/launch_ipykernel.py @@ -6,12 +6,14 @@ import logging import os import random +import select import signal import socket import tempfile import uuid from multiprocessing import Process from threading import Thread +from typing import Optional, Dict, Any from Cryptodome.Cipher import AES, PKCS1_v1_5 from Cryptodome.PublicKey import RSA @@ -344,30 +346,34 @@ def _get_candidate_port(lower_port, upper_port): return random.randint(lower_port, upper_port) -def get_server_request(sock): +def get_server_request(sock: socket.socket) -> Optional[Dict[str, Any]]: """Gets a request from the server and returns the corresponding dictionary. This code also exists in the R kernel-launcher's server_listener.py script. """ - conn = None - data = "" - request_info = None + conn: Optional[socket.socket] = None try: - conn, addr = sock.accept() - while True: - buffer = conn.recv(1024).decode("utf-8") - if not buffer: # send is complete - request_info = json.loads(data) - break - data = data + buffer # append what we received until we get no more... - except Exception as e: - if type(e) is not socket.timeout: - raise e + # Enterprise gateway establishes a new connection for every request. + # Wait until a new connection is made, before accepting and reading. + input_ready, _, except_ready = select.select([sock], [], []) + for sock_ready in input_ready: + conn, addr = sock_ready.accept() + logger.info(f"Accepted connection on control channel from {addr=}") + data: str = "" + while buffer := conn.recv(1024).decode("utf-8"): + data = data + buffer # append what we received until we get no more... + return json.loads(data) if data else None + + if except_ready: + logger.error("Control channel sock reported error state") + + except Exception: + logger.exception("Control channel socket closed unexpectedly") finally: if conn: conn.close() - return request_info + return None def cancel_spark_jobs(sig, frame):