Skip to content

Commit

Permalink
fix: Could not start a Geth process with only WS or IPC (#2377)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Nov 25, 2024
1 parent a25349e commit 0210854
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 47 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
"eth-typing",
"eth-utils",
"hexbytes",
"py-geth>=5.0.0-beta.2,<6",
"py-geth>=5.1.0,<6",
"trie>=3.0.1,<4", # Peer: stricter pin needed for uv support.
"web3[tester]>=6.17.2,<7",
# ** Dependencies maintained by ApeWorX **
Expand Down
1 change: 0 additions & 1 deletion src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,6 @@ def uri(self) -> str:

# Use value from config file
network_config: dict = (config or {}).get(self.network.name) or DEFAULT_SETTINGS

if "url" in network_config:
raise ConfigError("Unknown provider setting 'url'. Did you mean 'uri'?")
elif "http_uri" in network_config:
Expand Down
163 changes: 123 additions & 40 deletions src/ape_node/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from subprocess import DEVNULL, PIPE, Popen
from typing import TYPE_CHECKING, Any, Optional, Union
from urllib.parse import urlparse

from eth_utils import add_0x_prefix, to_hex
from evmchains import get_random_rpc
Expand All @@ -13,7 +14,6 @@
from pydantic_settings import SettingsConfigDict
from requests.exceptions import ConnectionError
from web3.middleware import geth_poa_middleware as ExtraDataToPOAMiddleware
from yarl import URL

from ape.api.config import PluginConfig
from ape.api.providers import SubprocessProvider, TestProviderAPI
Expand Down Expand Up @@ -90,13 +90,17 @@ def create_genesis_data(alloc: Alloc, chain_id: int) -> "GenesisDataTypedDict":
class GethDevProcess(BaseGethProcess):
"""
A developer-configured geth that only exists until disconnected.
(Implementation detail of the local node provider).
"""

def __init__(
self,
data_dir: Path,
hostname: str = DEFAULT_HOSTNAME,
port: int = DEFAULT_PORT,
hostname: Optional[str] = None,
port: Optional[int] = None,
ipc_path: Optional[Path] = None,
ws_hostname: Optional[str] = None,
ws_port: Optional[str] = None,
mnemonic: str = DEFAULT_TEST_MNEMONIC,
number_of_accounts: int = DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
chain_id: int = DEFAULT_TEST_CHAIN_ID,
Expand All @@ -111,23 +115,33 @@ def __init__(
raise NodeSoftwareNotInstalledError()

self._data_dir = data_dir
self._hostname = hostname
self._port = port
self.is_running = False
self._auto_disconnect = auto_disconnect

geth_kwargs = construct_test_chain_kwargs(
data_dir=self.data_dir,
geth_executable=executable,
rpc_addr=hostname,
rpc_port=f"{port}",
network_id=f"{chain_id}",
ws_enabled=False,
ws_addr=None,
ws_origins=None,
ws_port=None,
ws_api=None,
)
kwargs_ctor: dict = {
"data_dir": self.data_dir,
"geth_executable": executable,
"network_id": f"{chain_id}",
}
if hostname is not None:
kwargs_ctor["rpc_addr"] = hostname
if port is not None:
kwargs_ctor["rpc_port"] = f"{port}"
if ws_hostname:
kwargs_ctor["ws_enabled"] = True
kwargs_ctor["ws_addr"] = ws_hostname
if ws_port:
kwargs_ctor["ws_enabled"] = True
kwargs_ctor["ws_port"] = f"{ws_port}"
if ipc_path is not None:
kwargs_ctor["ipc_path"] = f"{ipc_path}"
if not kwargs_ctor.get("ws_enabled"):
kwargs_ctor["ws_api"] = None
kwargs_ctor["ws_enabled"] = False
kwargs_ctor["ws_addr"] = None
kwargs_ctor["ws_port"] = None

geth_kwargs = construct_test_chain_kwargs(**kwargs_ctor)

# Ensure a clean data-dir.
self._clean()
Expand All @@ -147,45 +161,99 @@ def __init__(

@classmethod
def from_uri(cls, uri: str, data_folder: Path, **kwargs):
parsed_uri = URL(uri)

if parsed_uri.host not in ("localhost", "127.0.0.1"):
raise ConnectionError(f"Unable to start Geth on non-local host {parsed_uri.host}.")

port = parsed_uri.port if parsed_uri.port is not None else DEFAULT_PORT
mnemonic = kwargs.get("mnemonic", DEFAULT_TEST_MNEMONIC)
number_of_accounts = kwargs.get("number_of_accounts", DEFAULT_NUMBER_OF_TEST_ACCOUNTS)
balance = kwargs.get("initial_balance", DEFAULT_TEST_ACCOUNT_BALANCE)
extra_accounts = [a.lower() for a in kwargs.get("extra_funded_accounts", [])]

return cls(
data_folder,
auto_disconnect=kwargs.get("auto_disconnect", True),
executable=kwargs.get("executable"),
extra_funded_accounts=extra_accounts,
hd_path=kwargs.get("hd_path", DEFAULT_TEST_HD_PATH),
hostname=parsed_uri.host,
initial_balance=balance,
mnemonic=mnemonic,
number_of_accounts=number_of_accounts,
port=port,
)
process_kwargs = {
"auto_disconnect": kwargs.get("auto_disconnect", True),
"executable": kwargs.get("executable"),
"extra_funded_accounts": extra_accounts,
"hd_path": kwargs.get("hd_path", DEFAULT_TEST_HD_PATH),
"initial_balance": balance,
"mnemonic": mnemonic,
"number_of_accounts": number_of_accounts,
}

parsed_uri = urlparse(uri)
if not parsed_uri.netloc:
path = Path(parsed_uri.path)
if path.suffix == ".ipc":
# Was given an IPC path.
process_kwargs["ipc_path"] = path

else:
raise ConnectionError(f"Unrecognized path type: '{path}'.")

elif hostname := parsed_uri.hostname:
if hostname not in ("localhost", "127.0.0.1"):
raise ConnectionError(
f"Unable to start Geth on non-local host {parsed_uri.hostname}."
)

if parsed_uri.scheme.startswith("ws"):
process_kwargs["ws_hostname"] = hostname
process_kwargs["ws_port"] = parsed_uri.port or DEFAULT_PORT
elif parsed_uri.scheme.startswith("http"):
process_kwargs["hostname"] = hostname or DEFAULT_HOSTNAME
process_kwargs["port"] = parsed_uri.port or DEFAULT_PORT
else:
raise ConnectionError(f"Unsupported scheme: '{parsed_uri.scheme}'.")

return cls(data_folder, **process_kwargs)

@property
def data_dir(self) -> str:
return f"{self._data_dir}"

@property
def _hostname(self) -> Optional[str]:
return self.geth_kwargs.get("rpc_addr")

@property
def _port(self) -> Optional[str]:
return self.geth_kwargs.get("rpc_port")

@property
def _ws_hostname(self) -> Optional[str]:
return self.geth_kwargs.get("ws_addr")

@property
def _ws_port(self) -> Optional[str]:
return self.geth_kwargs.get("ws_port")

def connect(self, timeout: int = 60):
home = str(Path.home())
ipc_path = self.ipc_path.replace(home, "$HOME")
logger.info(f"Starting geth (HTTP='{self._hostname}:{self._port}', IPC={ipc_path}).")
self._log_connection()
self.start()
self.wait_for_rpc(timeout=timeout)

# Register atexit handler to make sure disconnect is called for normal object lifecycle.
if self._auto_disconnect:
atexit.register(self.disconnect)

def _log_connection(self):
home = str(Path.home())
ipc_path = self.ipc_path.replace(home, "$HOME")

http_log = ""
if self._hostname:
http_log = f"HTTP={self._hostname}"
if port := self._port:
http_log = f"{http_log}:{port}"

ipc_log = f"IPC={ipc_path}"

ws_log = ""
if self._ws_hostname:
ws_log = f"WS={self._ws_hostname}"
if port := self._ws_port:
ws_log = f"{ws_log}:{port}"

connection_logs = ", ".join(x for x in (http_log, ipc_log, ws_log) if x)

logger.info(f"Starting geth ({connection_logs}).")

def start(self):
if self.is_running:
return
Expand Down Expand Up @@ -230,6 +298,21 @@ class EthereumNetworkConfig(PluginConfig):

model_config = SettingsConfigDict(extra="allow")

@field_validator("local", mode="before")
@classmethod
def _validate_local(cls, value):
value = value or {}
if not value:
return {**DEFAULT_SETTINGS.copy(), "chain_id": DEFAULT_TEST_CHAIN_ID}

if "chain_id" not in value:
value["chain_id"] = DEFAULT_TEST_CHAIN_ID
if "uri" not in value and "ipc_path" in value or "ws_uri" in value or "http_uri" in value:
# No need to add default HTTP URI if was given only IPC Path
return {**{k: v for k, v in DEFAULT_SETTINGS.items() if k != "uri"}, **value}

return {**DEFAULT_SETTINGS, **value}


class EthereumNodeConfig(PluginConfig):
"""
Expand Down Expand Up @@ -384,8 +467,8 @@ def _create_process(self) -> GethDevProcess:
extra_accounts = list({a.lower() for a in extra_accounts})
test_config["extra_funded_accounts"] = extra_accounts
test_config["initial_balance"] = self.test_config.balance

return GethDevProcess.from_uri(self.uri, self.data_dir, **test_config)
uri = self.ws_uri or self.uri
return GethDevProcess.from_uri(uri, self.data_dir, **test_config)

def disconnect(self):
# Must disconnect process first.
Expand Down
73 changes: 68 additions & 5 deletions tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def web3_factory(mocker):
return mocker.patch("ape_ethereum.provider._create_web3")


@pytest.fixture
def process_factory_patch(mocker):
return mocker.patch("ape_node.provider.GethDevProcess.from_uri")


@pytest.fixture
def tx_for_call(geth_contract):
return DynamicFeeTransaction.model_validate(
Expand Down Expand Up @@ -129,6 +134,7 @@ def test_uri_non_dev_and_not_configured(mocker, ethereum):
assert actual == expected


@geth_process_test
def test_uri_invalid(geth_provider, project, ethereum):
settings = geth_provider.provider_settings
geth_provider.provider_settings = {}
Expand Down Expand Up @@ -255,7 +261,18 @@ def test_connect_to_chain_that_started_poa(mock_web3, web3_factory, ethereum):


@geth_process_test
def test_connect_using_only_ipc_for_uri(project, networks, geth_provider):
def test_connect_using_only_ipc_for_uri_already_connected(project, networks, geth_provider):
"""
Shows we can remote-connect to a node that is already running when it exposes its IPC path.
"""
ipc_path = geth_provider.ipc_path
with project.temp_config(node={"ethereum": {"local": {"uri": f"{ipc_path}"}}}):
with networks.ethereum.local.use_provider("node") as node:
assert node.uri == f"{ipc_path}"


@geth_process_test
def test_connect_using_ipc(process_factory_patch, project, networks, geth_provider):
ipc_path = geth_provider.ipc_path
with project.temp_config(node={"ethereum": {"local": {"uri": f"{ipc_path}"}}}):
with networks.ethereum.local.use_provider("node") as node:
Expand Down Expand Up @@ -794,20 +811,66 @@ def test_trace_approach_config(project):


@geth_process_test
def test_start(mocker, convert, project, geth_provider):
def test_start(process_factory_patch, convert, project, geth_provider):
amount = convert("100_000 ETH", int)
spy = mocker.spy(GethDevProcess, "from_uri")

with project.temp_config(test={"balance": amount}):
try:
geth_provider.start()
except Exception:
pass # Exceptions are fine here.

actual = spy.call_args[1]["balance"]
assert actual == amount
actual = process_factory_patch.call_args[1]["balance"]
assert actual == amount


@geth_process_test
@pytest.mark.parametrize("key", ("uri", "ws_uri"))
def test_start_from_ws_uri(process_factory_patch, project, geth_provider, key):
uri = "ws://localhost:5677"

with project.temp_config(node={"ethereum": {"local": {key: uri}}}):
try:
geth_provider.start()
except Exception:
pass # Exceptions are fine here.

actual = process_factory_patch.call_args[0][0] # First "arg"
assert actual == uri


@geth_process_test
def test_auto_mine(geth_provider):
assert geth_provider.auto_mine is True


@geth_process_test
def test_geth_dev_from_uri_http(data_folder):
geth_dev = GethDevProcess.from_uri("http://localhost:6799", data_folder)
kwargs = geth_dev.geth_kwargs
assert kwargs["rpc_addr"] == "localhost"
assert kwargs["rpc_port"] == "6799"
assert kwargs["ws_enabled"] is False
assert kwargs.get("ws_api") is None
assert kwargs.get("ws_addr") is None
assert kwargs.get("ws_port") is None


@geth_process_test
def test_geth_dev_from_uri_ws(data_folder):
geth_dev = GethDevProcess.from_uri("ws://localhost:6799", data_folder)
kwargs = geth_dev.geth_kwargs
assert kwargs.get("rpc_addr") is None
assert kwargs["ws_enabled"] is True
assert kwargs["ws_addr"] == "localhost"
assert kwargs["ws_port"] == "6799"


@geth_process_test
def test_geth_dev_from_uri_ipc(data_folder):
geth_dev = GethDevProcess.from_uri("path/to/geth.ipc", data_folder)
kwargs = geth_dev.geth_kwargs
assert kwargs["ipc_path"] == "path/to/geth.ipc"
assert kwargs.get("ws_api") is None
assert kwargs.get("ws_addr") is None
assert kwargs.get("rpc_addr") is None

0 comments on commit 0210854

Please sign in to comment.