diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index cc70f3f..93fa5c6 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -56,6 +56,10 @@ def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None: self.recv = socket.recv self.close = socket.close self.recv_into = socket.recv_into + if hasattr(socket, "_interface"): + self._interface = socket._interface + if hasattr(socket, "_socket_pool"): + self._socket_pool = socket._socket_pool def connect(self, address: Tuple[str, int]) -> None: """Connect wrapper to add non-standard mode parameter""" @@ -93,7 +97,6 @@ def create_fake_ssl_context( * `Adafruit AirLift FeatherWing – ESP32 WiFi Co-Processor `_ """ - socket_pool.set_interface(iface) return _FakeSSLContext(iface) @@ -121,12 +124,15 @@ def get_radio_socketpool(radio): ssl_context = ssl.create_default_context() elif class_name == "ESP_SPIcontrol": - import adafruit_esp32spi.adafruit_esp32spi_socket as pool # pylint: disable=import-outside-toplevel + import adafruit_esp32spi.adafruit_esp32spi_socketpool as socketpool # pylint: disable=import-outside-toplevel + pool = socketpool.SocketPool(radio) ssl_context = create_fake_ssl_context(pool, radio) elif class_name == "WIZNET5K": - import adafruit_wiznet5k.adafruit_wiznet5k_socket as pool # pylint: disable=import-outside-toplevel + import adafruit_wiznet5k.adafruit_wiznet5k_socketpool as socketpool # pylint: disable=import-outside-toplevel + + pool = socketpool.SocketPool(radio) # Note: SSL/TLS connections are not supported by the Wiznet5k library at this time ssl_context = create_fake_ssl_context(pool, radio) diff --git a/tests/conftest.py b/tests/conftest.py index 2d9bb0a..06457cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,18 +14,39 @@ def set_interface(iface): """Helper to set the global internet interface""" +class SocketPool: + name = None + + def __init__(self, *args, **kwargs): + pass + + @property + def __name__(self): + return self.name + + +class ESP32SPI_SocketPool(SocketPool): # pylint: disable=too-few-public-methods + name = "adafruit_esp32spi_socketpool" + + +class WIZNET5K_SocketPool(SocketPool): # pylint: disable=too-few-public-methods + name = "adafruit_wiznet5k_socketpool" + + socketpool_module = type(sys)("socketpool") socketpool_module.SocketPool = mocket.MocketPool sys.modules["socketpool"] = socketpool_module esp32spi_module = type(sys)("adafruit_esp32spi") -esp32spi_socket_module = type(sys)("adafruit_esp32spi_socket") -esp32spi_socket_module.set_interface = set_interface +esp32spi_socket_module = type(sys)("adafruit_esp32spi_socketpool") +esp32spi_socket_module.SocketPool = ESP32SPI_SocketPool sys.modules["adafruit_esp32spi"] = esp32spi_module -sys.modules["adafruit_esp32spi.adafruit_esp32spi_socket"] = esp32spi_socket_module +sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"] = esp32spi_socket_module wiznet5k_module = type(sys)("adafruit_wiznet5k") -wiznet5k_socket_module = type(sys)("adafruit_wiznet5k_socket") -wiznet5k_socket_module.set_interface = set_interface +wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool") +wiznet5k_socketpool_module.SocketPool = WIZNET5K_SocketPool sys.modules["adafruit_wiznet5k"] = wiznet5k_module -sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] = wiznet5k_socket_module +sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = ( + wiznet5k_socketpool_module +) diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index ea80f7e..cbd49e8 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -21,13 +21,13 @@ def test_get_radio_socketpool_wifi(): def test_get_radio_socketpool_esp32spi(): radio = mocket.MockRadio.ESP_SPIcontrol() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) - assert socket_pool.__name__ == "adafruit_esp32spi_socket" + assert socket_pool.__name__ == "adafruit_esp32spi_socketpool" def test_get_radio_socketpool_wiznet5k(): radio = mocket.MockRadio.WIZNET5K() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) - assert socket_pool.__name__ == "adafruit_wiznet5k_socket" + assert socket_pool.__name__ == "adafruit_wiznet5k_socketpool" def test_get_radio_socketpool_unsupported():