From 48aec83d76cf3472e13558bf87e3c8cfc570228c Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Sat, 4 Nov 2023 19:12:49 -0400 Subject: [PATCH 01/12] wallet: remove a dead store in get_index_cache_and_increment --- src/jmclient/wallet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 094a326d7..123686fb0 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -2233,7 +2233,6 @@ def _set_index_cache(self, mixdepth, address_type, index): self._index_cache[mixdepth][address_type] = index def get_index_cache_and_increment(self, mixdepth, address_type): - index = self._index_cache[mixdepth][address_type] cur_index = self._index_cache[mixdepth][address_type] self._set_index_cache(mixdepth, address_type, cur_index + 1) return cur_index From 8245271d7f8c4632427a3750f8f6f6025b486191 Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Thu, 9 Nov 2023 18:09:27 -0500 Subject: [PATCH 02/12] wallet: avoid IndexError in _is_my_bip32_path --- src/jmclient/wallet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 123686fb0..caede524c 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -2208,7 +2208,7 @@ def _get_key_from_path(self, path): self._ENGINE def _is_my_bip32_path(self, path): - return path[0] == self._key_ident + return len(path) > 0 and path[0] == self._key_ident def is_standard_wallet_script(self, path): return self._is_my_bip32_path(path) From 574c29e899c8bb0d52e8446b51b3a4a3137958a6 Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Fri, 10 Nov 2023 04:20:07 -0500 Subject: [PATCH 03/12] wallet: hoist get_script_from_path default impl into BaseWallet --- src/jmclient/wallet.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index caede524c..728751fc1 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -921,7 +921,8 @@ def get_script_from_path(self, path): returns: script """ - raise NotImplementedError() + priv, engine = self._get_key_from_path(path) + return engine.key_to_script(priv) def get_script(self, mixdepth, address_type, index): path = self.get_path(mixdepth, address_type, index) @@ -1932,13 +1933,6 @@ def get_details(self, path): return super().get_details(path) return path[1], 'imported', path[2] - def get_script_from_path(self, path): - if not self._is_imported_path(path): - return super().get_script_from_path(path) - - priv, engine = self._get_key_from_path(path) - return engine.key_to_script(priv) - class BIP39WalletMixin(object): """ @@ -2116,7 +2110,7 @@ def _get_supported_address_types(cls): def get_script_from_path(self, path): if not self._is_my_bip32_path(path): - raise WalletError("unable to get script for unknown key path") + return super().get_script_from_path(path) md, address_type, index = self.get_details(path) @@ -2132,10 +2126,7 @@ def get_script_from_path(self, path): #concept of a "next address" cant be used return self.get_new_script_override_disable(md, address_type) - priv, engine = self._get_key_from_path(path) - script = engine.key_to_script(priv) - - return script + return super().get_script_from_path(path) def get_path(self, mixdepth=None, address_type=None, index=None): if mixdepth is not None: From 2c38a813fc1bd97061759305b51b69daebe01ca4 Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Fri, 10 Nov 2023 04:12:50 -0500 Subject: [PATCH 04/12] wallet: delete redundant get_script and get_addr methods Their implementations were identical to those in the superclass. --- src/jmclient/wallet.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 728751fc1..06f99feff 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -2234,10 +2234,6 @@ def get_script_and_update_map(self, *args): self._script_map[script] = path return script - def get_script(self, mixdepth, address_type, index): - path = self.get_path(mixdepth, address_type, index) - return self.get_script_from_path(path) - @deprecated def get_key(self, mixdepth, address_type, index): path = self.get_path(mixdepth, address_type, index) @@ -2527,14 +2523,6 @@ def get_details(self, path): def _get_default_used_indices(self): return {x: [0, 0, 0, 0] for x in range(self.max_mixdepth + 1)} - def get_script(self, mixdepth, address_type, index): - path = self.get_path(mixdepth, address_type, index) - return self.get_script_from_path(path) - - def get_addr(self, mixdepth, address_type, index): - script = self.get_script(mixdepth, address_type, index) - return self.script_to_addr(script) - def add_burner_output(self, path, txhex, block_height, merkle_branch, block_index, write=True): """ From b58ac679cbca46ff7f3e7f8f63c722b516609434 Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Fri, 10 Nov 2023 05:03:48 -0500 Subject: [PATCH 05/12] wallet: drop _get_addr_int_ext; replace with calls to get_new_addr --- src/jmclient/wallet.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 06f99feff..4ae1405e0 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -543,29 +543,20 @@ def get_key_from_addr(self, addr): privkey = self._get_key_from_path(path)[0] return privkey - def _get_addr_int_ext(self, address_type, mixdepth): - if address_type == self.ADDRESS_TYPE_EXTERNAL: - script = self.get_external_script(mixdepth) - elif address_type == self.ADDRESS_TYPE_INTERNAL: - script = self.get_internal_script(mixdepth) - else: - assert 0 - return self.script_to_addr(script) - def get_external_addr(self, mixdepth): """ Return an address suitable for external distribution, including funding the wallet from other sources, or receiving payments or donations. JoinMarket will never generate these addresses for internal use. """ - return self._get_addr_int_ext(self.ADDRESS_TYPE_EXTERNAL, mixdepth) + return self.get_new_addr(mixdepth, self.ADDRESS_TYPE_EXTERNAL) def get_internal_addr(self, mixdepth): """ Return an address for internal usage, as change addresses and when participating in transactions initiated by other parties. """ - return self._get_addr_int_ext(self.ADDRESS_TYPE_INTERNAL, mixdepth) + return self.get_new_addr(mixdepth, self.ADDRESS_TYPE_INTERNAL) def get_external_script(self, mixdepth): return self.get_new_script(mixdepth, self.ADDRESS_TYPE_EXTERNAL) From fc1e00058b5f9058d58b5e7a87a6fff57ff4b18a Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Sat, 4 Nov 2023 16:41:17 -0400 Subject: [PATCH 06/12] wallet_showutxos: use O(1) check for frozen instead of O(n) utxo_d = [] for k, v in disabled.items(): utxo_d.append(k) {'frozen': True if u in utxo_d else False} The above was inefficient. Replace with: {'frozen': u in disabled} Checking for existence of a key in a dict takes time proportional to O(1), whereas checking for existence of an element in a list takes time proportional to O(n). --- src/jmclient/wallet_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/jmclient/wallet_utils.py b/src/jmclient/wallet_utils.py index 3e29be2c5..d59096b60 100644 --- a/src/jmclient/wallet_utils.py +++ b/src/jmclient/wallet_utils.py @@ -431,9 +431,6 @@ def wallet_showutxos(wallet_service, showprivkey): includeconfs=True) for md in utxos: (enabled, disabled) = get_utxos_enabled_disabled(wallet_service, md) - utxo_d = [] - for k, v in disabled.items(): - utxo_d.append(k) for u, av in utxos[md].items(): success, us = utxo_to_utxostr(u) assert success @@ -453,7 +450,7 @@ def wallet_showutxos(wallet_service, showprivkey): 'external': False, 'mixdepth': mixdepth, 'confirmations': av['confs'], - 'frozen': True if u in utxo_d else False} + 'frozen': u in disabled} if showprivkey: unsp[us]['privkey'] = wallet_service.get_wif_path(av['path']) if locktime: From 184d76f7f7ac81b0e566c5c5f2ea89d9f7da2258 Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Sat, 4 Nov 2023 16:22:14 -0400 Subject: [PATCH 07/12] wallet: add get_{balance,utxos}_at_mixdepth methods Sometimes calling code is only interested in the balance or UTXOs at a single mixdepth. In these cases, it is wasteful to get the balance or UTXOs at all mixdepths, only to throw away the returned information about all but the single mixdepth of interest. Implement new methods in BaseWallet to get the balance or UTXOs at a single mixdepth. Also, correct an apparent oversight due to apparently misplaced indentation: the maxheight parameter of get_balance_by_mixdepth was ignored unless the include_disabled parameter was passed as False. It appears that the intention was for include_disabled and maxheight to be independent filters on the returned information. --- src/jmclient/wallet.py | 90 ++++++++++++++++++------------- test/jmclient/test_utxomanager.py | 24 ++++----- test/jmclient/test_wallet.py | 2 +- 3 files changed, 64 insertions(+), 52 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 4ae1405e0..cbaceb623 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -19,6 +19,7 @@ from decimal import Decimal from numbers import Integral from math import exp +from typing import Any, Dict, Optional, Tuple from .configure import jm_single @@ -280,32 +281,28 @@ def select_utxos(self, mixdepth, amount, utxo_filter=(), select_fn=None, 'value': utxos[s['utxo']][1]} for s in selected} - def get_balance_by_mixdepth(self, max_mixdepth=float('Inf'), - include_disabled=True, maxheight=None): - """ By default this returns a dict of aggregated bitcoin - balance per mixdepth: {0: N sats, 1: M sats, ...} for all - currently available mixdepths. - If max_mixdepth is set it will return balances only up - to that mixdepth. + def get_balance_at_mixdepth(self, mixdepth: int, + include_disabled: bool = True, + maxheight: Optional[int] = None) -> int: + """ By default this returns aggregated bitcoin balance at mixdepth. To get only enabled balance, set include_disabled=False. To get balances only with a certain number of confs, use maxheight. """ - balance_dict = collections.defaultdict(int) - for mixdepth, utxomap in self._utxo.items(): - if mixdepth > max_mixdepth: - continue - if not include_disabled: - utxomap = {k: v for k, v in utxomap.items( - ) if not self.is_disabled(*k)} - if maxheight is not None: - utxomap = {k: v for k, v in utxomap.items( - ) if v[2] <= maxheight} - value = sum(x[1] for x in utxomap.values()) - balance_dict[mixdepth] = value - return balance_dict - - def get_utxos_by_mixdepth(self): - return deepcopy(self._utxo) + utxomap = self._utxo.get(mixdepth) + if not utxomap: + return 0 + if not include_disabled: + utxomap = {k: v for k, v in utxomap.items( + ) if not self.is_disabled(*k)} + if maxheight is not None: + utxomap = {k: v for k, v in utxomap.items( + ) if v[2] <= maxheight} + return sum(x[1] for x in utxomap.values()) + + def get_utxos_at_mixdepth(self, mixdepth: int) -> \ + Dict[Tuple[bytes, int], Tuple[Tuple, int, int]]: + utxomap = self._utxo.get(mixdepth) + return deepcopy(utxomap) if utxomap else {} def __eq__(self, o): return self._utxo == o._utxo and \ @@ -836,10 +833,19 @@ def get_balance_by_mixdepth(self, verbose=True, confirmations, set maxheight to max acceptable blockheight. returns: {mixdepth: value} """ + balances = collections.defaultdict(int) + for md in range(self.mixdepth + 1): + balances[md] = self.get_balance_at_mixdepth(md, verbose=verbose, + include_disabled=include_disabled, maxheight=maxheight) + return balances + + def get_balance_at_mixdepth(self, mixdepth, + verbose: bool = True, + include_disabled: bool = False, + maxheight: Optional[int] = None) -> int: # TODO: verbose - return self._utxos.get_balance_by_mixdepth(max_mixdepth=self.mixdepth, - include_disabled=include_disabled, - maxheight=maxheight) + return self._utxos.get_balance_at_mixdepth(mixdepth, + include_disabled=include_disabled, maxheight=maxheight) def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False): """ @@ -850,25 +856,35 @@ def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False): {'script': bytes, 'path': tuple, 'value': int}}} (if `includeheight` is True, adds key 'height': int) """ - mix_utxos = self._utxos.get_utxos_by_mixdepth() - script_utxos = collections.defaultdict(dict) - for md, data in mix_utxos.items(): - if md > self.mixdepth: - continue + for md in range(self.mixdepth + 1): + script_utxos[md] = self.get_utxos_at_mixdepth(md, + include_disabled=include_disabled, includeheight=includeheight) + return script_utxos + + def get_utxos_at_mixdepth(self, mixdepth: int, + include_disabled: bool = False, + includeheight: bool = False) -> \ + Dict[Tuple[bytes, int], Dict[str, Any]]: + script_utxos = {} + if 0 <= mixdepth <= self.mixdepth: + data = self._utxos.get_utxos_at_mixdepth(mixdepth) for utxo, (path, value, height) in data.items(): if not include_disabled and self._utxos.is_disabled(*utxo): continue script = self.get_script_from_path(path) addr = self.get_address_from_path(path) label = self.get_address_label(addr) - script_utxos[md][utxo] = {'script': script, - 'path': path, - 'value': value, - 'address': addr, - 'label': label} + script_utxo = { + 'script': script, + 'path': path, + 'value': value, + 'address': addr, + 'label': label, + } if includeheight: - script_utxos[md][utxo]['height'] = height + script_utxo['height'] = height + script_utxos[utxo] = script_utxo return script_utxos diff --git a/test/jmclient/test_utxomanager.py b/test/jmclient/test_utxomanager.py index 2d3023f14..1bd97e1ca 100644 --- a/test/jmclient/test_utxomanager.py +++ b/test/jmclient/test_utxomanager.py @@ -56,14 +56,12 @@ def test_utxomanager_persist(setup_env_nodeps): assert not um.is_disabled(txid, index+2) um.disable_utxo(txid, index+2) - utxos = um.get_utxos_by_mixdepth() - assert len(utxos[mixdepth]) == 1 - assert len(utxos[mixdepth+1]) == 2 - assert len(utxos[mixdepth+2]) == 0 + assert len(um.get_utxos_at_mixdepth(mixdepth)) == 1 + assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 2 + assert len(um.get_utxos_at_mixdepth(mixdepth+2)) == 0 - balances = um.get_balance_by_mixdepth() - assert balances[mixdepth] == value - assert balances[mixdepth+1] == value * 2 + assert um.get_balance_at_mixdepth(mixdepth) == value + assert um.get_balance_at_mixdepth(mixdepth+1) == value * 2 um.remove_utxo(txid, index, mixdepth) assert um.have_utxo(txid, index) == False @@ -79,14 +77,12 @@ def test_utxomanager_persist(setup_env_nodeps): assert um.have_utxo(txid, index) == False assert um.have_utxo(txid, index+1) == mixdepth + 1 - utxos = um.get_utxos_by_mixdepth() - assert len(utxos[mixdepth]) == 0 - assert len(utxos[mixdepth+1]) == 1 + assert len(um.get_utxos_at_mixdepth(mixdepth)) == 0 + assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 1 - balances = um.get_balance_by_mixdepth() - assert balances[mixdepth] == 0 - assert balances[mixdepth+1] == value - assert balances[mixdepth+2] == 0 + assert um.get_balance_at_mixdepth(mixdepth) == 0 + assert um.get_balance_at_mixdepth(mixdepth+1) == value + assert um.get_balance_at_mixdepth(mixdepth+2) == 0 def test_utxomanager_select(setup_env_nodeps): diff --git a/test/jmclient/test_wallet.py b/test/jmclient/test_wallet.py index 86d5d8e6b..ab68e7235 100644 --- a/test/jmclient/test_wallet.py +++ b/test/jmclient/test_wallet.py @@ -477,7 +477,7 @@ def test_get_bbm(setup_wallet): wallet = get_populated_wallet(amount, num_tx) # disable a utxo and check we can correctly report # balance with the disabled flag off: - utxo_1 = list(wallet._utxos.get_utxos_by_mixdepth()[0].keys())[0] + utxo_1 = list(wallet._utxos.get_utxos_at_mixdepth(0).keys())[0] wallet.disable_utxo(*utxo_1) balances = wallet.get_balance_by_mixdepth(include_disabled=True) assert balances[0] == num_tx * amount From 77f0194a37075fd747fbdfad68fd3381d1494f49 Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Sat, 4 Nov 2023 16:37:51 -0400 Subject: [PATCH 08/12] wallet_utils: use new get_utxos_at_mixdepth method Rather than evaluating wallet_service.get_utxos_by_mixdepth()[md], instead evaluate wallet_service.get_utxos_at_mixdepth(md). This way we're not computing a bunch of data that we'll immediately discard. --- src/jmclient/wallet_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jmclient/wallet_utils.py b/src/jmclient/wallet_utils.py index d59096b60..c1165ff0c 100644 --- a/src/jmclient/wallet_utils.py +++ b/src/jmclient/wallet_utils.py @@ -407,8 +407,8 @@ def get_imported_privkey_branch(wallet_service, m, showprivkey): addr = wallet_service.get_address_from_path(path) script = wallet_service.get_script_from_path(path) balance = 0.0 - for data in wallet_service.get_utxos_by_mixdepth( - include_disabled=True)[m].values(): + for data in wallet_service.get_utxos_at_mixdepth(m, + include_disabled=True).values(): if script == data['script']: balance += data['value'] status = ('used' if balance > 0.0 else 'empty') @@ -1276,8 +1276,8 @@ def output_utxos(utxos, status, start=0): def get_utxos_enabled_disabled(wallet_service, md): """ Returns dicts for enabled and disabled separately """ - utxos_enabled = wallet_service.get_utxos_by_mixdepth()[md] - utxos_all = wallet_service.get_utxos_by_mixdepth(include_disabled=True)[md] + utxos_enabled = wallet_service.get_utxos_at_mixdepth(md) + utxos_all = wallet_service.get_utxos_at_mixdepth(md, include_disabled=True) utxos_disabled_keyset = set(utxos_all).difference(set(utxos_enabled)) utxos_disabled = {} for u in utxos_disabled_keyset: From 64f18bce18fca78258b04da75b11a9fa6d807f4e Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Sat, 4 Nov 2023 17:01:09 -0400 Subject: [PATCH 09/12] get_imported_privkey_branch: use O(m+n) algorithm instead of O(m*n) The algorithm in get_imported_privkey_branch was O(m*n): for each imported path, it was iterating over the entire set of UTXOs. Rewrite the algorithm to make one pass over the set of UTXOs up front to compute the balance of each script (O(m)) and then, separately, one pass over the set of imported paths to pluck out the balance for each path (O(n)). --- src/jmclient/wallet_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/jmclient/wallet_utils.py b/src/jmclient/wallet_utils.py index c1165ff0c..786637fcf 100644 --- a/src/jmclient/wallet_utils.py +++ b/src/jmclient/wallet_utils.py @@ -7,7 +7,7 @@ from datetime import datetime, timedelta from optparse import OptionParser from numbers import Integral -from collections import Counter +from collections import Counter, defaultdict from itertools import islice, chain from jmclient import (get_network, WALLET_IMPLEMENTATIONS, Storage, podle, jm_single, WalletError, BaseWallet, VolatileStorage, @@ -403,15 +403,15 @@ def get_tx_info(txid, tx_cache=None): def get_imported_privkey_branch(wallet_service, m, showprivkey): entries = [] + balance_by_script = defaultdict(int) + for data in wallet_service.get_utxos_at_mixdepth(m, + include_disabled=True).values(): + balance_by_script[data['script']] += data['value'] for path in wallet_service.yield_imported_paths(m): addr = wallet_service.get_address_from_path(path) script = wallet_service.get_script_from_path(path) - balance = 0.0 - for data in wallet_service.get_utxos_at_mixdepth(m, - include_disabled=True).values(): - if script == data['script']: - balance += data['value'] - status = ('used' if balance > 0.0 else 'empty') + balance = balance_by_script.get(script, 0) + status = ('used' if balance else 'empty') if showprivkey: wip_privkey = wallet_service.get_wif_path(path) else: From 01ec2a41817c8e9c6f6af4bc6e6ad86a687e63ff Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Sat, 4 Nov 2023 18:42:03 -0400 Subject: [PATCH 10/12] wallet: add _addr_map, paralleling _script_map Hoist _populate_script_map from BIP32Wallet into BaseWallet, rename it to _populate_maps, and have it populate the new _addr_map in addition to the existing _script_map. Have the constructor of each concrete wallet subclass pass to _populate_maps the paths it contributes. Additionally, do not implement yield_known_paths by iterating over _script_map, but rather have each wallet subclass contribute its own paths to the generator returned by yield_known_paths. --- src/jmclient/wallet.py | 117 ++++++++++++++++++----------------- test/jmclient/test_wallet.py | 12 ++-- 2 files changed, 65 insertions(+), 64 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index cbaceb623..27c1d554b 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -382,6 +382,8 @@ def __init__(self, storage, gap_limit=6, merge_algorithm_name=None, # {script: path}, should always hold mappings for all "known" keys self._script_map = {} + # {address: path}, should always hold mappings for all "known" keys + self._addr_map = {} self._load_storage() @@ -535,8 +537,7 @@ def get_key_from_addr(self, addr): """ There should be no reason for code outside the wallet to need a privkey. """ - script = self._ENGINE.address_to_script(addr) - path = self.script_to_path(script) + path = self.addr_to_path(addr) privkey = self._get_key_from_path(path)[0] return privkey @@ -1046,8 +1047,8 @@ def is_known_addr(self, addr): returns: bool """ - script = self.addr_to_script(addr) - return script in self._script_map + assert isinstance(addr, str) + return addr in self._addr_map def is_known_script(self, script): """ @@ -1062,8 +1063,8 @@ def is_known_script(self, script): return script in self._script_map def get_addr_mixdepth(self, addr): - script = self.addr_to_script(addr) - return self.get_script_mixdepth(script) + path = self.addr_to_path(addr) + return self._get_mixdepth_from_path(path) def get_script_mixdepth(self, script): path = self.script_to_path(script) @@ -1076,16 +1077,27 @@ def yield_known_paths(self): returns: path generator """ - for s in self._script_map.values(): - yield s + for md in range(self.max_mixdepth + 1): + for path in self.yield_imported_paths(md): + yield path + + def _populate_maps(self, paths): + for path in paths: + script = self.get_script_from_path(path) + self._script_map[script] = path + self._addr_map[self.script_to_addr(script)] = path def addr_to_path(self, addr): - script = self.addr_to_script(addr) - return self.script_to_path(script) + assert isinstance(addr, str) + path = self._addr_map.get(addr) + assert path is not None + return path def script_to_path(self, script): - assert script in self._script_map - return self._script_map[script] + assert isinstance(script, bytes) + path = self._script_map.get(script) + assert path is not None + return path def set_next_index(self, mixdepth, address_type, index, force=False): """ @@ -1775,12 +1787,7 @@ def _load_storage(self): for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items(): md = int(md) self._imported[md] = keys - for index, (key, key_type) in enumerate(keys): - if not key: - # imported key was removed - continue - assert key_type in self._ENGINES - self._cache_imported_key(md, key, key_type, index) + self._populate_maps(self.yield_imported_paths(md)) def save(self): import_data = {} @@ -1849,8 +1856,8 @@ def remove_imported_key(self, script=None, address=None, path=None): raise Exception("Only one of script|address|path may be given.") if address: - script = self.addr_to_script(address) - if script: + path = self.addr_to_path(address) + elif script: path = self.script_to_path(script) if not path: @@ -1863,18 +1870,18 @@ def remove_imported_key(self, script=None, address=None, path=None): if not script: script = self.get_script_from_path(path) + if not address: + address = self.script_to_addr(script) # we need to retain indices self._imported[path[1]][path[2]] = (b'', -1) del self._script_map[script] + del self._addr_map[address] def _cache_imported_key(self, mixdepth, privkey, key_type, index): - engine = self._ENGINES[key_type] path = (self._IMPORTED_ROOT_PATH, mixdepth, index) - - self._script_map[engine.key_to_script(privkey)] = path - + self._populate_maps((path,)) return path def _get_mixdepth_from_path(self, path): @@ -2010,6 +2017,7 @@ class BIP32Wallet(BaseWallet): def __init__(self, storage, **kwargs): self._entropy = None + self._key_ident = None # {mixdepth: {type: index}} with type being 0/1 corresponding # to external/internal addresses self._index_cache = None @@ -2028,7 +2036,7 @@ def __init__(self, storage, **kwargs): # used to verify paths for sanity checking and for wallet id creation self._key_ident = b'' # otherwise get_bip32_* won't work self._key_ident = self._get_key_ident() - self._populate_script_map() + self._populate_maps(self.yield_known_bip32_paths()) self.disable_new_scripts = False @classmethod @@ -2074,13 +2082,14 @@ def _get_key_ident(self): self.get_bip32_priv_export(0, self.BIP32_EXT_ID).encode('ascii')).digest())\ .digest()[:3] - def _populate_script_map(self): + def yield_known_paths(self): + return chain(super().yield_known_paths(), self.yield_known_bip32_paths()) + + def yield_known_bip32_paths(self): for md in self._index_cache: for address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID): for i in range(self._index_cache[md][address_type]): - path = self.get_path(md, address_type, i) - script = self.get_script_from_path(path) - self._script_map[script] = path + yield self.get_path(md, address_type, i) def save(self): for md, data in self._index_cache.items(): @@ -2115,10 +2124,7 @@ def _derive_bip32_master_key(cls, seed): def _get_supported_address_types(cls): return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID) - def get_script_from_path(self, path): - if not self._is_my_bip32_path(path): - return super().get_script_from_path(path) - + def _check_path(self, path): md, address_type, index = self.get_details(path) if not 0 <= md <= self.max_mixdepth: @@ -2131,10 +2137,19 @@ def get_script_from_path(self, path): and address_type != FidelityBondMixin.BIP32_TIMELOCK_ID: #special case for timelocked addresses because for them the #concept of a "next address" cant be used - return self.get_new_script_override_disable(md, address_type) + self._set_index_cache(md, address_type, current_index + 1) + self._populate_maps((path,)) + def get_script_from_path(self, path): + if self._is_my_bip32_path(path): + self._check_path(path) return super().get_script_from_path(path) + def get_address_from_path(self, path): + if self._is_my_bip32_path(path): + self._check_path(path) + return super().get_address_from_path(path) + def get_path(self, mixdepth=None, address_type=None, index=None): if mixdepth is not None: assert isinstance(mixdepth, Integral) @@ -2215,13 +2230,8 @@ def get_new_script(self, mixdepth, address_type): if self.disable_new_scripts: raise RuntimeError("Obtaining new wallet addresses " + "disabled, due to nohistory mode") - return self.get_new_script_override_disable(mixdepth, address_type) - - def get_new_script_override_disable(self, mixdepth, address_type): - # This is called by get_script_from_path and calls back there. We need to - # ensure all conditions match to avoid endless recursion. - index = self.get_index_cache_and_increment(mixdepth, address_type) - return self.get_script_and_update_map(mixdepth, address_type, index) + index = self._index_cache[mixdepth][address_type] + return self.get_script(mixdepth, address_type, index) def _set_index_cache(self, mixdepth, address_type, index): """ Ensures that any update to index_cache dict only applies @@ -2230,17 +2240,6 @@ def _set_index_cache(self, mixdepth, address_type, index): assert address_type in self._get_supported_address_types() self._index_cache[mixdepth][address_type] = index - def get_index_cache_and_increment(self, mixdepth, address_type): - cur_index = self._index_cache[mixdepth][address_type] - self._set_index_cache(mixdepth, address_type, cur_index + 1) - return cur_index - - def get_script_and_update_map(self, *args): - path = self.get_path(*args) - script = self.get_script_from_path(path) - self._script_map[script] = path - return script - @deprecated def get_key(self, mixdepth, address_type, index): path = self.get_path(mixdepth, address_type, index) @@ -2385,6 +2384,10 @@ class FidelityBondMixin(object): _BIP32_PUBKEY_PREFIX = "fbonds-mpk-" + def __init__(self, storage, **kwargs): + super().__init__(storage, **kwargs) + self._populate_maps(self.yield_fidelity_bond_paths()) + @classmethod def _time_number_to_timestamp(cls, timenumber): """ @@ -2444,14 +2447,14 @@ def get_xpub_from_fidelity_bond_master_pub_key(cls, mpk): else: return False - def _populate_script_map(self): - super()._populate_script_map() + def yield_known_paths(self): + return chain(super().yield_known_paths(), self.yield_fidelity_bond_paths()) + + def yield_fidelity_bond_paths(self): md = self.FIDELITY_BOND_MIXDEPTH address_type = self.BIP32_TIMELOCK_ID for timenumber in range(self.TIMENUMBER_COUNT): - path = self.get_path(md, address_type, timenumber) - script = self.get_script_from_path(path) - self._script_map[script] = path + yield self.get_path(md, address_type, timenumber) def add_utxo(self, txid, index, script, value, height=None): super().add_utxo(txid, index, script, value, height) diff --git a/test/jmclient/test_wallet.py b/test/jmclient/test_wallet.py index ab68e7235..45b23fa8e 100644 --- a/test/jmclient/test_wallet.py +++ b/test/jmclient/test_wallet.py @@ -17,7 +17,6 @@ wallet_gettimelockaddress, UnknownAddressForLabel from test_blockchaininterface import sync_test_wallet from freezegun import freeze_time -from bitcointx.wallet import CCoinAddressError pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind") @@ -264,9 +263,6 @@ def test_bip32_timelocked_addresses(setup_wallet, timenumber, address, wif): mixdepth = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH address_type = FidelityBondMixin.BIP32_TIMELOCK_ID - #wallet needs to know about the script beforehand - wallet.get_script_and_update_map(mixdepth, address_type, timenumber) - assert address == wallet.get_addr(mixdepth, address_type, timenumber) assert wif == wallet.get_wif_path(wallet.get_path(mixdepth, address_type, timenumber)) @@ -287,7 +283,7 @@ def test_gettimelockaddress_method(setup_wallet, timenumber, locktime_string): m = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH address_type = FidelityBondMixin.BIP32_TIMELOCK_ID - script = wallet.get_script_and_update_map(m, address_type, timenumber) + script = wallet.get_script(m, address_type, timenumber) addr = wallet.script_to_addr(script) addr_from_method = wallet_gettimelockaddress(wallet, locktime_string) @@ -456,7 +452,7 @@ def test_timelocked_output_signing(setup_wallet): wallet = SegwitWalletFidelityBonds(storage) timenumber = 0 - script = wallet.get_script_and_update_map( + script = wallet.get_script( FidelityBondMixin.FIDELITY_BOND_MIXDEPTH, FidelityBondMixin.BIP32_TIMELOCK_ID, timenumber) utxo = fund_wallet_addr(wallet, wallet.script_to_addr(script)) @@ -610,7 +606,9 @@ def test_address_labels(setup_wallet): wallet.get_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4") wallet.set_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4", "test") - with pytest.raises(CCoinAddressError): + # we no longer decode addresses just to see if we know about them, + # so we won't get a CCoinAddressError for invalid addresses + #with pytest.raises(CCoinAddressError): wallet.get_address_label("badaddress") wallet.set_address_label("badaddress", "test") From 5bc7eb4b8e3057ed3d60537c518f6f8918f4431d Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Sat, 4 Nov 2023 18:56:46 -0400 Subject: [PATCH 11/12] wallet: add persistent cache, mapping path->(priv, pub, script, addr) Deriving private keys from BIP32 paths, public keys from private keys, scripts from public keys, and addresses from scripts are some of the most CPU-intensive tasks the wallet performs. Once the wallet inevitably accumulates thousands of used paths, startup times become painful due to needing to re-derive these data items for every used path in the wallet upon every startup. Introduce a persistent cache to avoid the need to re-derive these items every time the wallet is opened. Introduce _get_keypair_from_path and _get_pubkey_from_path methods to allow cached public keys to be used rather than always deriving them on the fly. Change many code paths that were calling CPU-intensive methods of BTCEngine so that instead they call _get_key_from_path, _get_keypair_from_path, _get_pubkey_from_path, get_script_from_path, and/or get_address_from_path, all of which can take advantage of the new cache. --- src/jmclient/wallet.py | 194 ++++++++++++++++++++++++++++++------ test/jmclient/test_taker.py | 9 +- 2 files changed, 171 insertions(+), 32 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 27c1d554b..a580f96fa 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -374,6 +374,7 @@ def __init__(self, storage, gap_limit=6, merge_algorithm_name=None, self._storage = storage self._utxos = None self._addr_labels = None + self._cache = None # highest mixdepth ever used in wallet, important for synching self.max_mixdepth = None # effective maximum mixdepth to be used by joinmarket @@ -388,6 +389,7 @@ def __init__(self, storage, gap_limit=6, merge_algorithm_name=None, self._load_storage() assert self._utxos is not None + assert self._cache is not None assert self.max_mixdepth is not None assert self.max_mixdepth >= 0 assert self.network in ('mainnet', 'testnet', 'signet') @@ -424,6 +426,7 @@ def _load_storage(self): self.network = self._storage.data[b'network'].decode('ascii') self._utxos = UTXOManager(self._storage, self.merge_algorithm) self._addr_labels = AddressLabelsManager(self._storage) + self._cache = self._storage.data.setdefault(b'cache', {}) def get_storage_location(self): """ Return the location of the @@ -564,21 +567,31 @@ def get_internal_script(self, mixdepth): @classmethod def addr_to_script(cls, addr): + """ + Try not to call this slow method. Instead, call addr_to_path, + followed by get_script_from_path, as those are cached. + """ return cls._ENGINE.address_to_script(addr) @classmethod def pubkey_to_script(cls, pubkey): + """ + Try not to call this slow method. Instead, call + get_script_from_path if possible, as that is cached. + """ return cls._ENGINE.pubkey_to_script(pubkey) @classmethod def pubkey_to_addr(cls, pubkey): + """ + Try not to call this slow method. Instead, call + get_address_from_path if possible, as that is cached. + """ return cls._ENGINE.pubkey_to_address(pubkey) def script_to_addr(self, script): - assert self.is_known_script(script) path = self.script_to_path(script) - engine = self._get_key_from_path(path)[1] - return engine.script_to_address(script) + return self.get_address_from_path(path) def get_script_code(self, script): """ @@ -589,8 +602,7 @@ def get_script_code(self, script): For non-segwit wallets, raises EngineError. """ path = self.script_to_path(script) - priv, engine = self._get_key_from_path(path) - pub = engine.privkey_to_pubkey(priv) + pub, engine = self._get_pubkey_from_path(path) return engine.pubkey_to_script_code(pub) @classmethod @@ -606,12 +618,20 @@ def get_key(self, mixdepth, address_type, index): raise NotImplementedError() def get_addr(self, mixdepth, address_type, index): - script = self.get_script(mixdepth, address_type, index) - return self.script_to_addr(script) + path = self.get_path(mixdepth, address_type, index) + return self.get_address_from_path(path) def get_address_from_path(self, path): - script = self.get_script_from_path(path) - return self.script_to_addr(script) + cache = self._get_cache_for_path(path) + addr = cache.get(b'A') + if addr is None: + engine = self._get_pubkey_from_path(path)[1] + script = self.get_script_from_path(path) + addr = engine.script_to_address(script) + cache[b'A'] = addr.encode('ascii') + else: + addr = addr.decode('ascii') + return addr def get_new_addr(self, mixdepth, address_type): """ @@ -929,8 +949,13 @@ def get_script_from_path(self, path): returns: script """ - priv, engine = self._get_key_from_path(path) - return engine.key_to_script(priv) + cache = self._get_cache_for_path(path) + script = cache.get(b'S') + if script is None: + pubkey, engine = self._get_pubkey_from_path(path) + script = engine.pubkey_to_script(pubkey) + cache[b'S'] = script + return script def get_script(self, mixdepth, address_type, index): path = self.get_path(mixdepth, address_type, index) @@ -939,6 +964,44 @@ def get_script(self, mixdepth, address_type, index): def _get_key_from_path(self, path): raise NotImplementedError() + def _get_keypair_from_path(self, path): + privkey, engine = self._get_key_from_path(path) + cache = self._get_cache_for_path(path) + pubkey = cache.get(b'P') + if pubkey is None: + pubkey = engine.privkey_to_pubkey(privkey) + cache[b'P'] = pubkey + return privkey, pubkey, engine + + def _get_pubkey_from_path(self, path): + privkey, pubkey, engine = self._get_keypair_from_path(path) + return pubkey, engine + + def _get_cache_keys_for_path(self, path): + return path[:1] + tuple(map(_int_to_bytestr, path[1:])) + + def _get_cache_for_path(self, path): + assert len(path) > 0 + cache = self._cache + for k in self._get_cache_keys_for_path(path): + cache = cache.setdefault(k, {}) + return cache + + def _delete_cache_for_path(self, path) -> bool: + assert len(path) > 0 + def recurse(cache, itr): + k = next(itr, None) + if k is None: + cache.clear() + else: + child = cache.get(k) + if not child or not recurse(child, itr): + return False + if not child: + del cache[k] + return True + return recurse(self._cache, iter(self._get_cache_keys_for_path(path))) + def get_path_repr(self, path): """ Get a human-readable representation of the wallet path. @@ -993,7 +1056,7 @@ def sign_message(self, message, path): signature as base64-encoded string """ priv, engine = self._get_key_from_path(path) - addr = engine.privkey_to_address(priv) + addr = self.get_address_from_path(path) return addr, engine.sign_message(priv, message) def get_wallet_name(self): @@ -1083,9 +1146,8 @@ def yield_known_paths(self): def _populate_maps(self, paths): for path in paths: - script = self.get_script_from_path(path) - self._script_map[script] = path - self._addr_map[self.script_to_addr(script)] = path + self._script_map[self.get_script_from_path(path)] = path + self._addr_map[self.get_address_from_path(path)] = path def addr_to_path(self, addr): assert isinstance(addr, str) @@ -1399,9 +1461,8 @@ def create_psbt_from_tx(self, tx, spent_outs=None, force_witness_utxo=True): # this happens when an input is provided but it's not in # this wallet; in this case, we cannot set the redeem script. continue - privkey, _ = self._get_key_from_path(path) - txinput.redeem_script = btc.pubkey_to_p2wpkh_script( - btc.privkey_to_pubkey(privkey)) + pubkey = self._get_pubkey_from_path(path)[0] + txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey) return new_psbt def sign_psbt(self, in_psbt, with_sign_result=False): @@ -1471,9 +1532,8 @@ def sign_psbt(self, in_psbt, with_sign_result=False): # this happens when an input is provided but it's not in # this wallet; in this case, we cannot set the redeem script. continue - privkey, _ = self._get_key_from_path(path) - txinput.redeem_script = btc.pubkey_to_p2wpkh_script( - btc.privkey_to_pubkey(privkey)) + pubkey = self._get_pubkey_from_path(path)[0] + txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey) # no else branch; any other form of scriptPubKey will just be # ignored. try: @@ -1871,13 +1931,14 @@ def remove_imported_key(self, script=None, address=None, path=None): if not script: script = self.get_script_from_path(path) if not address: - address = self.script_to_addr(script) + address = self.get_address_from_path(path) # we need to retain indices self._imported[path[1]][path[2]] = (b'', -1) del self._script_map[script] del self._addr_map[address] + self._delete_cache_for_path(path) def _cache_imported_key(self, mixdepth, privkey, key_type, index): path = (self._IMPORTED_ROOT_PATH, mixdepth, index) @@ -1916,7 +1977,7 @@ def _is_imported_path(cls, path): def is_standard_wallet_script(self, path): if self._is_imported_path(path): - engine = self._get_key_from_path(path)[1] + engine = self._get_pubkey_from_path(path)[1] return engine == self._ENGINE return super().is_standard_wallet_script(path) @@ -2164,7 +2225,6 @@ def get_path(self, mixdepth=None, address_type=None, index=None): assert isinstance(index, Integral) if address_type is None: raise Exception("address_type must be set if index is set") - assert index <= self._index_cache[mixdepth][address_type] assert index < self.BIP32_MAX_PATH_LEVEL return tuple(chain(self._get_bip32_export_path(mixdepth, address_type), (index,))) @@ -2216,9 +2276,32 @@ def _get_mixdepth_from_path(self, path): def _get_key_from_path(self, path): if not self._is_my_bip32_path(path): raise WalletError("Invalid path, unknown root: {}".format(path)) - - return self._ENGINE.derive_bip32_privkey(self._master_key, path), \ - self._ENGINE + cache = self._get_cache_for_path(path) + privkey = cache.get(b'p') + if privkey is None: + privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + cache[b'p'] = privkey + return privkey, self._ENGINE + + def _get_keypair_from_path(self, path): + if not self._is_my_bip32_path(path): + return super()._get_keypair_from_path(path) + cache = self._get_cache_for_path(path) + privkey = cache.get(b'p') + if privkey is None: + privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + cache[b'p'] = privkey + pubkey = cache.get(b'P') + if pubkey is None: + pubkey = self._ENGINE.privkey_to_pubkey(privkey) + cache[b'P'] = pubkey + return privkey, pubkey, self._ENGINE + + def _get_cache_keys_for_path(self, path): + if not self._is_my_bip32_path(path): + return super()._get_cache_keys_for_path(path) + return path[:1] + tuple([self._path_level_to_repr(lvl).encode('ascii') + for lvl in path[1:]]) def _is_my_bip32_path(self, path): return len(path) > 0 and path[0] == self._key_ident @@ -2431,8 +2514,7 @@ def is_timelocked_path(cls, path): def _get_key_ident(self): first_path = self.get_path(0, BIP32Wallet.BIP32_EXT_ID) - priv, engine = self._get_key_from_path(first_path) - pub = engine.privkey_to_pubkey(priv) + pub = self._get_pubkey_from_path(first_path)[0] return sha256(sha256(pub).digest()).digest()[:3] def is_standard_wallet_script(self, path): @@ -2483,11 +2565,37 @@ def _get_key_from_path(self, path): key_path = path[:-1] locktime = path[-1] engine = self._TIMELOCK_ENGINE - privkey = engine.derive_bip32_privkey(self._master_key, key_path) + cache = super()._get_cache_for_path(key_path) + privkey = cache.get(b'p') + if privkey is None: + privkey = engine.derive_bip32_privkey(self._master_key, key_path) + cache[b'p'] = privkey return (privkey, locktime), engine else: return super()._get_key_from_path(path) + def _get_keypair_from_path(self, path): + if not self.is_timelocked_path(path): + return super()._get_keypair_from_path(path) + key_path = path[:-1] + locktime = path[-1] + engine = self._TIMELOCK_ENGINE + cache = super()._get_cache_for_path(key_path) + privkey = cache.get(b'p') + if privkey is None: + privkey = engine.derive_bip32_privkey(self._master_key, key_path) + cache[b'p'] = privkey + pubkey = cache.get(b'P') + if pubkey is None: + pubkey = engine.privkey_to_pubkey(privkey) + cache[b'P'] = pubkey + return (privkey, locktime), (pubkey, locktime), engine + + def _get_cache_for_path(self, path): + if self.is_timelocked_path(path): + path = path[:-1] + return super()._get_cache_for_path(path) + def get_path(self, mixdepth=None, address_type=None, index=None): if address_type == None or address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID, self.BIP32_BURN_ID) or index == None: @@ -2632,6 +2740,32 @@ def _get_bip32_export_path(self, mixdepth=None, address_type=None): path = super()._get_bip32_export_path(mixdepth, address_type) return path + def _get_key_from_path(self, path): + raise WalletError("Cannot get a private key from a watch-only wallet") + + def _get_keypair_from_path(self, path): + raise WalletError("Cannot get a private key from a watch-only wallet") + + def _get_pubkey_from_path(self, path): + if not self._is_my_bip32_path(path): + return super()._get_pubkey_from_path(path) + if self.is_timelocked_path(path): + key_path = path[:-1] + locktime = path[-1] + cache = self._get_cache_for_path(key_path) + pubkey = cache.get(b'P') + if pubkey is None: + pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey( + self._master_key, key_path) + cache[b'P'] = pubkey + return (pubkey, locktime), self._TIMELOCK_ENGINE + cache = self._get_cache_for_path(path) + pubkey = cache.get(b'P') + if pubkey is None: + pubkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + cache[b'P'] = pubkey + return pubkey, self._ENGINE + WALLET_IMPLEMENTATIONS = { LegacyWallet.TYPE: LegacyWallet, diff --git a/test/jmclient/test_taker.py b/test/jmclient/test_taker.py index 8d20a4331..da902f396 100644 --- a/test/jmclient/test_taker.py +++ b/test/jmclient/test_taker.py @@ -121,6 +121,11 @@ def get_txtype(self): """ return 'p2wpkh' + def _get_key_from_path(self, path): + if path[0] == b'dummy': + return struct.pack(b'B', path[2] + 1)*32 + b'\x01', self._ENGINE + raise NotImplementedError() + def get_key_from_addr(self, addr): """usable addresses: privkey all 1s, 2s, 3s, ... :""" privs = [x*32 + b"\x01" for x in [struct.pack(b'B', y) for y in range(1,6)]] @@ -139,8 +144,8 @@ def get_key_from_addr(self, addr): return p raise ValueError("No such keypair") - def _is_my_bip32_path(self, path): - return True + def get_path_repr(self, path): + return '/'.join(map(str, path)) def is_standard_wallet_script(self, path): if path[0] == "nonstandard_path": From c3c10f1615631c3cca4d3c6be88c9c875e961d37 Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Tue, 14 Nov 2023 22:08:29 -0500 Subject: [PATCH 12/12] wallet: implement optional cache validation Add a validate_cache parameter to the five principal caching methods: - _get_key_from_path - _get_keypair_from_path - _get_pubkey_from_path - get_script_from_path - get_address_from_path and to the five convenience methods that wrap the above: - get_script - get_addr - script_to_addr - get_new_script - get_new_addr The value of this new parameter defaults to False in all but the last two methods, where we are willing to sacrifice speed for the sake of extra confidence in the correctness of *new* scripts and addresses to be used for new deposits and new transactions. --- src/jmclient/wallet.py | 215 ++++++++++++++++++++++++------------ test/jmclient/test_taker.py | 9 +- 2 files changed, 149 insertions(+), 75 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index a580f96fa..b417f60d9 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -589,9 +589,11 @@ def pubkey_to_addr(cls, pubkey): """ return cls._ENGINE.pubkey_to_address(pubkey) - def script_to_addr(self, script): + def script_to_addr(self, script, + validate_cache: bool = False): path = self.script_to_path(script) - return self.get_address_from_path(path) + return self.get_address_from_path(path, + validate_cache=validate_cache) def get_script_code(self, script): """ @@ -617,30 +619,42 @@ def pubkey_has_script(cls, pubkey, script): def get_key(self, mixdepth, address_type, index): raise NotImplementedError() - def get_addr(self, mixdepth, address_type, index): + def get_addr(self, mixdepth, address_type, index, + validate_cache: bool = False): path = self.get_path(mixdepth, address_type, index) - return self.get_address_from_path(path) + return self.get_address_from_path(path, + validate_cache=validate_cache) - def get_address_from_path(self, path): + def get_address_from_path(self, path, + validate_cache: bool = False): cache = self._get_cache_for_path(path) addr = cache.get(b'A') - if addr is None: - engine = self._get_pubkey_from_path(path)[1] - script = self.get_script_from_path(path) - addr = engine.script_to_address(script) - cache[b'A'] = addr.encode('ascii') - else: + if addr is not None: addr = addr.decode('ascii') + if addr is None or validate_cache: + engine = self._get_pubkey_from_path(path)[1] + script = self.get_script_from_path(path, + validate_cache=validate_cache) + new_addr = engine.script_to_address(script) + if addr is None: + addr = new_addr + cache[b'A'] = addr.encode('ascii') + elif addr != new_addr: + raise WalletError("Wallet cache validation failed") return addr - def get_new_addr(self, mixdepth, address_type): + def get_new_addr(self, mixdepth, address_type, + validate_cache: bool = True): """ use get_external_addr/get_internal_addr """ - script = self.get_new_script(mixdepth, address_type) - return self.script_to_addr(script) + script = self.get_new_script(mixdepth, address_type, + validate_cache=validate_cache) + return self.script_to_addr(script, + validate_cache=validate_cache) - def get_new_script(self, mixdepth, address_type): + def get_new_script(self, mixdepth, address_type, + validate_cache: bool = True): raise NotImplementedError() def get_wif(self, mixdepth, address_type, index): @@ -938,7 +952,8 @@ def _get_merge_algorithm(cls, algorithm_name=None): def _get_mixdepth_from_path(self, path): raise NotImplementedError() - def get_script_from_path(self, path): + def get_script_from_path(self, path, + validate_cache: bool = False): """ internal note: This is the final sink for all operations that somehow need to derive a script. If anything goes wrong when deriving a @@ -951,30 +966,43 @@ def get_script_from_path(self, path): """ cache = self._get_cache_for_path(path) script = cache.get(b'S') - if script is None: - pubkey, engine = self._get_pubkey_from_path(path) - script = engine.pubkey_to_script(pubkey) - cache[b'S'] = script + if script is None or validate_cache: + pubkey, engine = self._get_pubkey_from_path(path, + validate_cache=validate_cache) + new_script = engine.pubkey_to_script(pubkey) + if script is None: + cache[b'S'] = script = new_script + elif script != new_script: + raise WalletError("Wallet cache validation failed") return script - def get_script(self, mixdepth, address_type, index): + def get_script(self, mixdepth, address_type, index, + validate_cache: bool = False): path = self.get_path(mixdepth, address_type, index) - return self.get_script_from_path(path) + return self.get_script_from_path(path, validate_cache=validate_cache) - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): raise NotImplementedError() - def _get_keypair_from_path(self, path): - privkey, engine = self._get_key_from_path(path) + def _get_keypair_from_path(self, path, + validate_cache: bool = False): + privkey, engine = self._get_key_from_path(path, + validate_cache=validate_cache) cache = self._get_cache_for_path(path) pubkey = cache.get(b'P') - if pubkey is None: - pubkey = engine.privkey_to_pubkey(privkey) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = engine.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return privkey, pubkey, engine - def _get_pubkey_from_path(self, path): - privkey, pubkey, engine = self._get_keypair_from_path(path) + def _get_pubkey_from_path(self, path, + validate_cache: bool = False): + privkey, pubkey, engine = self._get_keypair_from_path(path, + validate_cache=validate_cache) return pubkey, engine def _get_cache_keys_for_path(self, path): @@ -1952,9 +1980,11 @@ def _get_mixdepth_from_path(self, path): assert len(path) == 3 return path[1] - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if not self._is_imported_path(path): - return super()._get_key_from_path(path) + return super()._get_key_from_path(path, + validate_cache=validate_cache) assert len(path) == 3 md, i = path[1], path[2] @@ -2201,15 +2231,19 @@ def _check_path(self, path): self._set_index_cache(md, address_type, current_index + 1) self._populate_maps((path,)) - def get_script_from_path(self, path): + def get_script_from_path(self, path, + validate_cache: bool = False): if self._is_my_bip32_path(path): self._check_path(path) - return super().get_script_from_path(path) + return super().get_script_from_path(path, + validate_cache=validate_cache) - def get_address_from_path(self, path): + def get_address_from_path(self, path, + validate_cache: bool = False): if self._is_my_bip32_path(path): self._check_path(path) - return super().get_address_from_path(path) + return super().get_address_from_path(path, + validate_cache=validate_cache) def get_path(self, mixdepth=None, address_type=None, index=None): if mixdepth is not None: @@ -2273,28 +2307,40 @@ def _get_mixdepth_from_path(self, path): return path[len(self._get_bip32_base_path())] - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if not self._is_my_bip32_path(path): raise WalletError("Invalid path, unknown root: {}".format(path)) cache = self._get_cache_for_path(path) privkey = cache.get(b'p') - if privkey is None: - privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") return privkey, self._ENGINE - def _get_keypair_from_path(self, path): + def _get_keypair_from_path(self, path, + validate_cache: bool = False): if not self._is_my_bip32_path(path): - return super()._get_keypair_from_path(path) + return super()._get_keypair_from_path(path, + validate_cache=validate_cache) cache = self._get_cache_for_path(path) privkey = cache.get(b'p') - if privkey is None: - privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") pubkey = cache.get(b'P') - if pubkey is None: - pubkey = self._ENGINE.privkey_to_pubkey(privkey) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = self._ENGINE.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return privkey, pubkey, self._ENGINE def _get_cache_keys_for_path(self, path): @@ -2309,12 +2355,14 @@ def _is_my_bip32_path(self, path): def is_standard_wallet_script(self, path): return self._is_my_bip32_path(path) - def get_new_script(self, mixdepth, address_type): + def get_new_script(self, mixdepth, address_type, + validate_cache: bool = True): if self.disable_new_scripts: raise RuntimeError("Obtaining new wallet addresses " + "disabled, due to nohistory mode") index = self._index_cache[mixdepth][address_type] - return self.get_script(mixdepth, address_type, index) + return self.get_script(mixdepth, address_type, index, + validate_cache=validate_cache) def _set_index_cache(self, mixdepth, address_type, index): """ Ensures that any update to index_cache dict only applies @@ -2560,35 +2608,47 @@ def get_bip32_pub_export(self, mixdepth=None, address_type=None): def _get_supported_address_types(cls): return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID, cls.BIP32_TIMELOCK_ID, cls.BIP32_BURN_ID) - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if self.is_timelocked_path(path): key_path = path[:-1] locktime = path[-1] engine = self._TIMELOCK_ENGINE cache = super()._get_cache_for_path(key_path) privkey = cache.get(b'p') - if privkey is None: - privkey = engine.derive_bip32_privkey(self._master_key, key_path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = engine.derive_bip32_privkey(self._master_key, key_path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") return (privkey, locktime), engine else: return super()._get_key_from_path(path) - def _get_keypair_from_path(self, path): + def _get_keypair_from_path(self, path, + validate_cache: bool = False): if not self.is_timelocked_path(path): - return super()._get_keypair_from_path(path) + return super()._get_keypair_from_path(path, + validate_cache=validate_cache) key_path = path[:-1] locktime = path[-1] engine = self._TIMELOCK_ENGINE cache = super()._get_cache_for_path(key_path) privkey = cache.get(b'p') - if privkey is None: - privkey = engine.derive_bip32_privkey(self._master_key, key_path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = engine.derive_bip32_privkey(self._master_key, key_path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") pubkey = cache.get(b'P') - if pubkey is None: - pubkey = engine.privkey_to_pubkey(privkey) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = engine.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return (privkey, locktime), (pubkey, locktime), engine def _get_cache_for_path(self, path): @@ -2740,30 +2800,41 @@ def _get_bip32_export_path(self, mixdepth=None, address_type=None): path = super()._get_bip32_export_path(mixdepth, address_type) return path - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): raise WalletError("Cannot get a private key from a watch-only wallet") - def _get_keypair_from_path(self, path): + def _get_keypair_from_path(self, path, + validate_cache: bool = False): raise WalletError("Cannot get a private key from a watch-only wallet") - def _get_pubkey_from_path(self, path): + def _get_pubkey_from_path(self, path, + validate_cache: bool = False): if not self._is_my_bip32_path(path): - return super()._get_pubkey_from_path(path) + return super()._get_pubkey_from_path(path, + validate_cache=validate_cache) if self.is_timelocked_path(path): key_path = path[:-1] locktime = path[-1] cache = self._get_cache_for_path(key_path) pubkey = cache.get(b'P') - if pubkey is None: - pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey( + if pubkey is None or validate_cache: + new_pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey( self._master_key, key_path) - cache[b'P'] = pubkey + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return (pubkey, locktime), self._TIMELOCK_ENGINE cache = self._get_cache_for_path(path) pubkey = cache.get(b'P') - if pubkey is None: - pubkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = self._ENGINE.derive_bip32_privkey( + self._master_key, path) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return pubkey, self._ENGINE diff --git a/test/jmclient/test_taker.py b/test/jmclient/test_taker.py index da902f396..7067382df 100644 --- a/test/jmclient/test_taker.py +++ b/test/jmclient/test_taker.py @@ -121,7 +121,8 @@ def get_txtype(self): """ return 'p2wpkh' - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if path[0] == b'dummy': return struct.pack(b'B', path[2] + 1)*32 + b'\x01', self._ENGINE raise NotImplementedError() @@ -152,10 +153,12 @@ def is_standard_wallet_script(self, path): return False return True - def script_to_addr(self, script): + def script_to_addr(self, script, + validate_cache: bool = False): if self.script_to_path(script)[0] == "nonstandard_path": return "dummyaddr" - return super().script_to_addr(script) + return super().script_to_addr(script, + validate_cache=validate_cache) def dummy_order_chooser():