Skip to content

Commit

Permalink
Add a test for the internally tagged enum canary
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed Oct 1, 2023
1 parent e32bdeb commit a95d46c
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 6 deletions.
12 changes: 7 additions & 5 deletions src/de/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ mod tag;
mod tests;
mod value;

const SERDE_CONTENT_CANARY: &str = "serde::__private::de::content::Content";
const SERDE_TAG_KEY_CANARY: &str = "serde::__private::de::content::TagOrContent";

/// The RON deserializer.
///
/// If you just want to simply deserialize a value,
Expand Down Expand Up @@ -165,9 +168,8 @@ impl<'de> Deserializer<'de> {
{
// HACK: switch to JSON enum semantics for JSON content
// Robust impl blocked on https://github.com/serde-rs/serde/pull/2420
let is_serde_content = std::any::type_name::<V::Value>()
== "serde::__private::de::content::Content"
|| std::any::type_name::<V::Value>() == "serde::__private::de::content::TagOrContent";
let is_serde_content = std::any::type_name::<V::Value>() == SERDE_CONTENT_CANARY
|| std::any::type_name::<V::Value>() == SERDE_TAG_KEY_CANARY;

let old_serde_content_newtype = self.serde_content_newtype;
self.serde_content_newtype = false;
Expand Down Expand Up @@ -849,7 +851,7 @@ impl<'de, 'a> de::MapAccess<'de> for CommaSeparated<'a, 'de> {
{
if self.has_element()? {
self.inside_internally_tagged_enum =
std::any::type_name::<K::Value>() == "serde::__private::de::content::TagOrContent";
std::any::type_name::<K::Value>() == SERDE_TAG_KEY_CANARY;

match self.terminator {
Terminator::Struct => guard_recursion! { self.de =>
Expand All @@ -875,7 +877,7 @@ impl<'de, 'a> de::MapAccess<'de> for CommaSeparated<'a, 'de> {
self.de.parser.skip_ws()?;

let res = if self.inside_internally_tagged_enum
&& std::any::type_name::<V::Value>() != "serde::__private::de::content::Content"
&& std::any::type_name::<V::Value>() != SERDE_CONTENT_CANARY
{
guard_recursion! { self.de =>
seed.deserialize(&mut tag::Deserializer::new(&mut *self.de))?
Expand Down
89 changes: 88 additions & 1 deletion tests/449_tagged_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,94 @@ fn test_serde_content_hack() {
assert_eq!(
std::any::type_name::<serde::__private::de::Content>(),
"serde::__private::de::content::Content"
)
);
}

#[test]
fn test_serde_internally_tagged_hack() {
const SERDE_CONTENT_CANARY: &str = "serde::__private::de::content::Content";
const SERDE_TAG_KEY_CANARY: &str = "serde::__private::de::content::TagOrContent";

struct Deserializer {
tag_key: Option<String>,
tag_value: String,
field_key: Option<String>,
field_value: i32,
}

impl<'de> serde::Deserializer<'de> for Deserializer {
type Error = ron::Error;

fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_map(self)
}

// GRCOV_EXCL_START
serde::forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map struct enum identifier ignored_any
}
// GRCOV_EXCL_STOP
}

impl<'de> serde::de::MapAccess<'de> for Deserializer {
type Error = ron::Error;

fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: serde::de::DeserializeSeed<'de>,
{
assert_eq!(std::any::type_name::<K::Value>(), SERDE_TAG_KEY_CANARY);

if let Some(tag_key) = self.tag_key.take() {
return seed
.deserialize(serde::de::value::StringDeserializer::new(tag_key))
.map(Some);
}

if let Some(field_key) = self.field_key.take() {
return seed
.deserialize(serde::de::value::StringDeserializer::new(field_key))
.map(Some);
}

Ok(None)
}

fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: serde::de::DeserializeSeed<'de>,
{
if self.field_key.is_some() {
assert_ne!(std::any::type_name::<V::Value>(), SERDE_CONTENT_CANARY);
return seed.deserialize(serde::de::value::StrDeserializer::new(&self.tag_value));
}

assert_eq!(std::any::type_name::<V::Value>(), SERDE_CONTENT_CANARY);

seed.deserialize(serde::de::value::I32Deserializer::new(self.field_value))
}
}

#[derive(PartialEq, Debug, Deserialize)]
#[serde(tag = "tag")]
enum InternallyTagged {
A { hi: i32 },
}

assert_eq!(
InternallyTagged::deserialize(Deserializer {
tag_key: Some(String::from("tag")),
tag_value: String::from("A"),
field_key: Some(String::from("hi")),
field_value: 42,
}),
Ok(InternallyTagged::A { hi: 42 })
);
}

#[test]
Expand Down

0 comments on commit a95d46c

Please sign in to comment.