diff --git a/chia/consensus/blockchain.py b/chia/consensus/blockchain.py index 71e5d49c58a9..4218ae76c86a 100644 --- a/chia/consensus/blockchain.py +++ b/chia/consensus/blockchain.py @@ -414,7 +414,8 @@ async def add_block( try: # Always add the block to the database - async with self.block_store.db_wrapper.writer(): + # This call should be wrapped with a writer to force all coin reads to the same connection. + async with self.block_store.db_wrapper.writer_maybe_transaction(): # Perform the DB operations to update the state, and rollback if something goes wrong await self.block_store.add_full_block(header_hash, block, block_record) records, state_change_summary = await self._reconsider_peak(block_record, genesis, fork_info) diff --git a/chia/full_node/coin_store.py b/chia/full_node/coin_store.py index b7cdbba85858..dc74361dd8a4 100644 --- a/chia/full_node/coin_store.py +++ b/chia/full_node/coin_store.py @@ -1,6 +1,8 @@ from __future__ import annotations +import asyncio import dataclasses +import inspect import logging import sqlite3 import time @@ -19,7 +21,6 @@ from chia.util.batches import to_batches from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER, DBWrapper2 from chia.util.ints import uint32, uint64 -from chia.util.lru_cache import LRUCache log = logging.getLogger(__name__) @@ -32,13 +33,12 @@ class CoinStore: """ db_wrapper: DBWrapper2 - coins_added_at_height_cache: LRUCache[uint32, list[CoinRecord]] @classmethod async def create(cls, db_wrapper: DBWrapper2) -> CoinStore: if db_wrapper.db_version != 2: raise RuntimeError(f"CoinStore does not support database schema v{db_wrapper.db_version}") - self = CoinStore(db_wrapper, LRUCache(100)) + self = CoinStore(db_wrapper) async with self.db_wrapper.writer_maybe_transaction() as conn: log.info("DB: Creating coin store tables and indexes.") @@ -156,6 +156,20 @@ async def get_coin_records(self, names: Collection[bytes32]) -> list[CoinRecord] coins: list[CoinRecord] = [] async with self.db_wrapper.reader_no_transaction() as conn: + if conn!=self.db_wrapper._write_connection: + task=asyncio.current_task() + log.info( + f"get_coin_records not using _current_writer {task.get_name()} {conn}" + ) + frame = inspect.currentframe().f_back + log.info(f" Trace 1 {inspect.getframeinfo(frame).filename} {frame.f_lineno}") + frame = frame.f_back + log.info(f" Trace 2 {inspect.getframeinfo(frame).filename} {frame.f_lineno}") + frame = frame.f_back + log.info(f" Trace 3 {inspect.getframeinfo(frame).filename} {frame.f_lineno}") + frame = frame.f_back + log.info(f" Trace 4 {inspect.getframeinfo(frame).filename} {frame.f_lineno}") + cursors: list[Cursor] = [] for batch in to_batches(names, SQLITE_MAX_VARIABLE_NUMBER): names_db: tuple[Any, ...] = tuple(batch.entries) @@ -177,10 +191,6 @@ async def get_coin_records(self, names: Collection[bytes32]) -> list[CoinRecord] return coins async def get_coins_added_at_height(self, height: uint32) -> list[CoinRecord]: - coins_added: Optional[list[CoinRecord]] = self.coins_added_at_height_cache.get(height) - if coins_added is not None: - return coins_added - async with self.db_wrapper.reader_no_transaction() as conn: async with conn.execute( "SELECT confirmed_index, spent_index, coinbase, puzzle_hash, " @@ -192,7 +202,6 @@ async def get_coins_added_at_height(self, height: uint32) -> list[CoinRecord]: for row in rows: coin = self.row_to_coin(row) coins.append(CoinRecord(coin, row[0], row[1], row[2], row[6])) - self.coins_added_at_height_cache.put(height, coins) return coins async def get_coins_removed_at_height(self, height: uint32) -> list[CoinRecord]: @@ -566,7 +575,6 @@ async def rollback_to_block(self, block_index: int) -> list[CoinRecord]: coin_changes[record.name] = record await conn.execute("UPDATE coin_record SET spent_index=0 WHERE spent_index>?", (block_index,)) - self.coins_added_at_height_cache = LRUCache(self.coins_added_at_height_cache.capacity) return list(coin_changes.values()) # Store CoinRecord in DB diff --git a/chia/full_node/full_node.py b/chia/full_node/full_node.py index de0c9f3a7671..13a95d447309 100644 --- a/chia/full_node/full_node.py +++ b/chia/full_node/full_node.py @@ -317,14 +317,17 @@ async def manage(self) -> AsyncIterator[None]: ) async with self.blockchain.priority_mutex.acquire(priority=BlockchainMutexPriority.high): pending_tx = await self.mempool_manager.new_peak(self.blockchain.get_tx_peak(), None) - assert len(pending_tx.items) == 0 # no pending transactions when starting up + assert len(pending_tx.items) == 0 # no pending transactions when starting up - full_peak: Optional[FullBlock] = await self.blockchain.get_full_peak() - assert full_peak is not None - state_change_summary = StateChangeSummary(peak, uint32(max(peak.height - 1, 0)), [], [], [], []) - ppp_result: PeakPostProcessingResult = await self.peak_post_processing( - full_peak, state_change_summary, None - ) + full_peak: Optional[FullBlock] = await self.blockchain.get_full_peak() + assert full_peak is not None + state_change_summary = StateChangeSummary(peak, uint32(max(peak.height - 1, 0)), [], [], [], []) + + # Must be called under priority_mutex + ppp_result: PeakPostProcessingResult = await self.peak_post_processing( + full_peak, state_change_summary, None + ) + # Can be used outside of priority_mutex await self.peak_post_processing_2(full_peak, None, state_change_summary, ppp_result) if self.config["send_uncompact_interval"] != 0: sanitize_weight_proof_only = False @@ -626,6 +629,7 @@ async def short_sync_batch(self, peer: WSChiaConnection, start_height: uint32, t response = await peer.call_api(FullNodeAPI.request_blocks, request) if not response: raise ValueError(f"Error short batch syncing, invalid/no response for {height}-{end_height}") + async with self.blockchain.priority_mutex.acquire(priority=BlockchainMutexPriority.high): state_change_summary: Optional[StateChangeSummary] prev_b = None @@ -637,9 +641,13 @@ async def short_sync_batch(self, peer: WSChiaConnection, start_height: uint32, t self.constants, new_slot, prev_b, self.blockchain ) vs = ValidationState(ssi, diff, None) - success, state_change_summary = await self.add_block_batch( - response.blocks, peer_info, fork_info, vs - ) + + # Wrap add_block_batch with writer to ensure all writes and reads are on same connection. + # add_block_batch should only be called under priority_mutex so this will not deadlock. + async with self.block_store.db_wrapper.writer() as conn: + success, state_change_summary = await self.add_block_batch( + response.blocks, peer_info, fork_info, vs + ) if not success: raise ValueError(f"Error short batch syncing, failed to validate blocks {height}-{end_height}") if state_change_summary is not None: @@ -651,7 +659,6 @@ async def short_sync_batch(self, peer: WSChiaConnection, start_height: uint32, t state_change_summary, peer, ) - await self.peak_post_processing_2(peak_fb, peer, state_change_summary, ppp_result) except Exception: # Still do post processing after cancel (or exception) peak_fb = await self.blockchain.get_full_peak() @@ -660,6 +667,9 @@ async def short_sync_batch(self, peer: WSChiaConnection, start_height: uint32, t raise finally: self.log.info(f"Added blocks {height}-{end_height}") + if state_change_summary is not None and peak_fb is not None: + # Call outside of priority_mutex to encourage concurrency + await self.peak_post_processing_2(peak_fb, peer, state_change_summary, ppp_result) finally: self.sync_store.batch_syncing.remove(peer.peer_node_id) return True @@ -1352,16 +1362,20 @@ async def ingest_blocks( block_rate_height = start_height pre_validation_results = list(await asyncio.gather(*futures)) - # The ValidationState object (vs) is an in-out parameter. the add_block_batch() - # call will update it - state_change_summary, err = await self.add_prevalidated_blocks( - blockchain, - blocks, - pre_validation_results, - fork_info, - peer.peer_info, - vs, - ) + + # Wrap add_prevalidated_blocks with writer to ensure all writes and reads are on same connection. + # add_prevalidated_blocks should only be called under priority_mutex so this will not deadlock. + async with self.block_store.db_wrapper.writer() as conn: + # The ValidationState object (vs) is an in-out parameter. the add_block_batch() + # call will update it + state_change_summary, err = await self.add_prevalidated_blocks( + blockchain, + blocks, + pre_validation_results, + fork_info, + peer.peer_info, + vs, + ) if err is not None: await peer.close(600) raise ValueError(f"Failed to validate block batch {start_height} to {end_height}: {err}") @@ -1731,7 +1745,10 @@ async def _finish_sync(self, fork_point: Optional[uint32]) -> None: ppp_result: PeakPostProcessingResult = await self.peak_post_processing( peak_fb, state_change_summary, None ) - await self.peak_post_processing_2(peak_fb, None, state_change_summary, ppp_result) + + if peak_fb is not None: + # Call outside of priority_mutex to encourage concurrency + await self.peak_post_processing_2(peak_fb, None, state_change_summary, ppp_result) if peak is not None and self.weight_proof_handler is not None: await self.weight_proof_handler.get_proof_of_weight(peak.header_hash) @@ -2083,6 +2100,9 @@ async def add_block( ppp_result: Optional[PeakPostProcessingResult] = None async with ( self.blockchain.priority_mutex.acquire(priority=BlockchainMutexPriority.high), + # Wrap with writer to ensure all writes and reads are on same connection. + # add_prevalidated_blocks should only be called under priority_mutex so this will not deadlock. + self.block_store.db_wrapper.writer(), enable_profiler(self.profile_block_validation) as pr, ): # After acquiring the lock, check again, because another asyncio thread might have added it diff --git a/chia/util/db_wrapper.py b/chia/util/db_wrapper.py index b1d2b05e07b0..4fa7323fb649 100644 --- a/chia/util/db_wrapper.py +++ b/chia/util/db_wrapper.py @@ -74,7 +74,8 @@ async def _create_connection( log_file: Optional[TextIO] = None, name: Optional[str] = None, ) -> aiosqlite.Connection: - connection = await aiosqlite.connect(database=database, uri=uri) + # To avoid https://github.com/python/cpython/issues/118172 + connection = await aiosqlite.connect(database=database, uri=uri, cached_statements=0) if log_file is not None: await connection.set_trace_callback(functools.partial(sql_trace_callback, file=log_file, name=name))