Skip to content

Commit

Permalink
Deserialize boundedbtreeset (#781)
Browse files Browse the repository at this point in the history
* deserialize for btreeset

* tests added

* run tests under std for bounded-btree-set

* insert after bound check, new tests for bound check
  • Loading branch information
ozgunozerk authored Aug 31, 2023
1 parent f9cfc79 commit a8a85a4
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 6 deletions.
112 changes: 110 additions & 2 deletions bounded-collections/src/bounded_btree_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ use crate::{Get, TryCollect};
use alloc::collections::BTreeSet;
use codec::{Compact, Decode, Encode, MaxEncodedLen};
use core::{borrow::Borrow, marker::PhantomData, ops::Deref};
#[cfg(feature = "serde")]
use serde::{
de::{Error, SeqAccess, Visitor},
Deserialize, Deserializer, Serialize,
};

/// A bounded set based on a B-Tree.
///
Expand All @@ -29,9 +34,67 @@ use core::{borrow::Borrow, marker::PhantomData, ops::Deref};
///
/// Unlike a standard `BTreeSet`, there is an enforced upper limit to the number of items in the
/// set. All internal operations ensure this bound is respected.
#[cfg_attr(feature = "serde", derive(Serialize), serde(transparent))]
#[derive(Encode, scale_info::TypeInfo)]
#[scale_info(skip_type_params(S))]
pub struct BoundedBTreeSet<T, S>(BTreeSet<T>, PhantomData<S>);
pub struct BoundedBTreeSet<T, S>(BTreeSet<T>, #[cfg_attr(feature = "serde", serde(skip_serializing))] PhantomData<S>);

#[cfg(feature = "serde")]
impl<'de, T, S: Get<u32>> Deserialize<'de> for BoundedBTreeSet<T, S>
where
T: Ord + Deserialize<'de>,
S: Clone,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// Create a visitor to visit each element in the sequence
struct BTreeSetVisitor<T, S>(std::marker::PhantomData<(T, S)>);

impl<'de, T, S> Visitor<'de> for BTreeSetVisitor<T, S>
where
T: Ord + Deserialize<'de>,
S: Get<u32> + Clone,
{
type Value = BTreeSet<T>;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a sequence")
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let size = seq.size_hint().unwrap_or(0);
let max = match usize::try_from(S::get()) {
Ok(n) => n,
Err(_) => return Err(A::Error::custom("can't convert to usize")),
};
if size > max {
Err(A::Error::custom("out of bounds"))
} else {
let mut values = BTreeSet::new();

while let Some(value) = seq.next_element()? {
if values.len() >= max {
return Err(A::Error::custom("out of bounds"))
}
values.insert(value);
}

Ok(values)
}
}
}

let visitor: BTreeSetVisitor<T, S> = BTreeSetVisitor(PhantomData);
deserializer
.deserialize_seq(visitor)
.map(|v| BoundedBTreeSet::<T, S>::try_from(v).map_err(|_| Error::custom("out of bounds")))?
}
}

impl<T, S> Decode for BoundedBTreeSet<T, S>
where
Expand Down Expand Up @@ -333,7 +396,7 @@ where
}
}

#[cfg(test)]
#[cfg(all(test, feature = "std"))]
mod test {
use super::*;
use crate::ConstU32;
Expand Down Expand Up @@ -522,4 +585,49 @@ mod test {
set: BoundedBTreeSet<String, ConstU32<16>>,
}
}

#[test]
fn test_serializer() {
let mut c = BoundedBTreeSet::<u32, ConstU32<6>>::new();
c.try_insert(0).unwrap();
c.try_insert(1).unwrap();
c.try_insert(2).unwrap();

assert_eq!(serde_json::json!(&c).to_string(), r#"[0,1,2]"#);
}

#[test]
fn test_deserializer() {
let c: Result<BoundedBTreeSet<u32, ConstU32<6>>, serde_json::error::Error> = serde_json::from_str(r#"[0,1,2]"#);
assert!(c.is_ok());
let c = c.unwrap();

assert_eq!(c.len(), 3);
assert!(c.contains(&0));
assert!(c.contains(&1));
assert!(c.contains(&2));
}

#[test]
fn test_deserializer_bound() {
let c: Result<BoundedBTreeSet<u32, ConstU32<3>>, serde_json::error::Error> = serde_json::from_str(r#"[0,1,2]"#);
assert!(c.is_ok());
let c = c.unwrap();

assert_eq!(c.len(), 3);
assert!(c.contains(&0));
assert!(c.contains(&1));
assert!(c.contains(&2));
}

#[test]
fn test_deserializer_failed() {
let c: Result<BoundedBTreeSet<u32, ConstU32<4>>, serde_json::error::Error> =
serde_json::from_str(r#"[0,1,2,3,4]"#);

match c {
Err(msg) => assert_eq!(msg.to_string(), "out of bounds at line 1 column 11"),
_ => unreachable!("deserializer must raise error"),
}
}
}
17 changes: 13 additions & 4 deletions bounded-collections/src/bounded_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ where
let mut values = Vec::with_capacity(size);

while let Some(value) = seq.next_element()? {
values.push(value);
if values.len() > max {
if values.len() >= max {
return Err(A::Error::custom("out of bounds"))
}
values.push(value);
}

Ok(values)
Expand Down Expand Up @@ -1187,10 +1187,19 @@ mod test {
assert_eq!(c[2], 2);
}

#[test]
fn test_deserializer_bound() {
let c: BoundedVec<u32, ConstU32<3>> = serde_json::from_str(r#"[0,1,2]"#).unwrap();

assert_eq!(c.len(), 3);
assert_eq!(c[0], 0);
assert_eq!(c[1], 1);
assert_eq!(c[2], 2);
}

#[test]
fn test_deserializer_failed() {
let c: Result<BoundedVec<u32, ConstU32<4>>, serde_json::error::Error> =
serde_json::from_str(r#"[0,1,2,3,4,5]"#);
let c: Result<BoundedVec<u32, ConstU32<4>>, serde_json::error::Error> = serde_json::from_str(r#"[0,1,2,3,4]"#);

match c {
Err(msg) => assert_eq!(msg.to_string(), "out of bounds at line 1 column 11"),
Expand Down

0 comments on commit a8a85a4

Please sign in to comment.