Skip to content

Commit

Permalink
[squashed] Tree hash caching and optimisations for Altair #2459
Browse files Browse the repository at this point in the history
Squashed commit of the following:

commit c6893f7
Author: Michael Sproul <[email protected]>
Date:   Wed Jul 14 18:37:26 2021 +1000

    Tree hash caching and optimisations for Altair
  • Loading branch information
michaelsproul committed Jul 15, 2021
1 parent 73b7de6 commit 21c1fcb
Show file tree
Hide file tree
Showing 11 changed files with 206 additions and 39 deletions.
1 change: 0 additions & 1 deletion beacon_node/beacon_chain/src/beacon_fork_choice_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ where
.deconstruct()
.0;

// FIXME(altair): could remove clone with by-value `balances` accessor
self.justified_balances = self
.store
.get_state(&justified_block.state_root(), Some(justified_block.slot()))
Expand Down
1 change: 0 additions & 1 deletion beacon_node/http_api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,6 @@ pub fn serve<T: BeaconChainTypes>(
blocking_json_task(move || {
block_id
.block(&chain)
// FIXME(altair): could avoid clone with by-value accessor
.map(|block| block.message().body().attestations().clone())
.map(api_types::GenericResponse::from)
})
Expand Down
3 changes: 1 addition & 2 deletions consensus/cached_tree_hash/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ pub type CacheArena = cache_arena::CacheArena<Hash256>;
pub use crate::cache::TreeHashCache;
pub use crate::impls::int_log;
use ethereum_types::H256 as Hash256;
use tree_hash::TreeHash;

#[derive(Debug, PartialEq, Clone)]
pub enum Error {
Expand All @@ -34,7 +33,7 @@ impl From<cache_arena::Error> for Error {
}

/// Trait for types which can make use of a cache to accelerate calculation of their tree hash root.
pub trait CachedTreeHash<Cache>: TreeHash {
pub trait CachedTreeHash<Cache> {
/// Create a new cache appropriate for use with values of this type.
fn new_tree_hash_cache(&self, arena: &mut CacheArena) -> Cache;

Expand Down
2 changes: 1 addition & 1 deletion consensus/state_processing/src/per_block_processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ pub fn per_block_processing<T: EthSpec>(

process_randao(state, block, verify_signatures, spec)?;
process_eth1_data(state, block.body().eth1_data())?;
process_operations(state, block.body(), verify_signatures, spec)?;
process_operations(state, block.body(), proposer_index, verify_signatures, spec)?;

if let BeaconBlockRef::Altair(inner) = block {
process_sync_aggregate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use types::consts::altair::{PARTICIPATION_FLAG_WEIGHTS, PROPOSER_WEIGHT, WEIGHT_
pub fn process_operations<'a, T: EthSpec>(
state: &mut BeaconState<T>,
block_body: BeaconBlockBodyRef<'a, T>,
proposer_index: u64,
verify_signatures: VerifySignatures,
spec: &ChainSpec,
) -> Result<(), BlockProcessingError> {
Expand All @@ -26,7 +27,7 @@ pub fn process_operations<'a, T: EthSpec>(
verify_signatures,
spec,
)?;
process_attestations(state, block_body, verify_signatures, spec)?;
process_attestations(state, block_body, proposer_index, verify_signatures, spec)?;
process_deposits(state, block_body.deposits(), spec)?;
process_exits(state, block_body.voluntary_exits(), verify_signatures, spec)?;
Ok(())
Expand Down Expand Up @@ -85,21 +86,30 @@ pub mod altair {
pub fn process_attestations<T: EthSpec>(
state: &mut BeaconState<T>,
attestations: &[Attestation<T>],
proposer_index: u64,
verify_signatures: VerifySignatures,
spec: &ChainSpec,
) -> Result<(), BlockProcessingError> {
attestations
.iter()
.enumerate()
.try_for_each(|(i, attestation)| {
process_attestation(state, attestation, i, verify_signatures, spec)
process_attestation(
state,
attestation,
i,
proposer_index,
verify_signatures,
spec,
)
})
}

pub fn process_attestation<T: EthSpec>(
state: &mut BeaconState<T>,
attestation: &Attestation<T>,
att_index: usize,
proposer_index: u64,
verify_signatures: VerifySignatures,
spec: &ChainSpec,
) -> Result<(), BlockProcessingError> {
Expand Down Expand Up @@ -145,9 +155,7 @@ pub mod altair {
.safe_mul(WEIGHT_DENOMINATOR)?
.safe_div(PROPOSER_WEIGHT)?;
let proposer_reward = proposer_reward_numerator.safe_div(proposer_reward_denominator)?;
// FIXME(altair): optimise by passing in proposer_index
let proposer_index = state.get_beacon_proposer_index(state.slot(), spec)?;
increase_balance(state, proposer_index, proposer_reward)?;
increase_balance(state, proposer_index as usize, proposer_reward)?;
Ok(())
}
}
Expand Down Expand Up @@ -212,6 +220,7 @@ pub fn process_attester_slashings<T: EthSpec>(
pub fn process_attestations<'a, T: EthSpec>(
state: &mut BeaconState<T>,
block_body: BeaconBlockBodyRef<'a, T>,
proposer_index: u64,
verify_signatures: VerifySignatures,
spec: &ChainSpec,
) -> Result<(), BlockProcessingError> {
Expand All @@ -223,6 +232,7 @@ pub fn process_attestations<'a, T: EthSpec>(
altair::process_attestations(
state,
block_body.attestations(),
proposer_index,
verify_signatures,
spec,
)?;
Expand Down
7 changes: 7 additions & 0 deletions consensus/state_processing/src/per_block_processing/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ fn invalid_attestation_no_committee_for_index() {
let result = process_operations::process_attestations(
&mut state,
head_block.body(),
head_block.proposer_index(),
VerifySignatures::True,
&spec,
);
Expand Down Expand Up @@ -368,6 +369,7 @@ fn invalid_attestation_wrong_justified_checkpoint() {
let result = process_operations::process_attestations(
&mut state,
head_block.body(),
head_block.proposer_index(),
VerifySignatures::True,
&spec,
);
Expand Down Expand Up @@ -400,6 +402,7 @@ fn invalid_attestation_bad_aggregation_bitfield_len() {
let result = process_operations::process_attestations(
&mut state,
head_block.body(),
head_block.proposer_index(),
VerifySignatures::True,
&spec,
);
Expand All @@ -425,6 +428,7 @@ fn invalid_attestation_bad_signature() {
let result = process_operations::process_attestations(
&mut state,
head_block.body(),
head_block.proposer_index(),
VerifySignatures::True,
&spec,
);
Expand Down Expand Up @@ -456,6 +460,7 @@ fn invalid_attestation_included_too_early() {
let result = process_operations::process_attestations(
&mut state,
head_block.body(),
head_block.proposer_index(),
VerifySignatures::True,
&spec,
);
Expand Down Expand Up @@ -491,6 +496,7 @@ fn invalid_attestation_included_too_late() {
let result = process_operations::process_attestations(
&mut state,
head_block.body(),
head_block.proposer_index(),
VerifySignatures::True,
&spec,
);
Expand Down Expand Up @@ -522,6 +528,7 @@ fn invalid_attestation_target_epoch_slot_mismatch() {
let result = process_operations::process_attestations(
&mut state,
head_block.body(),
head_block.proposer_index(),
VerifySignatures::True,
&spec,
);
Expand Down
138 changes: 112 additions & 26 deletions consensus/types/src/beacon_state/tree_hash_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#![allow(clippy::indexing_slicing)]

use super::Error;
use crate::{BeaconState, EthSpec, Hash256, Slot, Unsigned, Validator};
use crate::{
BeaconState, EthSpec, Hash256, ParticipationFlags, ParticipationList, Slot, Unsigned, Validator,
};
use cached_tree_hash::{int_log, CacheArena, CachedTreeHash, TreeHashCache};
use rayon::prelude::*;
use ssz_derive::{Decode, Encode};
Expand Down Expand Up @@ -139,6 +141,9 @@ pub struct BeaconTreeHashCacheInner<T: EthSpec> {
randao_mixes: TreeHashCache,
slashings: TreeHashCache,
eth1_data_votes: Eth1DataVotesTreeHashCache<T>,
// Participation caches
previous_epoch_participation: ParticipationTreeHashCache,
current_epoch_participation: ParticipationTreeHashCache,
}

impl<T: EthSpec> BeaconTreeHashCacheInner<T> {
Expand All @@ -163,6 +168,11 @@ impl<T: EthSpec> BeaconTreeHashCacheInner<T> {
let mut slashings_arena = CacheArena::default();
let slashings = state.slashings().new_tree_hash_cache(&mut slashings_arena);

let previous_epoch_participation =
ParticipationTreeHashCache::new(state, BeaconState::previous_epoch_participation);
let current_epoch_participation =
ParticipationTreeHashCache::new(state, BeaconState::current_epoch_participation);

Self {
previous_state: None,
validators,
Expand All @@ -176,6 +186,8 @@ impl<T: EthSpec> BeaconTreeHashCacheInner<T> {
randao_mixes,
slashings,
eth1_data_votes: Eth1DataVotesTreeHashCache::new(state),
previous_epoch_participation,
current_epoch_participation,
}
}

Expand Down Expand Up @@ -264,31 +276,25 @@ impl<T: EthSpec> BeaconTreeHashCacheInner<T> {
)?;

// Participation
match state {
BeaconState::Base(state) => {
hasher.write(
state
.previous_epoch_attestations
.tree_hash_root()
.as_bytes(),
)?;
hasher.write(state.current_epoch_attestations.tree_hash_root().as_bytes())?;
}
// FIXME(altair): add a cache to accelerate hashing of these fields
BeaconState::Altair(state) => {
hasher.write(
state
.previous_epoch_participation
.tree_hash_root()
.as_bytes(),
)?;
hasher.write(
state
.current_epoch_participation
.tree_hash_root()
.as_bytes(),
)?;
}
if let BeaconState::Base(state) = state {
hasher.write(
state
.previous_epoch_attestations
.tree_hash_root()
.as_bytes(),
)?;
hasher.write(state.current_epoch_attestations.tree_hash_root().as_bytes())?;
} else {
hasher.write(
self.previous_epoch_participation
.recalculate_tree_hash_root(state.previous_epoch_participation()?)?
.as_bytes(),
)?;
hasher.write(
self.current_epoch_participation
.recalculate_tree_hash_root(state.current_epoch_participation()?)?
.as_bytes(),
)?;
}

hasher.write(state.justification_bits().tree_hash_root().as_bytes())?;
Expand Down Expand Up @@ -506,6 +512,60 @@ impl ParallelValidatorTreeHash {
}
}

#[derive(Debug, PartialEq, Clone)]
pub struct ParticipationTreeHashCache {
inner: Option<ParticipationTreeHashCacheInner>,
}

#[derive(Debug, PartialEq, Clone)]
pub struct ParticipationTreeHashCacheInner {
arena: CacheArena,
tree_hash_cache: TreeHashCache,
}

impl ParticipationTreeHashCache {
/// Initialize a new cache for the participation list returned by `field` (if any).
fn new<T: EthSpec>(
state: &BeaconState<T>,
field: impl FnOnce(
&BeaconState<T>,
) -> Result<
&VariableList<ParticipationFlags, T::ValidatorRegistryLimit>,
Error,
>,
) -> Self {
let inner = field(state).map(ParticipationTreeHashCacheInner::new).ok();
Self { inner }
}

/// Compute the tree hash root for the given `epoch_participation`.
///
/// This function will initialize the inner cache if necessary (e.g. when crossing the fork).
fn recalculate_tree_hash_root<N: Unsigned>(
&mut self,
epoch_participation: &VariableList<ParticipationFlags, N>,
) -> Result<Hash256, Error> {
let cache = self
.inner
.get_or_insert_with(|| ParticipationTreeHashCacheInner::new(epoch_participation));
ParticipationList::new(epoch_participation)
.recalculate_tree_hash_root(&mut cache.arena, &mut cache.tree_hash_cache)
.map_err(Into::into)
}
}

impl ParticipationTreeHashCacheInner {
fn new<N: Unsigned>(epoch_participation: &VariableList<ParticipationFlags, N>) -> Self {
let mut arena = CacheArena::default();
let tree_hash_cache =
ParticipationList::new(epoch_participation).new_tree_hash_cache(&mut arena);
ParticipationTreeHashCacheInner {
arena,
tree_hash_cache,
}
}
}

#[cfg(feature = "arbitrary-fuzz")]
impl<T: EthSpec> arbitrary::Arbitrary for BeaconTreeHashCache<T> {
fn arbitrary(_u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
Expand All @@ -516,6 +576,7 @@ impl<T: EthSpec> arbitrary::Arbitrary for BeaconTreeHashCache<T> {
#[cfg(test)]
mod test {
use super::*;
use crate::MainnetEthSpec;

#[test]
fn validator_node_count() {
Expand All @@ -524,4 +585,29 @@ mod test {
let _cache = v.new_tree_hash_cache(&mut arena);
assert_eq!(arena.backing_len(), NODES_PER_VALIDATOR);
}

#[test]
fn participation_flags() {
type N = <MainnetEthSpec as EthSpec>::ValidatorRegistryLimit;
let len = 65;
let mut test_flag = ParticipationFlags::default();
test_flag.add_flag(0).unwrap();
let epoch_participation = VariableList::<_, N>::new(vec![test_flag; len]).unwrap();

let mut cache = ParticipationTreeHashCache { inner: None };

let cache_root = cache
.recalculate_tree_hash_root(&epoch_participation)
.unwrap();
let recalc_root = cache
.recalculate_tree_hash_root(&epoch_participation)
.unwrap();

assert_eq!(cache_root, recalc_root, "recalculated root should match");
assert_eq!(
cache_root,
epoch_participation.tree_hash_root(),
"cached root should match uncached"
);
}
}
2 changes: 2 additions & 0 deletions consensus/types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub mod slot_epoch_macros;
pub mod config_and_preset;
pub mod fork_context;
pub mod participation_flags;
pub mod participation_list;
pub mod preset;
pub mod slot_epoch;
pub mod subnet_id;
Expand Down Expand Up @@ -117,6 +118,7 @@ pub use crate::graffiti::{Graffiti, GRAFFITI_BYTES_LEN};
pub use crate::historical_batch::HistoricalBatch;
pub use crate::indexed_attestation::IndexedAttestation;
pub use crate::participation_flags::ParticipationFlags;
pub use crate::participation_list::ParticipationList;
pub use crate::pending_attestation::PendingAttestation;
pub use crate::preset::{AltairPreset, BasePreset};
pub use crate::proposer_slashing::ProposerSlashing;
Expand Down
4 changes: 4 additions & 0 deletions consensus/types/src/participation_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ impl ParticipationFlags {
let mask = 1u8.safe_shl(flag_index as u32)?;
Ok(self.bits & mask == mask)
}

pub fn into_u8(self) -> u8 {
self.bits
}
}

/// Decode implementation that transparently behaves like the inner `u8`.
Expand Down
Loading

0 comments on commit 21c1fcb

Please sign in to comment.