From 02108545153db302b6eab399d7a4e95d2578ea75 Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 26 Nov 2024 06:13:31 +0700 Subject: [PATCH] fix: Could not start a Geth process with only WS or IPC (#2377) --- setup.py | 2 +- src/ape_ethereum/provider.py | 1 - src/ape_node/provider.py | 163 +++++++++++++++++++------ tests/functional/geth/test_provider.py | 73 ++++++++++- 4 files changed, 192 insertions(+), 47 deletions(-) diff --git a/setup.py b/setup.py index fffa3f86fe..37ae1886f4 100644 --- a/setup.py +++ b/setup.py @@ -123,7 +123,7 @@ "eth-typing", "eth-utils", "hexbytes", - "py-geth>=5.0.0-beta.2,<6", + "py-geth>=5.1.0,<6", "trie>=3.0.1,<4", # Peer: stricter pin needed for uv support. "web3[tester]>=6.17.2,<7", # ** Dependencies maintained by ApeWorX ** diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index 58bba04dde..eba0c2d3cc 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -1352,7 +1352,6 @@ def uri(self) -> str: # Use value from config file network_config: dict = (config or {}).get(self.network.name) or DEFAULT_SETTINGS - if "url" in network_config: raise ConfigError("Unknown provider setting 'url'. Did you mean 'uri'?") elif "http_uri" in network_config: diff --git a/src/ape_node/provider.py b/src/ape_node/provider.py index aa7fb0f2ee..e7b49415ac 100644 --- a/src/ape_node/provider.py +++ b/src/ape_node/provider.py @@ -3,6 +3,7 @@ from pathlib import Path from subprocess import DEVNULL, PIPE, Popen from typing import TYPE_CHECKING, Any, Optional, Union +from urllib.parse import urlparse from eth_utils import add_0x_prefix, to_hex from evmchains import get_random_rpc @@ -13,7 +14,6 @@ from pydantic_settings import SettingsConfigDict from requests.exceptions import ConnectionError from web3.middleware import geth_poa_middleware as ExtraDataToPOAMiddleware -from yarl import URL from ape.api.config import PluginConfig from ape.api.providers import SubprocessProvider, TestProviderAPI @@ -90,13 +90,17 @@ def create_genesis_data(alloc: Alloc, chain_id: int) -> "GenesisDataTypedDict": class GethDevProcess(BaseGethProcess): """ A developer-configured geth that only exists until disconnected. + (Implementation detail of the local node provider). """ def __init__( self, data_dir: Path, - hostname: str = DEFAULT_HOSTNAME, - port: int = DEFAULT_PORT, + hostname: Optional[str] = None, + port: Optional[int] = None, + ipc_path: Optional[Path] = None, + ws_hostname: Optional[str] = None, + ws_port: Optional[str] = None, mnemonic: str = DEFAULT_TEST_MNEMONIC, number_of_accounts: int = DEFAULT_NUMBER_OF_TEST_ACCOUNTS, chain_id: int = DEFAULT_TEST_CHAIN_ID, @@ -111,23 +115,33 @@ def __init__( raise NodeSoftwareNotInstalledError() self._data_dir = data_dir - self._hostname = hostname - self._port = port self.is_running = False self._auto_disconnect = auto_disconnect - geth_kwargs = construct_test_chain_kwargs( - data_dir=self.data_dir, - geth_executable=executable, - rpc_addr=hostname, - rpc_port=f"{port}", - network_id=f"{chain_id}", - ws_enabled=False, - ws_addr=None, - ws_origins=None, - ws_port=None, - ws_api=None, - ) + kwargs_ctor: dict = { + "data_dir": self.data_dir, + "geth_executable": executable, + "network_id": f"{chain_id}", + } + if hostname is not None: + kwargs_ctor["rpc_addr"] = hostname + if port is not None: + kwargs_ctor["rpc_port"] = f"{port}" + if ws_hostname: + kwargs_ctor["ws_enabled"] = True + kwargs_ctor["ws_addr"] = ws_hostname + if ws_port: + kwargs_ctor["ws_enabled"] = True + kwargs_ctor["ws_port"] = f"{ws_port}" + if ipc_path is not None: + kwargs_ctor["ipc_path"] = f"{ipc_path}" + if not kwargs_ctor.get("ws_enabled"): + kwargs_ctor["ws_api"] = None + kwargs_ctor["ws_enabled"] = False + kwargs_ctor["ws_addr"] = None + kwargs_ctor["ws_port"] = None + + geth_kwargs = construct_test_chain_kwargs(**kwargs_ctor) # Ensure a clean data-dir. self._clean() @@ -147,38 +161,70 @@ def __init__( @classmethod def from_uri(cls, uri: str, data_folder: Path, **kwargs): - parsed_uri = URL(uri) - - if parsed_uri.host not in ("localhost", "127.0.0.1"): - raise ConnectionError(f"Unable to start Geth on non-local host {parsed_uri.host}.") - - port = parsed_uri.port if parsed_uri.port is not None else DEFAULT_PORT mnemonic = kwargs.get("mnemonic", DEFAULT_TEST_MNEMONIC) number_of_accounts = kwargs.get("number_of_accounts", DEFAULT_NUMBER_OF_TEST_ACCOUNTS) balance = kwargs.get("initial_balance", DEFAULT_TEST_ACCOUNT_BALANCE) extra_accounts = [a.lower() for a in kwargs.get("extra_funded_accounts", [])] - return cls( - data_folder, - auto_disconnect=kwargs.get("auto_disconnect", True), - executable=kwargs.get("executable"), - extra_funded_accounts=extra_accounts, - hd_path=kwargs.get("hd_path", DEFAULT_TEST_HD_PATH), - hostname=parsed_uri.host, - initial_balance=balance, - mnemonic=mnemonic, - number_of_accounts=number_of_accounts, - port=port, - ) + process_kwargs = { + "auto_disconnect": kwargs.get("auto_disconnect", True), + "executable": kwargs.get("executable"), + "extra_funded_accounts": extra_accounts, + "hd_path": kwargs.get("hd_path", DEFAULT_TEST_HD_PATH), + "initial_balance": balance, + "mnemonic": mnemonic, + "number_of_accounts": number_of_accounts, + } + + parsed_uri = urlparse(uri) + if not parsed_uri.netloc: + path = Path(parsed_uri.path) + if path.suffix == ".ipc": + # Was given an IPC path. + process_kwargs["ipc_path"] = path + + else: + raise ConnectionError(f"Unrecognized path type: '{path}'.") + + elif hostname := parsed_uri.hostname: + if hostname not in ("localhost", "127.0.0.1"): + raise ConnectionError( + f"Unable to start Geth on non-local host {parsed_uri.hostname}." + ) + + if parsed_uri.scheme.startswith("ws"): + process_kwargs["ws_hostname"] = hostname + process_kwargs["ws_port"] = parsed_uri.port or DEFAULT_PORT + elif parsed_uri.scheme.startswith("http"): + process_kwargs["hostname"] = hostname or DEFAULT_HOSTNAME + process_kwargs["port"] = parsed_uri.port or DEFAULT_PORT + else: + raise ConnectionError(f"Unsupported scheme: '{parsed_uri.scheme}'.") + + return cls(data_folder, **process_kwargs) @property def data_dir(self) -> str: return f"{self._data_dir}" + @property + def _hostname(self) -> Optional[str]: + return self.geth_kwargs.get("rpc_addr") + + @property + def _port(self) -> Optional[str]: + return self.geth_kwargs.get("rpc_port") + + @property + def _ws_hostname(self) -> Optional[str]: + return self.geth_kwargs.get("ws_addr") + + @property + def _ws_port(self) -> Optional[str]: + return self.geth_kwargs.get("ws_port") + def connect(self, timeout: int = 60): - home = str(Path.home()) - ipc_path = self.ipc_path.replace(home, "$HOME") - logger.info(f"Starting geth (HTTP='{self._hostname}:{self._port}', IPC={ipc_path}).") + self._log_connection() self.start() self.wait_for_rpc(timeout=timeout) @@ -186,6 +232,28 @@ def connect(self, timeout: int = 60): if self._auto_disconnect: atexit.register(self.disconnect) + def _log_connection(self): + home = str(Path.home()) + ipc_path = self.ipc_path.replace(home, "$HOME") + + http_log = "" + if self._hostname: + http_log = f"HTTP={self._hostname}" + if port := self._port: + http_log = f"{http_log}:{port}" + + ipc_log = f"IPC={ipc_path}" + + ws_log = "" + if self._ws_hostname: + ws_log = f"WS={self._ws_hostname}" + if port := self._ws_port: + ws_log = f"{ws_log}:{port}" + + connection_logs = ", ".join(x for x in (http_log, ipc_log, ws_log) if x) + + logger.info(f"Starting geth ({connection_logs}).") + def start(self): if self.is_running: return @@ -230,6 +298,21 @@ class EthereumNetworkConfig(PluginConfig): model_config = SettingsConfigDict(extra="allow") + @field_validator("local", mode="before") + @classmethod + def _validate_local(cls, value): + value = value or {} + if not value: + return {**DEFAULT_SETTINGS.copy(), "chain_id": DEFAULT_TEST_CHAIN_ID} + + if "chain_id" not in value: + value["chain_id"] = DEFAULT_TEST_CHAIN_ID + if "uri" not in value and "ipc_path" in value or "ws_uri" in value or "http_uri" in value: + # No need to add default HTTP URI if was given only IPC Path + return {**{k: v for k, v in DEFAULT_SETTINGS.items() if k != "uri"}, **value} + + return {**DEFAULT_SETTINGS, **value} + class EthereumNodeConfig(PluginConfig): """ @@ -384,8 +467,8 @@ def _create_process(self) -> GethDevProcess: extra_accounts = list({a.lower() for a in extra_accounts}) test_config["extra_funded_accounts"] = extra_accounts test_config["initial_balance"] = self.test_config.balance - - return GethDevProcess.from_uri(self.uri, self.data_dir, **test_config) + uri = self.ws_uri or self.uri + return GethDevProcess.from_uri(uri, self.data_dir, **test_config) def disconnect(self): # Must disconnect process first. diff --git a/tests/functional/geth/test_provider.py b/tests/functional/geth/test_provider.py index 93e8c550ca..3b7a8bee7c 100644 --- a/tests/functional/geth/test_provider.py +++ b/tests/functional/geth/test_provider.py @@ -45,6 +45,11 @@ def web3_factory(mocker): return mocker.patch("ape_ethereum.provider._create_web3") +@pytest.fixture +def process_factory_patch(mocker): + return mocker.patch("ape_node.provider.GethDevProcess.from_uri") + + @pytest.fixture def tx_for_call(geth_contract): return DynamicFeeTransaction.model_validate( @@ -129,6 +134,7 @@ def test_uri_non_dev_and_not_configured(mocker, ethereum): assert actual == expected +@geth_process_test def test_uri_invalid(geth_provider, project, ethereum): settings = geth_provider.provider_settings geth_provider.provider_settings = {} @@ -255,7 +261,18 @@ def test_connect_to_chain_that_started_poa(mock_web3, web3_factory, ethereum): @geth_process_test -def test_connect_using_only_ipc_for_uri(project, networks, geth_provider): +def test_connect_using_only_ipc_for_uri_already_connected(project, networks, geth_provider): + """ + Shows we can remote-connect to a node that is already running when it exposes its IPC path. + """ + ipc_path = geth_provider.ipc_path + with project.temp_config(node={"ethereum": {"local": {"uri": f"{ipc_path}"}}}): + with networks.ethereum.local.use_provider("node") as node: + assert node.uri == f"{ipc_path}" + + +@geth_process_test +def test_connect_using_ipc(process_factory_patch, project, networks, geth_provider): ipc_path = geth_provider.ipc_path with project.temp_config(node={"ethereum": {"local": {"uri": f"{ipc_path}"}}}): with networks.ethereum.local.use_provider("node") as node: @@ -794,9 +811,8 @@ def test_trace_approach_config(project): @geth_process_test -def test_start(mocker, convert, project, geth_provider): +def test_start(process_factory_patch, convert, project, geth_provider): amount = convert("100_000 ETH", int) - spy = mocker.spy(GethDevProcess, "from_uri") with project.temp_config(test={"balance": amount}): try: @@ -804,10 +820,57 @@ def test_start(mocker, convert, project, geth_provider): except Exception: pass # Exceptions are fine here. - actual = spy.call_args[1]["balance"] - assert actual == amount + actual = process_factory_patch.call_args[1]["balance"] + assert actual == amount + + +@geth_process_test +@pytest.mark.parametrize("key", ("uri", "ws_uri")) +def test_start_from_ws_uri(process_factory_patch, project, geth_provider, key): + uri = "ws://localhost:5677" + + with project.temp_config(node={"ethereum": {"local": {key: uri}}}): + try: + geth_provider.start() + except Exception: + pass # Exceptions are fine here. + + actual = process_factory_patch.call_args[0][0] # First "arg" + assert actual == uri @geth_process_test def test_auto_mine(geth_provider): assert geth_provider.auto_mine is True + + +@geth_process_test +def test_geth_dev_from_uri_http(data_folder): + geth_dev = GethDevProcess.from_uri("http://localhost:6799", data_folder) + kwargs = geth_dev.geth_kwargs + assert kwargs["rpc_addr"] == "localhost" + assert kwargs["rpc_port"] == "6799" + assert kwargs["ws_enabled"] is False + assert kwargs.get("ws_api") is None + assert kwargs.get("ws_addr") is None + assert kwargs.get("ws_port") is None + + +@geth_process_test +def test_geth_dev_from_uri_ws(data_folder): + geth_dev = GethDevProcess.from_uri("ws://localhost:6799", data_folder) + kwargs = geth_dev.geth_kwargs + assert kwargs.get("rpc_addr") is None + assert kwargs["ws_enabled"] is True + assert kwargs["ws_addr"] == "localhost" + assert kwargs["ws_port"] == "6799" + + +@geth_process_test +def test_geth_dev_from_uri_ipc(data_folder): + geth_dev = GethDevProcess.from_uri("path/to/geth.ipc", data_folder) + kwargs = geth_dev.geth_kwargs + assert kwargs["ipc_path"] == "path/to/geth.ipc" + assert kwargs.get("ws_api") is None + assert kwargs.get("ws_addr") is None + assert kwargs.get("rpc_addr") is None