From 5de4eb3bd7e1c87a94a7b48501ea86f6fc871862 Mon Sep 17 00:00:00 2001 From: Yiannis Marangos Date: Tue, 26 Nov 2024 16:44:35 +0200 Subject: [PATCH] add more complete CowStr impl --- proto/src/serializers/cow_str.rs | 140 ++++++++++++++++++++++++++++++- 1 file changed, 136 insertions(+), 4 deletions(-) diff --git a/proto/src/serializers/cow_str.rs b/proto/src/serializers/cow_str.rs index 574c5b52..2e1a8fa2 100644 --- a/proto/src/serializers/cow_str.rs +++ b/proto/src/serializers/cow_str.rs @@ -1,11 +1,71 @@ -use std::borrow::Cow; +//! Wrapper `Cow<'_, str>` for deserializing without allocation. +//! +//! This is a workaround for [serde's issue 1852](https://github.com/serde-rs/serde/issues/1852). + +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use std::borrow::{Cow, ToOwned}; use std::fmt::{self, Debug, Display, Formatter}; use std::ops::Deref; -use serde::Deserialize; +/// Wrapper `Cow<'_, str>` for deserializing without allocation. +#[derive(Default)] +pub struct CowStr<'a>(Cow<'a, str>); + +impl<'a> CowStr<'a> { + /// Convert into `Cow<'a, str>`. + pub fn into_inner(self) -> Cow<'a, str> { + self.0 + } +} + +impl<'de> Deserialize<'de> for CowStr<'de> { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Visitor; + + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = CowStr<'de>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string") + } + + fn visit_borrowed_str(self, value: &'de str) -> Result + where + E: de::Error, + { + Ok(CowStr(Cow::Borrowed(value))) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(CowStr(Cow::Owned(value.to_owned()))) + } -#[derive(Default, Deserialize)] -pub struct CowStr<'a>(#[serde(borrow)] Cow<'a, str>); + fn visit_string(self, value: String) -> Result + where + E: de::Error, + { + Ok(CowStr(Cow::Owned(value))) + } + } + + deserializer.deserialize_str(Visitor) + } +} + +impl<'a> Serialize for CowStr<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.0) + } +} impl<'a> Debug for CowStr<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { @@ -38,3 +98,75 @@ impl<'a> AsRef<[u8]> for CowStr<'a> { self.0.as_bytes() } } + +impl<'a> From> for Cow<'a, str> { + fn from(value: CowStr<'a>) -> Self { + value.0 + } +} + +impl<'a> From> for CowStr<'a> { + fn from(value: Cow<'a, str>) -> Self { + CowStr(value) + } +} + +/// Serialize `Cow<'_, str>`. +pub fn serialize<'a, S>(value: &Cow<'a, str>, serializer: S) -> Result +where + S: Serializer, +{ + serializer.serialize_str(value) +} + +/// Deserialize `Cow<'_, str>`. +pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + CowStr::deserialize(deserializer).map(|value| value.into_inner()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn borrowed() { + struct Test(u32); + + impl<'de> Deserialize<'de> for Test { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = CowStr::deserialize(deserializer)?; + assert!(matches!(s.0, Cow::Borrowed(_))); + Ok(Test(s.parse().unwrap())) + } + } + + let v = serde_json::from_str::("\"2\"").unwrap(); + assert_eq!(v.0, 2); + } + + #[test] + fn owned() { + struct Test(u32); + + impl<'de> Deserialize<'de> for Test { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = CowStr::deserialize(deserializer)?; + assert!(matches!(s.0, Cow::Owned(_))); + Ok(Test(s.parse().unwrap())) + } + } + + let json_value = serde_json::from_str::("\"2\"").unwrap(); + let v = serde_json::from_value::(json_value).unwrap(); + assert_eq!(v.0, 2); + } +}