Skip to content

Commit

Permalink
add in_main_chain=1 to the SQL query, that just asks for heights (#18932
Browse files Browse the repository at this point in the history
)

extend unit test for BlockStore and add in_main_chain=1 to the SQL query, that just asks for heights
  • Loading branch information
arvidn authored Dec 3, 2024
1 parent 98f7f88 commit df52ee6
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 10 deletions.
106 changes: 102 additions & 4 deletions chia/_tests/core/full_node/stores/test_block_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from chia._tests.blockchain.blockchain_test_utils import _validate_and_add_block
from chia._tests.util.db_connection import DBConnection, PathDBConnection
from chia.consensus.blockchain import Blockchain
from chia.consensus.blockchain import AddBlockResult, Blockchain
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.consensus.full_block_to_block_record import header_block_to_sub_block_record
from chia.full_node.block_store import BlockStore
Expand Down Expand Up @@ -133,6 +133,87 @@ async def test_block_store(tmp_dir: Path, db_version: int, bt: BlockTools, use_c
assert br.header_hash == b.header_hash


@pytest.mark.limit_consensus_modes(reason="save time")
@pytest.mark.anyio
async def test_get_full_blocks_at(
tmp_dir: Path, db_version: int, bt: BlockTools, use_cache: bool, default_400_blocks: list[FullBlock]
) -> None:
blocks = bt.get_consecutive_blocks(10)
alt_blocks = default_400_blocks[:10]

async with DBConnection(2) as db_wrapper:
# Use a different file for the blockchain
coin_store = await CoinStore.create(db_wrapper)
block_store = await BlockStore.create(db_wrapper, use_cache=use_cache)
bc = await Blockchain.create(coin_store, block_store, bt.constants, tmp_dir, 2)

count = 0
for b1, b2 in zip(blocks, alt_blocks):
await _validate_and_add_block(bc, b1)
await _validate_and_add_block(bc, b2, expected_result=AddBlockResult.ADDED_AS_ORPHAN)
ret = await block_store.get_full_blocks_at([uint32(count)])
assert set(ret) == set([b1, b2])
count += 1
ret = await block_store.get_full_blocks_at([uint32(c) for c in range(count)])
assert len(ret) == count * 2
assert set(ret) == set(blocks[:count] + alt_blocks[:count])


@pytest.mark.limit_consensus_modes(reason="save time")
@pytest.mark.anyio
async def test_get_block_records_in_range(
bt: BlockTools, tmp_dir: Path, use_cache: bool, default_400_blocks: list[FullBlock]
) -> None:
blocks = bt.get_consecutive_blocks(10)
alt_blocks = default_400_blocks[:10]

async with DBConnection(2) as db_wrapper:
# Use a different file for the blockchain
coin_store = await CoinStore.create(db_wrapper)
block_store = await BlockStore.create(db_wrapper, use_cache=use_cache)
bc = await Blockchain.create(coin_store, block_store, bt.constants, tmp_dir, 2)

count = 0
for b1, b2 in zip(blocks, alt_blocks):
await _validate_and_add_block(bc, b1)
await _validate_and_add_block(bc, b2, expected_result=AddBlockResult.ADDED_AS_ORPHAN)
# the range is inclusive
ret = await block_store.get_block_records_in_range(count, count)
assert len(ret) == 1
assert b1.header_hash in ret
ret = await block_store.get_block_records_in_range(0, count)
count += 1
assert len(ret) == count
assert list(ret.keys()) == [b.header_hash for b in blocks[:count]]


@pytest.mark.limit_consensus_modes(reason="save time")
@pytest.mark.anyio
async def test_get_block_bytes_in_range_in_main_chain(
bt: BlockTools, tmp_dir: Path, use_cache: bool, default_400_blocks: list[FullBlock]
) -> None:
blocks = bt.get_consecutive_blocks(10)
alt_blocks = default_400_blocks[:10]

async with DBConnection(2) as db_wrapper:
# Use a different file for the blockchain
coin_store = await CoinStore.create(db_wrapper)
block_store = await BlockStore.create(db_wrapper, use_cache=use_cache)
bc = await Blockchain.create(coin_store, block_store, bt.constants, tmp_dir, 2)

count = 0
for b1, b2 in zip(blocks, alt_blocks):
await _validate_and_add_block(bc, b1)
await _validate_and_add_block(bc, b2, expected_result=AddBlockResult.ADDED_AS_ORPHAN)
# the range is inclusive
ret = await block_store.get_block_bytes_in_range(count, count)
assert ret == [bytes(b1)]
ret = await block_store.get_block_bytes_in_range(0, count)
count += 1
assert len(ret) == count
assert set(ret) == set([bytes(b) for b in blocks[:count]])


@pytest.mark.limit_consensus_modes(reason="save time")
@pytest.mark.anyio
async def test_deadlock(tmp_dir: Path, db_version: int, bt: BlockTools, use_cache: bool) -> None:
Expand Down Expand Up @@ -168,8 +249,9 @@ async def test_deadlock(tmp_dir: Path, db_version: int, bt: BlockTools, use_cach

@pytest.mark.limit_consensus_modes(reason="save time")
@pytest.mark.anyio
async def test_rollback(bt: BlockTools, tmp_dir: Path, use_cache: bool) -> None:
async def test_rollback(bt: BlockTools, tmp_dir: Path, use_cache: bool, default_400_blocks: list[FullBlock]) -> None:
blocks = bt.get_consecutive_blocks(10)
alt_blocks = default_400_blocks[:10]

async with DBConnection(2) as db_wrapper:
# Use a different file for the blockchain
Expand All @@ -179,8 +261,9 @@ async def test_rollback(bt: BlockTools, tmp_dir: Path, use_cache: bool) -> None:

# insert all blocks
count = 0
for block in blocks:
await _validate_and_add_block(bc, block)
for b1, b2 in zip(blocks, alt_blocks):
await _validate_and_add_block(bc, b1)
await _validate_and_add_block(bc, b2, expected_result=AddBlockResult.ADDED_AS_ORPHAN)
count += 1
ret = await block_store.get_random_not_compactified(count)
assert len(ret) == count
Expand All @@ -195,6 +278,13 @@ async def test_rollback(bt: BlockTools, tmp_dir: Path, use_cache: bool) -> None:
rows = list(await cursor.fetchall())
assert len(rows) == 1
assert rows[0][0]
for block in alt_blocks:
async with conn.execute(
"SELECT in_main_chain FROM full_blocks WHERE header_hash=?", (block.header_hash,)
) as cursor:
rows = list(await cursor.fetchall())
assert len(rows) == 1
assert not rows[0][0]

await block_store.rollback(5)

Expand All @@ -210,6 +300,14 @@ async def test_rollback(bt: BlockTools, tmp_dir: Path, use_cache: bool) -> None:
assert len(rows) == 1
assert rows[0][0] == (count <= 5)
count += 1
for block in alt_blocks:
async with conn.execute(
"SELECT in_main_chain FROM full_blocks WHERE header_hash=? ORDER BY height",
(block.header_hash,),
) as cursor:
rows = list(await cursor.fetchall())
assert len(rows) == 1
assert not rows[0][0]


@pytest.mark.limit_consensus_modes(reason="save time")
Expand Down
18 changes: 12 additions & 6 deletions chia/full_node/block_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ async def get_full_block_bytes(self, header_hash: bytes32) -> Optional[bytes]:
return None

async def get_full_blocks_at(self, heights: list[uint32]) -> list[FullBlock]:
"""
Returns all blocks at the given heights, including orphans.
"""
if len(heights) == 0:
return []

Expand Down Expand Up @@ -439,13 +442,15 @@ async def get_block_records_in_range(
) -> dict[bytes32, BlockRecord]:
"""
Returns a dictionary with all blocks in range between start and stop
if present.
if present. Only blocks part of the main chain/current peak are returned.
i.e. No orphan blocks
"""

ret: dict[bytes32, BlockRecord] = {}
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT header_hash,block_record FROM full_blocks WHERE height >= ? AND height <= ?",
"SELECT header_hash,block_record FROM full_blocks "
"WHERE height >= ? AND height <= ? AND in_main_chain=1",
(start, stop),
) as cursor:
for row in await cursor.fetchall():
Expand All @@ -462,13 +467,14 @@ async def get_block_bytes_in_range(
) -> list[bytes]:
"""
Returns a list with all full blocks in range between start and stop
if present.
if present. Only includes blocks in the main chain, in the current peak.
No orphan blocks.
"""

assert self.db_wrapper.db_version == 2
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT block FROM full_blocks WHERE height >= ? AND height <= ? and in_main_chain=1",
"SELECT block FROM full_blocks WHERE height >= ? AND height <= ? AND in_main_chain=1",
(start, stop),
) as cursor:
rows: list[sqlite3.Row] = list(await cursor.fetchall())
Expand All @@ -494,7 +500,7 @@ async def get_block_records_close_to_peak(
) -> tuple[dict[bytes32, BlockRecord], Optional[bytes32]]:
"""
Returns a dictionary with all blocks that have height >= peak height - blocks_n, as well as the
peak header hash.
peak header hash. Only blocks that are part of the main chain/current peak are included.
"""

peak = await self.get_peak()
Expand All @@ -504,7 +510,7 @@ async def get_block_records_close_to_peak(
ret: dict[bytes32, BlockRecord] = {}
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT header_hash, block_record FROM full_blocks WHERE height >= ?",
"SELECT header_hash, block_record FROM full_blocks WHERE height >= ? AND in_main_chain=1",
(peak[1] - blocks_n,),
) as cursor:
for row in await cursor.fetchall():
Expand Down

0 comments on commit df52ee6

Please sign in to comment.