From 7ca2adea3d8030f8a36991ab75d7c1a9962bd231 Mon Sep 17 00:00:00 2001 From: Jakub Zajkowski Date: Sun, 3 Nov 2024 22:36:00 +0100 Subject: [PATCH] Adding checks on GetRequest::Record check to make sure that we don't pass empty key fetches to lmdb storage (which results in a non-recoverable crash). Also changed DataReader and VersionedDatabases to reflect the same check. --- binary_port/Cargo.toml | 4 +- binary_port/src/record_id.rs | 9 +++ node/src/components/binary_port.rs | 6 ++ node/src/components/binary_port/tests.rs | 66 ++++++++++++++++--- .../lmdb/indexed_lmdb_block_store.rs | 3 + .../block_store/lmdb/versioned_databases.rs | 12 ++++ 6 files changed, 90 insertions(+), 10 deletions(-) diff --git a/binary_port/Cargo.toml b/binary_port/Cargo.toml index 4f8940b532..f53f98b930 100644 --- a/binary_port/Cargo.toml +++ b/binary_port/Cargo.toml @@ -21,13 +21,13 @@ schemars = { version = "0.8.16", features = ["preserve_order", "impl_json_schema bincode = "1.3.3" rand = "0.8.3" tokio-util = { version = "0.6.4", features = ["codec"] } +strum = "0.26.2" +strum_macros = "0.26.4" [dev-dependencies] casper-types = { path = "../types", features = ["datasize", "json-schema", "std", "testing"] } serde_json = "1" serde_test = "1" -strum = "0.26.2" -strum_macros = "0.26.4" [package.metadata.docs.rs] all-features = true diff --git a/binary_port/src/record_id.rs b/binary_port/src/record_id.rs index 649a574fea..9721b64391 100644 --- a/binary_port/src/record_id.rs +++ b/binary_port/src/record_id.rs @@ -6,10 +6,15 @@ use serde::Serialize; #[cfg(test)] use casper_types::testing::TestRng; +#[cfg(any(feature = "testing", test))] +use strum::IntoEnumIterator; +#[cfg(any(feature = "testing", test))] +use strum_macros::EnumIter; /// An identifier of a record type. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Serialize)] #[repr(u16)] +#[cfg_attr(any(feature = "testing", test), derive(EnumIter))] pub enum RecordId { /// Refers to `BlockHeader` record. BlockHeader = 0, @@ -44,6 +49,10 @@ impl RecordId { _ => unreachable!(), } } + #[cfg(any(feature = "testing", test))] + pub fn all() -> impl Iterator { + RecordId::iter() + } } impl TryFrom for RecordId { diff --git a/node/src/components/binary_port.rs b/node/src/components/binary_port.rs index 9522318e41..bce4dcb165 100644 --- a/node/src/components/binary_port.rs +++ b/node/src/components/binary_port.rs @@ -247,6 +247,9 @@ where key, } if RecordId::try_from(record_type_tag) == Ok(RecordId::Transfer) => { metrics.binary_port_get_record_count.inc(); + if key.is_empty() { + return BinaryResponse::new_empty(protocol_version); + } let Ok(block_hash) = bytesrepr::deserialize_from_slice(&key) else { debug!("received an incorrectly serialized key for a transfer record"); return BinaryResponse::new_error(ErrorCode::BadRequest, protocol_version); @@ -267,6 +270,9 @@ where key, } => { metrics.binary_port_get_record_count.inc(); + if key.is_empty() { + return BinaryResponse::new_empty(protocol_version); + } match RecordId::try_from(record_type_tag) { Ok(record_id) => { let Some(db_bytes) = effect_builder.get_raw_data(record_id, key).await else { diff --git a/node/src/components/binary_port/tests.rs b/node/src/components/binary_port/tests.rs index 000254fa68..539cd28779 100644 --- a/node/src/components/binary_port/tests.rs +++ b/node/src/components/binary_port/tests.rs @@ -1,11 +1,13 @@ use std::fmt::{self, Display, Formatter}; use derive_more::From; +use either::Either; use rand::Rng; use serde::Serialize; use casper_binary_port::{ BinaryRequest, BinaryResponse, GetRequest, GlobalStateEntityQualifier, GlobalStateRequest, + RecordId, }; use casper_types::{ @@ -55,7 +57,7 @@ struct TestCase { allow_request_get_all_values: bool, allow_request_get_trie: bool, allow_request_speculative_exec: bool, - request_generator: fn(&mut TestRng) -> BinaryRequest, + request_generator: Either BinaryRequest, BinaryRequest>, } #[tokio::test] @@ -66,21 +68,21 @@ async fn should_enqueue_requests_for_enabled_functions() { allow_request_get_all_values: ENABLED, allow_request_get_trie: rng.gen(), allow_request_speculative_exec: rng.gen(), - request_generator: |_| all_values_request(), + request_generator: Either::Left(|_| all_values_request()), }; let get_trie_enabled = TestCase { allow_request_get_all_values: rng.gen(), allow_request_get_trie: ENABLED, allow_request_speculative_exec: rng.gen(), - request_generator: |_| trie_request(), + request_generator: Either::Left(|_| trie_request()), }; let try_speculative_exec_enabled = TestCase { allow_request_get_all_values: rng.gen(), allow_request_get_trie: rng.gen(), allow_request_speculative_exec: ENABLED, - request_generator: try_speculative_exec_request, + request_generator: Either::Left(try_speculative_exec_request), }; for test_case in [ @@ -110,21 +112,21 @@ async fn should_return_error_for_disabled_functions() { allow_request_get_all_values: DISABLED, allow_request_get_trie: rng.gen(), allow_request_speculative_exec: rng.gen(), - request_generator: |_| all_values_request(), + request_generator: Either::Left(|_| all_values_request()), }; let get_trie_disabled = TestCase { allow_request_get_all_values: rng.gen(), allow_request_get_trie: DISABLED, allow_request_speculative_exec: rng.gen(), - request_generator: |_| trie_request(), + request_generator: Either::Left(|_| trie_request()), }; let try_speculative_exec_disabled = TestCase { allow_request_get_all_values: rng.gen(), allow_request_get_trie: rng.gen(), allow_request_speculative_exec: DISABLED, - request_generator: try_speculative_exec_request, + request_generator: Either::Left(try_speculative_exec_request), }; for test_case in [ @@ -148,6 +150,38 @@ async fn should_return_error_for_disabled_functions() { } } +#[tokio::test] +async fn should_return_empty_response_when_fetching_empty_key() { + let mut rng = TestRng::new(); + + let test_cases: Vec = record_requests_with_empty_keys() + .into_iter() + .map(|request| TestCase { + allow_request_get_all_values: DISABLED, + allow_request_get_trie: DISABLED, + allow_request_speculative_exec: DISABLED, + request_generator: Either::Right(request), + }) + .collect(); + + for test_case in test_cases { + let (receiver, mut runner) = run_test_case(test_case, &mut rng).await; + + let result = tokio::select! { + result = receiver => result.expect("expected successful response"), + _ = runner.crank_until( + &mut rng, + got_contract_runtime_request, + Duration::from_secs(10), + ) => { + panic!("expected receiver to complete first") + } + }; + assert_eq!(result.error_code(), 0); + assert!(result.payload().is_empty()); + } +} + async fn run_test_case( TestCase { allow_request_get_all_values, @@ -192,8 +226,12 @@ async fn run_test_case( .await; let (sender, receiver) = oneshot::channel(); + let request = match request_generator { + Either::Left(f) => f(rng), + Either::Right(v) => v, + }; let event = BinaryPortEvent::HandleRequest { - request: request_generator(rng), + request, responder: Responder::without_shutdown(sender), }; @@ -389,6 +427,18 @@ fn all_values_request() -> BinaryRequest { )))) } +#[cfg(test)] +fn record_requests_with_empty_keys() -> Vec { + let mut data = Vec::new(); + for record_id in RecordId::all() { + data.push(BinaryRequest::Get(GetRequest::Record { + record_type_tag: record_id.into(), + key: vec![], + })) + } + data +} + fn trie_request() -> BinaryRequest { BinaryRequest::Get(GetRequest::Trie { trie_key: Digest::hash([1u8; 32]), diff --git a/storage/src/block_store/lmdb/indexed_lmdb_block_store.rs b/storage/src/block_store/lmdb/indexed_lmdb_block_store.rs index 464e74b634..8fdff0f7e4 100644 --- a/storage/src/block_store/lmdb/indexed_lmdb_block_store.rs +++ b/storage/src/block_store/lmdb/indexed_lmdb_block_store.rs @@ -895,6 +895,9 @@ impl<'t> DataReader<(DbTableId, Vec), DbRawBytesSpec> &self, (id, key): (DbTableId, Vec), ) -> Result, BlockStoreError> { + if key.is_empty() { + return Ok(None); + } let store = &self.block_store.block_store; let res = match id { DbTableId::BlockHeader => store.block_header_dbs.get_raw(&self.txn, &key), diff --git a/storage/src/block_store/lmdb/versioned_databases.rs b/storage/src/block_store/lmdb/versioned_databases.rs index b4d5fa2866..3cc6e1ebb6 100644 --- a/storage/src/block_store/lmdb/versioned_databases.rs +++ b/storage/src/block_store/lmdb/versioned_databases.rs @@ -175,6 +175,9 @@ where txn: &Tx, key: &[u8], ) -> Result, LmdbExtError> { + if key.is_empty() { + return Ok(None); + } let value = txn.get(self.current, &key); match value { Ok(raw_bytes) => Ok(Some(DbRawBytesSpec::new_current(raw_bytes))), @@ -584,4 +587,13 @@ mod tests { .for_each_value_in_legacy(&mut txn, &mut visitor) .unwrap(); } + + #[test] + fn should_get_on_empty_key() { + let fixture = Fixture::new(); + let txn = fixture.env.begin_ro_txn().unwrap(); + let key = vec![]; + let res = fixture.dbs.get_raw(&txn, &key); + assert!(matches!(res, Ok(None))); + } }