Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Could not start a Geth process with only WS or IPC #2377

Merged
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
"eth-typing",
"eth-utils",
"hexbytes",
"py-geth>=5.0.0-beta.2,<6",
"py-geth>=5.1.0,<6",
"trie>=3.0.1,<4", # Peer: stricter pin needed for uv support.
"web3[tester]>=6.17.2,<7",
# ** Dependencies maintained by ApeWorX **
Expand Down
1 change: 0 additions & 1 deletion src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,6 @@ def uri(self) -> str:

# Use value from config file
network_config: dict = (config or {}).get(self.network.name) or DEFAULT_SETTINGS

if "url" in network_config:
raise ConfigError("Unknown provider setting 'url'. Did you mean 'uri'?")
elif "http_uri" in network_config:
Expand Down
163 changes: 123 additions & 40 deletions src/ape_node/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from subprocess import DEVNULL, PIPE, Popen
from typing import TYPE_CHECKING, Any, Optional, Union
from urllib.parse import urlparse

from eth_utils import add_0x_prefix, to_hex
from evmchains import get_random_rpc
Expand All @@ -13,7 +14,6 @@
from pydantic_settings import SettingsConfigDict
from requests.exceptions import ConnectionError
from web3.middleware import geth_poa_middleware as ExtraDataToPOAMiddleware
from yarl import URL

from ape.api.config import PluginConfig
from ape.api.providers import SubprocessProvider, TestProviderAPI
Expand Down Expand Up @@ -90,13 +90,17 @@ def create_genesis_data(alloc: Alloc, chain_id: int) -> "GenesisDataTypedDict":
class GethDevProcess(BaseGethProcess):
"""
A developer-configured geth that only exists until disconnected.
(Implementation detail of the local node provider).
"""

