diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index c8c209e..5b8a10c 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -58,6 +58,10 @@ def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None: self.recv = socket.recv self.close = socket.close self.recv_into = socket.recv_into + # For sockets that come from software socketpools (like the esp32api), they track + # the interface and socket pool. We need to make sure the clones do as well + self._interface = getattr(socket, "_interface", None) + self._socket_pool = getattr(socket, "_socket_pool", None) def connect(self, address: Tuple[str, int]) -> None: """Connect wrapper to add non-standard mode parameter""" @@ -94,7 +98,10 @@ def create_fake_ssl_context( * `Adafruit AirLift FeatherWing – ESP32 WiFi Co-Processor `_ """ - socket_pool.set_interface(iface) + if hasattr(socket_pool, "set_interface"): + # this is to manually support legacy hardware like the fona + socket_pool.set_interface(iface) + return _FakeSSLContext(iface) @@ -104,6 +111,13 @@ def create_fake_ssl_context( _global_ssl_contexts = {} +def _get_radio_hash_key(radio): + try: + return hash(radio) + except TypeError: + return radio.__class__.__name__ + + def get_radio_socketpool(radio): """Helper to get a socket pool for common boards. @@ -113,8 +127,9 @@ def get_radio_socketpool(radio): * Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift) * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) """ - class_name = radio.__class__.__name__ - if class_name not in _global_socketpools: + key = _get_radio_hash_key(radio) + if key not in _global_socketpools: + class_name = radio.__class__.__name__ if class_name == "Radio": import ssl # pylint: disable=import-outside-toplevel @@ -124,12 +139,15 @@ def get_radio_socketpool(radio): ssl_context = ssl.create_default_context() elif class_name == "ESP_SPIcontrol": - import adafruit_esp32spi.adafruit_esp32spi_socket as pool # pylint: disable=import-outside-toplevel + import adafruit_esp32spi.adafruit_esp32spi_socketpool as socketpool # pylint: disable=import-outside-toplevel + pool = socketpool.SocketPool(radio) ssl_context = create_fake_ssl_context(pool, radio) elif class_name == "WIZNET5K": - import adafruit_wiznet5k.adafruit_wiznet5k_socket as pool # pylint: disable=import-outside-toplevel + import adafruit_wiznet5k.adafruit_wiznet5k_socketpool as socketpool # pylint: disable=import-outside-toplevel + + pool = socketpool.SocketPool(radio) # Note: At this time, SSL/TLS connections are not supported by older # versions of the Wiznet5k library or on boards withouut the ssl module @@ -141,7 +159,6 @@ def get_radio_socketpool(radio): import ssl # pylint: disable=import-outside-toplevel ssl_context = ssl.create_default_context() - pool.set_interface(radio) except ImportError: # if SSL not on board, default to fake_ssl_context pass @@ -152,11 +169,11 @@ def get_radio_socketpool(radio): else: raise AttributeError(f"Unsupported radio class: {class_name}") - _global_key_by_socketpool[pool] = class_name - _global_socketpools[class_name] = pool - _global_ssl_contexts[class_name] = ssl_context + _global_key_by_socketpool[pool] = key + _global_socketpools[key] = pool + _global_ssl_contexts[key] = ssl_context - return _global_socketpools[class_name] + return _global_socketpools[key] def get_radio_ssl_context(radio): @@ -168,9 +185,8 @@ def get_radio_ssl_context(radio): * Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift) * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) """ - class_name = radio.__class__.__name__ get_radio_socketpool(radio) - return _global_ssl_contexts[class_name] + return _global_ssl_contexts[_get_radio_hash_key(radio)] # main class diff --git a/tests/conftest.py b/tests/conftest.py index ef6c96d..22128f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,9 +10,31 @@ import pytest -# pylint: disable=unused-argument -def set_interface(iface): - """Helper to set the global internet interface""" +class SocketPool: + name = None + + def __init__(self, *args, **kwargs): + pass + + @property + def __name__(self): + return self.name + + +class ESP32SPI_SocketPool(SocketPool): # pylint: disable=too-few-public-methods + name = "adafruit_esp32spi_socketpool" + + +class WIZNET5K_SocketPool(SocketPool): # pylint: disable=too-few-public-methods + name = "adafruit_wiznet5k_socketpool" + SOCK_STREAM = 0x21 + + +class WIZNET5K_With_SSL_SocketPool( + SocketPool +): # pylint: disable=too-few-public-methods + name = "adafruit_wiznet5k_socketpool" + SOCK_STREAM = 0x1 @pytest.fixture @@ -25,41 +47,45 @@ def circuitpython_socketpool_module(): @pytest.fixture -def adafruit_esp32spi_socket_module(): +def adafruit_esp32spi_socketpool_module(): esp32spi_module = type(sys)("adafruit_esp32spi") - esp32spi_socket_module = type(sys)("adafruit_esp32spi_socket") - esp32spi_socket_module.set_interface = set_interface + esp32spi_socket_module = type(sys)("adafruit_esp32spi_socketpool") + esp32spi_socket_module.SocketPool = ESP32SPI_SocketPool sys.modules["adafruit_esp32spi"] = esp32spi_module - sys.modules["adafruit_esp32spi.adafruit_esp32spi_socket"] = esp32spi_socket_module + sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"] = ( + esp32spi_socket_module + ) yield del sys.modules["adafruit_esp32spi"] - del sys.modules["adafruit_esp32spi.adafruit_esp32spi_socket"] + del sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"] @pytest.fixture -def adafruit_wiznet5k_socket_module(): +def adafruit_wiznet5k_socketpool_module(): wiznet5k_module = type(sys)("adafruit_wiznet5k") - wiznet5k_socket_module = type(sys)("adafruit_wiznet5k_socket") - wiznet5k_socket_module.set_interface = set_interface - wiznet5k_socket_module.SOCK_STREAM = 0x21 + wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool") + wiznet5k_socketpool_module.SocketPool = WIZNET5K_SocketPool sys.modules["adafruit_wiznet5k"] = wiznet5k_module - sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] = wiznet5k_socket_module + sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = ( + wiznet5k_socketpool_module + ) yield del sys.modules["adafruit_wiznet5k"] - del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] + del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] @pytest.fixture -def adafruit_wiznet5k_with_ssl_socket_module(): +def adafruit_wiznet5k_with_ssl_socketpool_module(): wiznet5k_module = type(sys)("adafruit_wiznet5k") - wiznet5k_socket_module = type(sys)("adafruit_wiznet5k_socket") - wiznet5k_socket_module.set_interface = set_interface - wiznet5k_socket_module.SOCK_STREAM = 1 + wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool") + wiznet5k_socketpool_module.SocketPool = WIZNET5K_With_SSL_SocketPool sys.modules["adafruit_wiznet5k"] = wiznet5k_module - sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] = wiznet5k_socket_module + sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = ( + wiznet5k_socketpool_module + ) yield del sys.modules["adafruit_wiznet5k"] - del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] + del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] @pytest.fixture(autouse=True) diff --git a/tests/connection_manager_close_all_test.py b/tests/connection_manager_close_all_test.py index c0fa498..2057be0 100644 --- a/tests/connection_manager_close_all_test.py +++ b/tests/connection_manager_close_all_test.py @@ -88,7 +88,7 @@ def test_connection_manager_close_all_untracked(): def test_connection_manager_close_all_single_release_references_false( # pylint: disable=unused-argument - circuitpython_socketpool_module, adafruit_esp32spi_socket_module + circuitpython_socketpool_module, adafruit_esp32spi_socketpool_module ): radio_wifi = mocket.MockRadio.Radio() radio_esp = mocket.MockRadio.ESP_SPIcontrol() @@ -131,7 +131,7 @@ def test_connection_manager_close_all_single_release_references_false( # pylint def test_connection_manager_close_all_single_release_references_true( # pylint: disable=unused-argument - circuitpython_socketpool_module, adafruit_esp32spi_socket_module + circuitpython_socketpool_module, adafruit_esp32spi_socketpool_module ): radio_wifi = mocket.MockRadio.Radio() radio_esp = mocket.MockRadio.ESP_SPIcontrol() diff --git a/tests/get_connection_manager_test.py b/tests/get_connection_manager_test.py index 324d032..c5f7817 100644 --- a/tests/get_connection_manager_test.py +++ b/tests/get_connection_manager_test.py @@ -19,7 +19,7 @@ def test_get_connection_manager(): def test_different_connection_manager_different_pool( # pylint: disable=unused-argument - circuitpython_socketpool_module, adafruit_esp32spi_socket_module + circuitpython_socketpool_module, adafruit_esp32spi_socketpool_module ): radio_wifi = mocket.MockRadio.Radio() radio_esp = mocket.MockRadio.ESP_SPIcontrol() diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index 5c43ad1..9844e9e 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -13,6 +13,18 @@ import adafruit_connection_manager +def test__get_radio_hash_key(): + radio = mocket.MockRadio.Radio() + assert adafruit_connection_manager._get_radio_hash_key(radio) == hash(radio) + + +def test__get_radio_hash_key_not_hashable(): + radio = mocket.MockRadio.Radio() + + with mock.patch("builtins.hash", side_effect=TypeError()): + assert adafruit_connection_manager._get_radio_hash_key(radio) == "Radio" + + def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument circuitpython_socketpool_module, ): @@ -23,21 +35,21 @@ def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument - adafruit_esp32spi_socket_module, + adafruit_esp32spi_socketpool_module, ): radio = mocket.MockRadio.ESP_SPIcontrol() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) - assert socket_pool.__name__ == "adafruit_esp32spi_socket" + assert socket_pool.__name__ == "adafruit_esp32spi_socketpool" assert socket_pool in adafruit_connection_manager._global_socketpools.values() def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument - adafruit_wiznet5k_socket_module, + adafruit_wiznet5k_socketpool_module, ): radio = mocket.MockRadio.WIZNET5K() with mock.patch("sys.implementation", return_value=[9, 0, 0]): socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) - assert socket_pool.__name__ == "adafruit_wiznet5k_socket" + assert socket_pool.__name__ == "adafruit_wiznet5k_socketpool" assert socket_pool in adafruit_connection_manager._global_socketpools.values() @@ -68,7 +80,7 @@ def test_get_radio_ssl_context_wifi( # pylint: disable=unused-argument def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument - adafruit_esp32spi_socket_module, + adafruit_esp32spi_socketpool_module, ): radio = mocket.MockRadio.ESP_SPIcontrol() ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) @@ -77,7 +89,7 @@ def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument - adafruit_wiznet5k_socket_module, + adafruit_wiznet5k_socketpool_module, ): radio = mocket.MockRadio.WIZNET5K() with mock.patch("sys.implementation", return_value=[9, 0, 0]): diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py index 6be48f0..9abbf98 100644 --- a/tests/get_socket_test.py +++ b/tests/get_socket_test.py @@ -218,7 +218,7 @@ def test_get_socket_runtime_error_ties_again_only_once(): def test_fake_ssl_context_connect( # pylint: disable=unused-argument - adafruit_esp32spi_socket_module, + adafruit_esp32spi_socketpool_module, ): mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() @@ -237,7 +237,7 @@ def test_fake_ssl_context_connect( # pylint: disable=unused-argument def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument - adafruit_esp32spi_socket_module, + adafruit_esp32spi_socketpool_module, ): mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() diff --git a/tests/ssl_context_test.py b/tests/ssl_context_test.py index de6d4f1..2f2e370 100644 --- a/tests/ssl_context_test.py +++ b/tests/ssl_context_test.py @@ -15,7 +15,7 @@ def test_connect_esp32spi_https( # pylint: disable=unused-argument - adafruit_esp32spi_socket_module, + adafruit_esp32spi_socketpool_module, ): mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() @@ -48,7 +48,7 @@ def test_connect_wifi_https( # pylint: disable=unused-argument def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argument - adafruit_wiznet5k_socket_module, + adafruit_wiznet5k_socketpool_module, ): mock_pool = mocket.MocketPool() radio = mocket.MockRadio.WIZNET5K() @@ -66,7 +66,7 @@ def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argumen def test_connect_wiznet5k_https_supported( # pylint: disable=unused-argument - adafruit_wiznet5k_with_ssl_socket_module, + adafruit_wiznet5k_with_ssl_socketpool_module, ): radio = mocket.MockRadio.WIZNET5K() with mock.patch("sys.implementation", (None, WIZNET5K_SSL_SUPPORT_VERSION)):