def __init__(
self,
data_dir: Path,
hostname: str = DEFAULT_HOSTNAME,
port: int = DEFAULT_PORT,
hostname: Optional[str] = None,
port: Optional[int] = None,
ipc_path: Optional[Path] = None,
ws_hostname: Optional[str] = None,
ws_port: Optional[str] = None,
mnemonic: str = DEFAULT_TEST_MNEMONIC,
number_of_accounts: int = DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
chain_id: int = DEFAULT_TEST_CHAIN_ID,
Expand All @@ -111,23 +115,33 @@ def __init__(
raise NodeSoftwareNotInstalledError()

self._data_dir = data_dir
self._hostname = hostname
self._port = port
self.is_running = False
self._auto_disconnect = auto_disconnect

geth_kwargs = construct_test_chain_kwargs(
data_dir=self.data_dir,
geth_executable=executable,
rpc_addr=hostname,
rpc_port=f"{port}",
network_id=f"{chain_id}",
ws_enabled=False,
ws_addr=None,
ws_origins=None,
ws_port=None,
ws_api=None,
)
kwargs_ctor: dict = {
"data_dir": self.data_dir,
"geth_executable": executable,
"network_id": f"{chain_id}",
}
if hostname is not None:
kwargs_ctor["rpc_addr"] = hostname
if port is not None:
kwargs_ctor["rpc_port"] = f"{port}"
if ws_hostname:
kwargs_ctor["ws_enabled"] = True
kwargs_ctor["ws_addr"] = ws_hostname
if ws_port:
kwargs_ctor["ws_enabled"] = True
kwargs_ctor["ws_port"] = f"{ws_port}"
if ipc_path is not None:
kwargs_ctor["ipc_path"] = f"{ipc_path}"
if not kwargs_ctor.get("ws_enabled"):
kwargs_ctor["ws_api"] = None
kwargs_ctor["ws_enabled"] = False
kwargs_ctor["ws_addr"] = None
kwargs_ctor["ws_port"] = None

geth_kwargs = construct_test_chain_kwargs(**kwargs_ctor)

# Ensure a clean data-dir.
self._clean()
Expand All @@ -147,45 +161,99 @@ def __init__(

@classmethod
def from_uri(cls, uri: str, data_folder: Path, **kwargs):
parsed_uri = URL(uri)

if parsed_uri.host not in ("localhost", "127.0.0.1"):
raise ConnectionError(f"Unable to start Geth on non-local host {parsed_uri.host}.")

port = parsed_uri.port if parsed_uri.port is not None else DEFAULT_PORT
mnemonic = kwargs.get("mnemonic", DEFAULT_TEST_MNEMONIC)
number_of_accounts = kwargs.get("number_of_accounts", DEFAULT_NUMBER_OF_TEST_ACCOUNTS)
balance = kwargs.get("initial_balance", DEFAULT_TEST_ACCOUNT_BALANCE)
extra_accounts = [a.lower() for a in kwargs.get("extra_funded_accounts", [])]

return cls(
data_folder,
auto_disconnect=kwargs.get("auto_disconnect", True),
executable=kwargs.get("executable"),
extra_funded_accounts=extra_accounts,
hd_path=kwargs.get("hd_path", DEFAULT_TEST_HD_PATH),
hostname=parsed_uri.host,
initial_balance=balance,
mnemonic=mnemonic,
number_of_accounts=number_of_accounts,
port=port,
)
process_kwargs = {
"auto_disconnect": kwargs.get("auto_disconnect", True),
"executable": kwargs.get("executable"),
"extra_funded_accounts": extra_accounts,
"hd_path": kwargs.get("hd_path", DEFAULT_TEST_HD_PATH),
"initial_balance": balance,
"mnemonic": mnemonic,
"number_of_accounts": number_of_accounts,
}

parsed_uri = urlparse(uri)
if not parsed_uri.netloc:
path = Path(parsed_uri.path)
if path.suffix == ".ipc":
# Was given an IPC path.
process_kwargs["ipc_path"] = path

else:
raise ConnectionError(f"Unrecognized path type: '{path}'.")

elif hostname := parsed_uri.hostname:
if hostname not in ("localhost", "127.0.0.1"):
raise ConnectionError(
f"Unable to start Geth on non-local host {parsed_uri.hostname}."
)

if parsed_uri.scheme.startswith("ws"):
process_kwargs["ws_hostname"] = hostname
process_kwargs["ws_port"] = parsed_uri.port or DEFAULT_PORT
elif parsed_uri.scheme.startswith("http"):
process_kwargs["hostname"] = hostname or DEFAULT_HOSTNAME
process_kwargs["port"] = parsed_uri.port or DEFAULT_PORT
else:
raise ConnectionError(f"Unsupported scheme: '{parsed_uri.scheme}'.")

return cls(data_folder, **process_kwargs)

@property
def data_dir(self) -> str:
return f"{self._data_dir}"

@property
def _hostname(self) -> Optional[str]:
return self.geth_kwargs.get("rpc_addr")

@property
def _port(self) -> Optional[str]:
return self.geth_kwargs.get("rpc_port")

@property
def _ws_hostname(self) -> Optional[str]:
return self.geth_kwargs.get("ws_addr")

@property
def _ws_port(self) -> Optional[str]:
return self.geth_kwargs.get("ws_port")

def connect(self, timeout: int = 60):
home = str(Path.home())
ipc_path = self.ipc_path.replace(home, "$HOME")
logger.info(f"Starting geth (HTTP='{self._hostname}:{self._port}', IPC={ipc_path}).")
self._log_connection()
self.start()
self.wait_for_rpc(timeout=timeout)

# Register atexit handler to make sure disconnect is called for normal object lifecycle.
if self._auto_disconnect:
atexit.register(self.disconnect)

def _log_connection(self):
home = str(Path.home())
ipc_path = self.ipc_path.replace(home, "$HOME")

http_log = ""
if self._hostname:
http_log = f"HTTP={self._hostname}"
if port := self._port:
http_log = f"{http_log}:{port}"

ipc_log = f"IPC={ipc_path}"

ws_log = ""
if self._ws_hostname:
ws_log = f"WS={self._ws_hostname}"
if port := self._ws_port:
ws_log = f"{ws_log}:{port}"

connection_logs = ", ".join(x for x in (http_log, ipc_log, ws_log) if x)

logger.info(f"Starting geth ({connection_logs}).")

def start(self):
if self.is_running:
return
Expand Down Expand Up @@ -230,6 +298,21 @@ class EthereumNetworkConfig(PluginConfig):

model_config = SettingsConfigDict(extra="allow")

@field_validator("local", mode="before")
@classmethod
def _validate_local(cls, value):
value = value or {}
if not value:
return {**DEFAULT_SETTINGS.copy(), "chain_id": DEFAULT_TEST_CHAIN_ID}

if "chain_id" not in value:
value["chain_id"] = DEFAULT_TEST_CHAIN_ID
if "uri" not in value and "ipc_path" in value or "ws_uri" in value or "http_uri" in value:
# No need to add default HTTP URI if was given only IPC Path
return {**{k: v for k, v in DEFAULT_SETTINGS.items() if k != "uri"}, **value}

return {**DEFAULT_SETTINGS, **value}


class EthereumNodeConfig(PluginConfig):
"""
Expand Down Expand Up @@ -384,8 +467,8 @@ def _create_process(self) -> GethDevProcess:
extra_accounts = list({a.lower() for a in extra_accounts})
test_config["extra_funded_accounts"] = extra_accounts
test_config["initial_balance"] = self.test_config.balance

return GethDevProcess.from_uri(self.uri, self.data_dir, **test_config)
uri = self.ws_uri or self.uri
return GethDevProcess.from_uri(uri, self.data_dir, **test_config)

def disconnect(self):
# Must disconnect process first.
Expand Down
73 changes: 68 additions & 5 deletions tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def web3_factory(mocker):
return mocker.patch("ape_ethereum.provider._create_web3")


@pytest.fixture
def process_factory_patch(mocker):
return mocker.patch("ape_node.provider.GethDevProcess.from_uri")


@pytest.fixture
def tx_for_call(geth_contract):
return DynamicFeeTransaction.model_validate(
Expand Down Expand Up @@ -129,6 +134,7 @@ def test_uri_non_dev_and_not_configured(mocker, ethereum):
assert actual == expected


@geth_process_test
def test_uri_invalid(geth_provider, project, ethereum):
settings = geth_provider.provider_settings
geth_provider.provider_settings = {}
Expand Down Expand Up @@ -255,7 +261,18 @@ def test_connect_to_chain_that_started_poa(mock_web3, web3_factory, ethereum):


@geth_process_test
def test_connect_using_only_ipc_for_uri(project, networks, geth_provider):
def test_connect_using_only_ipc_for_uri_already_connected(project, networks, geth_provider):
"""
Shows we can remote-connect to a node that is already running when it exposes its IPC path.
"""
ipc_path = geth_provider.ipc_path
with project.temp_config(node={"ethereum": {"local": {"uri": f"{ipc_path}"}}}):
with networks.ethereum.local.use_provider("node") as node:
assert node.uri == f"{ipc_path}"


@geth_process_test
def test_connect_using_ipc(process_factory_patch, project, networks, geth_provider):
ipc_path = geth_provider.ipc_path
with project.temp_config(node={"ethereum": {"local": {"uri": f"{ipc_path}"}}}):
with networks.ethereum.local.use_provider("node") as node:
Expand Down Expand Up @@ -794,20 +811,66 @@ def test_trace_approach_config(project):


@geth_process_test
def test_start(mocker, convert, project, geth_provider):
def test_start(process_factory_patch, convert, project, geth_provider):
amount = convert("100_000 ETH", int)
spy = mocker.spy(GethDevProcess, "from_uri")

with project.temp_config(test={"balance": amount}):
try:
geth_provider.start()
except Exception:
pass # Exceptions are fine here.

actual = spy.call_args[1]["balance"]
assert actual == amount
actual = process_factory_patch.call_args[1]["balance"]
assert actual == amount


@geth_process_test
@pytest.mark.parametrize("key", ("uri", "ws_uri"))
def test_start_from_ws_uri(process_factory_patch, project, geth_provider, key):
uri = "ws://localhost:5677"

with project.temp_config(node={"ethereum": {"local": {key: uri}}}):
try:
geth_provider.start()
except Exception:
pass # Exceptions are fine here.

actual = process_factory_patch.call_args[0][0] # First "arg"
assert actual == uri


@geth_process_test
def test_auto_mine(geth_provider):
assert geth_provider.auto_mine is True


@geth_process_test
def test_geth_dev_from_uri_http(data_folder):
geth_dev = GethDevProcess.from_uri("http://localhost:6799", data_folder)
kwargs = geth_dev.geth_kwargs
assert kwargs["rpc_addr"] == "localhost"
assert kwargs["rpc_port"] == "6799"
assert kwargs["ws_enabled"] is False
assert kwargs.get("ws_api") is None
assert kwargs.get("ws_addr") is None
assert kwargs.get("ws_port") is None


@geth_process_test
def test_geth_dev_from_uri_ws(data_folder):
geth_dev = GethDevProcess.from_uri("ws://localhost:6799", data_folder)
kwargs = geth_dev.geth_kwargs
assert kwargs.get("rpc_addr") is None
assert kwargs["ws_enabled"] is True
assert kwargs["ws_addr"] == "localhost"
assert kwargs["ws_port"] == "6799"


@geth_process_test
def test_geth_dev_from_uri_ipc(data_folder):
geth_dev = GethDevProcess.from_uri("path/to/geth.ipc", data_folder)
kwargs = geth_dev.geth_kwargs
assert kwargs["ipc_path"] == "path/to/geth.ipc"
assert kwargs.get("ws_api") is None
assert kwargs.get("ws_addr") is None
assert kwargs.get("rpc_addr") is None
Loading