From fbc171361d67c85deca5553d42c856a16a494d92 Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Mon, 12 Feb 2024 11:47:49 -0800 Subject: [PATCH] Implement read support for proto2 extensions (#163) --- pb-jelly-gen/codegen/codegen.py | 161 +++++-- pb-jelly/src/extensions.rs | 133 ++++++ pb-jelly/src/helpers.rs | 6 +- pb-jelly/src/lib.rs | 59 ++- .../proto_pbtest/src/extensions.rs.expected | 403 ++++++++++++++++++ .../pb-jelly/proto_pbtest/src/lib.rs.expected | 1 + .../proto_pbtest/src/pbtest3.rs.expected | 12 +- .../proto/packages/pbtest/extensions.proto | 26 ++ pb-test/src/pbtest.rs | 37 ++ pb-test/src/verify_generated_files.rs | 2 +- 10 files changed, 793 insertions(+), 47 deletions(-) create mode 100644 pb-jelly/src/extensions.rs create mode 100644 pb-test/gen/pb-jelly/proto_pbtest/src/extensions.rs.expected create mode 100644 pb-test/proto/packages/pbtest/extensions.proto diff --git a/pb-jelly-gen/codegen/codegen.py b/pb-jelly-gen/codegen/codegen.py index e255fc0..41ebbf1 100755 --- a/pb-jelly-gen/codegen/codegen.py +++ b/pb-jelly-gen/codegen/codegen.py @@ -133,11 +133,7 @@ def escape_name(s: str) -> str: # https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto SourceCodeLocation = List[int] ProtoTypes = Union[FileDescriptorProto, EnumDescriptorProto, DescriptorProto] -WalkRet = Tuple[ - List[Tuple[List[Text], EnumDescriptorProto, SourceCodeLocation]], - List[Tuple[List[Text], DescriptorProto, SourceCodeLocation]], -] -ModTree = DefaultDict[Text, DefaultDict[Text, Any]] +ModTree = DefaultDict[Text, "ModTree"] T = TypeVar("T") @@ -223,7 +219,7 @@ def __init__( self, ctx: "Context", proto_file: FileDescriptorProto, - msg_type: DescriptorProto, + msg_type: Optional[DescriptorProto], field: FieldDescriptorProto, ) -> None: self.ctx = ctx @@ -234,6 +230,7 @@ def __init__( self.oneof = ( field.HasField("oneof_index") and not field.proto3_optional + and msg_type is not None and msg_type.oneof_decl[field.oneof_index] ) @@ -784,7 +781,7 @@ def write_comments(self, sci_loc: Optional[SourceCodeInfo.Location]) -> None: self.write_line_broken_text_with_prefix(sci_loc.trailing_comments, "///") def rust_type( - self, msg_type: DescriptorProto, field: FieldDescriptorProto + self, msg_type: Optional[DescriptorProto], field: FieldDescriptorProto ) -> RustType: return RustType(self.ctx, self.proto_file, msg_type, field) @@ -971,6 +968,11 @@ def gen_msg( name = "_".join(path + [msg_type.name]) escaped_name = escape_name(name) + preserve_unrecognized = msg_type.options.Extensions[ + extensions_pb2.preserve_unrecognized + ] + has_extensions = len(msg_type.extension_range) > 0 + oneof_fields: DefaultDict[Text, List[FieldDescriptorProto]] = defaultdict(list) proto3_optional_synthetic_oneofs: Set[int] = { field.oneof_index for field in msg_type.field if field.proto3_optional @@ -1024,9 +1026,12 @@ def gen_msg( % (escape_name(oneof.name), oneof_msg_name(name, oneof)) ) - if msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]: + if preserve_unrecognized: self.write("pub _unrecognized: Vec,") + if has_extensions: + self.write("pub _extensions: ::pb_jelly::Unrecognized,") + # Generate any oneof enum structs for oneof in oneof_decls: self.write("#[derive(%s)]" % ", ".join(sorted(derives))) @@ -1124,10 +1129,10 @@ def gen_msg( self.write( "%s: %s," % (escape_name(oneof.name), typ.default(name)) ) - if msg_type.options.Extensions[ - extensions_pb2.preserve_unrecognized - ]: + if preserve_unrecognized: self.write("_unrecognized: Vec::new(),") + if has_extensions: + self.write("_extensions: ::pb_jelly::Unrecognized::default(),") with block(self, "lazy_static!"): self.write( @@ -1205,10 +1210,7 @@ def gen_msg( self.write('name: "%s",' % oneof.name) with block(self, "fn compute_size(&self) -> usize"): - if ( - len(msg_type.field) > 0 - or msg_type.options.Extensions[extensions_pb2.preserve_unrecognized] - ): + if len(msg_type.field) > 0 or preserve_unrecognized or has_extensions: self.write("let mut size = 0;") for field in msg_type.field: typ = self.rust_type(msg_type, field) @@ -1263,10 +1265,10 @@ def gen_msg( % field.name ) self.write("size += %s_size;" % field.name) - if msg_type.options.Extensions[ - extensions_pb2.preserve_unrecognized - ]: + if preserve_unrecognized: self.write("size += self._unrecognized.len();") + if has_extensions: + self.write("size += self._extensions.compute_size();") self.write("size") else: self.write("0") @@ -1341,17 +1343,16 @@ def gen_msg( ) self.write("::pb_jelly::varint::write(l as u64, w)?;") self.write("::pb_jelly::Message::serialize(val, w)?;") - if msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]: + if preserve_unrecognized: self.write("w.write_all(&self._unrecognized)?;") + if has_extensions: + self.write("self._extensions.serialize(w)?;") self.write("Ok(())") with block( self, "fn deserialize(&mut self, mut buf: &mut B) -> ::std::io::Result<()>", ): - preserve_unrecognized = msg_type.options.Extensions[ - extensions_pb2.preserve_unrecognized - ] if preserve_unrecognized: self.write( "let mut unrecognized = ::pb_jelly::Unrecognized::default();" @@ -1461,6 +1462,15 @@ def gen_msg( "self.%s = %s;" % (escape_name(field.name), field_val) ) + if has_extensions: + pattern = " | ".join( + "{}..={}".format(r.start, r.end - 1) + for r in msg_type.extension_range + ) + with block(self, pattern + " =>"): + self.write( + "self._extensions.gather(field_number, typ, &mut buf)?;" + ) with block(self, "_ =>"): if preserve_unrecognized: self.write( @@ -1488,7 +1498,12 @@ def gen_msg( ) if preserve_unrecognized: - self.write("unrecognized.serialize(&mut self._unrecognized)?;") + self.write( + "self._unrecognized.reserve(unrecognized.compute_size());" + ) + self.write( + "unrecognized.serialize(&mut std::io::Cursor::new(&mut self._unrecognized))?;" + ) self.write("Ok(())") with block(self, "impl ::pb_jelly::Reflection for " + name): @@ -1594,9 +1609,58 @@ def gen_msg( with block(self, "_ =>"): self.write('panic!("unknown field name given")') + if has_extensions: + with block(self, "impl ::pb_jelly::extensions::Extensible for " + name): + with block( + self, + "fn _extensions(&self) -> &::pb_jelly::Unrecognized", + ): + self.write("&self._extensions") + + def gen_extension( + self, + path: List[Text], + extension_field: FieldDescriptorProto, + scl: SourceCodeLocation, + ) -> None: + crate, mod_parts = self.ctx.crate_from_proto_filename(self.proto_file.name) + + self.write_comments(self.source_code_info_by_scl.get(tuple(scl))) + name = ("_".join(path + [extension_field.name])).upper() + rust_type = self.rust_type(None, extension_field) + extendee = self.ctx.find(extension_field.extendee) + kind = ( + "RepeatedExtension" + if extension_field.label == FieldDescriptorProto.LABEL_REPEATED + else "SingularExtension" + ) + + self.write( + """pub const {name}: ::pb_jelly::extensions::{kind}<{extendee}, {field_type}> = + ::pb_jelly::extensions::{kind}::new( + {field_number}, + ::pb_jelly::wire_format::Type::{wire_format}, + "{raw_name}", + );""".format( + name=name, + extendee=extendee.rust_name(crate, mod_parts), + field_type=rust_type.rust_type(), + kind=kind, + field_number=extension_field.number, + wire_format=rust_type.wire_format(), + raw_name=extension_field.name, + ) + ) -def walk(proto: FileDescriptorProto) -> WalkRet: - enums, messages = [], [] + +def walk( + proto: FileDescriptorProto, +) -> Tuple[ + List[Tuple[List[Text], EnumDescriptorProto, SourceCodeLocation]], + List[Tuple[List[Text], DescriptorProto, SourceCodeLocation]], + List[Tuple[List[Text], FieldDescriptorProto, SourceCodeLocation]], +]: + enums, messages, extensions = [], [], [] def _walk( proto: ProtoTypes, parents: List[Text], scl_prefix: SourceCodeLocation @@ -1613,6 +1677,15 @@ def _walk( for i, nested_message in enumerate(proto.nested_type): ntfn = DescriptorProto.NESTED_TYPE_FIELD_NUMBER _walk(nested_message, parents + [proto.name], scl_prefix + [ntfn, i]) + + for i, nested_extension in enumerate(proto.extension): + extensions.append( + ( + parents + [proto.name], + nested_extension, + scl_prefix + [DescriptorProto.EXTENSION_FIELD_NUMBER, i], + ) + ) elif isinstance(proto, FileDescriptorProto): for i, enum_type in enumerate(proto.enum_type): etfn = FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER @@ -1622,8 +1695,17 @@ def _walk( mtfn = FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER _walk(message_type, parents, scl_prefix + [mtfn, i]) + for i, nested_extension in enumerate(proto.extension): + extensions.append( + ( + parents, + nested_extension, + scl_prefix + [FileDescriptorProto.EXTENSION_FIELD_NUMBER, i], + ) + ) + _walk(proto, [], []) - return enums, messages + return enums, messages, extensions M = TypeVar("M", DescriptorProto, EnumDescriptorProto) @@ -1755,6 +1837,11 @@ def calc_impls( if msg_type.typ.options.Extensions[extensions_pb2.preserve_unrecognized]: impls_copy = False # Preserve unparsed has a Vec which is not Copy + if len(msg_type.typ.extension_range) > 0: + # `Unrecognized` is neither Copy nor Eq + impls_eq = False + impls_copy = False + for field in msg_type.typ.field: typ = field.type rust_type = RustType(self, msg_type.proto_file, msg_type.typ, field) @@ -1810,6 +1897,7 @@ def calc_impls( if msg_type.typ.options.Extensions[ extensions_pb2.preserve_unrecognized ]: + # TODO: this check isn't really necessary, but it is useful assert field_type.typ.options.Extensions[ extensions_pb2.preserve_unrecognized ], ( @@ -1844,7 +1932,7 @@ def calc_impls( ) def feed(self, proto_file: FileDescriptorProto, to_generate: List[Text]) -> None: - enums, messages = walk(proto_file) + enums, messages, extensions = walk(proto_file) for name in to_generate: crate, _ = self.crate_from_proto_filename(name) @@ -1865,6 +1953,8 @@ def feed(self, proto_file: FileDescriptorProto, to_generate: List[Text]) -> None # so it suffices to examine one file at a time for the purposes of `box_recursive_fields` box_recursive_fields(message_types) + crate, _ = self.crate_from_proto_filename(proto_file.name) + for path, typ, _ in messages: msg_pt = ProtoType(self, proto_file, path, typ) @@ -1879,6 +1969,17 @@ def edges(type_name: Text) -> List[Text]: self.scc.process(msg_pt.proto_name(), edges, self.calc_impls) + if crate in self.deps_map: + for path, field, _ in extensions: + for type_name in [field.type_name, field.extendee]: + if type_name: + field_type = self.find(type_name) + dep_crate, _ = self.crate_from_proto_filename( + field_type.proto_file.name + ) + if dep_crate != crate: + self.deps_map[crate].add(dep_crate) + def find_enum(self, typename: Text) -> ProtoType[EnumDescriptorProto]: pt = self.find(typename) assert isinstance(pt.typ, EnumDescriptorProto) @@ -2124,7 +2225,7 @@ def add_mod(writer: CodeWriter) -> None: if writer.derive_serde: derive_serde = True - enums, messages = walk(proto_file) + enums, messages, extensions = walk(proto_file) for path, enum_typ, scl in enums: writer.gen_enum(path, enum_typ, scl) @@ -2134,6 +2235,10 @@ def add_mod(writer: CodeWriter) -> None: writer.gen_msg(path, msg_typ, scl) writer.write("") + for path, extension_field, scl in extensions: + writer.gen_extension(path, extension_field, scl) + writer.write("") + add_mod(writer=writer) # Note that output filenames must use "/" even on windows. It is part of the diff --git a/pb-jelly/src/extensions.rs b/pb-jelly/src/extensions.rs new file mode 100644 index 0000000..14282b3 --- /dev/null +++ b/pb-jelly/src/extensions.rs @@ -0,0 +1,133 @@ +use std::io; +use std::marker::PhantomData; + +use crate::{ + ensure_wire_format, + varint, + wire_format, + Message, + PbBufferReader, + Unrecognized, +}; + +/// Indicates that a message type has extension ranges defined. +/// See for details. +pub trait Extensible: Message { + /// Attempts to read the given extension field from `self`. + /// + /// Returns `Err(_)` if the field was found but could not be deserialized as the declared field type. + fn get_extension>(&self, extension: E) -> io::Result { + extension.get(self) + } + + /// Returns a reference to the `_extensions` field. + /// This is intended to be implemented by generated code and isn't very useful for users of pb-jelly, + /// so it's doc(hidden). + #[doc(hidden)] + fn _extensions(&self) -> &Unrecognized; +} + +/// Abstracts over [SingularExtension]/[RepeatedExtension]. +pub trait Extension { + type Extendee: Extensible; + type Value; + fn get(&self, m: &Self::Extendee) -> io::Result; +} + +/// An extension field. See for details. +pub struct SingularExtension { + pub field_number: u32, + pub wire_format: wire_format::Type, + pub name: &'static str, + _phantom: PhantomData U>, +} + +impl SingularExtension { + pub const fn new(field_number: u32, wire_format: wire_format::Type, name: &'static str) -> Self { + Self { + field_number, + wire_format, + name, + _phantom: PhantomData, + } + } +} + +impl Copy for SingularExtension {} +impl Clone for SingularExtension { + fn clone(&self) -> Self { + *self + } +} + +impl Extension for SingularExtension { + type Extendee = T; + type Value = Option; + + fn get(&self, m: &Self::Extendee) -> io::Result> { + Ok(match dbg!(m._extensions().get_singular_field(self.field_number)) { + Some((field, wire_format)) => { + let mut buf = io::Cursor::new(field); + ensure_wire_format(wire_format, self.wire_format, self.name, self.field_number)?; + if wire_format == wire_format::Type::LengthDelimited { + // we don't actually need this since we already have the length of `field` + varint::read(&mut buf)?; + } + let mut msg = U::default(); + msg.deserialize(&mut buf)?; + Some(msg) + }, + None => None, + }) + } +} + +/// A `repeated` extension field. See for details. +pub struct RepeatedExtension { + pub field_number: u32, + pub wire_format: wire_format::Type, + pub name: &'static str, + _phantom: PhantomData U>, +} + +impl RepeatedExtension { + pub const fn new(field_number: u32, wire_format: wire_format::Type, name: &'static str) -> Self { + Self { + field_number, + wire_format, + name, + _phantom: PhantomData, + } + } +} + +impl Copy for RepeatedExtension {} +impl Clone for RepeatedExtension { + fn clone(&self) -> Self { + *self + } +} + +impl Extension for RepeatedExtension { + type Extendee = T; + type Value = Vec; + + fn get(&self, m: &Self::Extendee) -> io::Result> { + let mut result = vec![]; + let mut buf = io::Cursor::new(m._extensions().get_fields(self.field_number)); + while let Some((_field_number, wire_format)) = wire_format::read(&mut buf)? { + ensure_wire_format(wire_format, self.wire_format, self.name, self.field_number)?; + let mut msg = U::default(); + if wire_format == wire_format::Type::LengthDelimited { + let length = varint::read(&mut buf)?.expect("corrupted Unrecognized"); + msg.deserialize(&mut buf.split(length as usize))?; + } else { + // we rely on the fact that the appropriate `Message` impls for i32/Fixed32/etc. only read the prefix of + // `buf`. this is a little dirty + msg.deserialize(&mut buf)?; + } + result.push(msg); + } + Ok(result) + } +} diff --git a/pb-jelly/src/helpers.rs b/pb-jelly/src/helpers.rs index 0ab67a1..878d1e0 100644 --- a/pb-jelly/src/helpers.rs +++ b/pb-jelly/src/helpers.rs @@ -15,7 +15,7 @@ pub fn deserialize_packed( typ: wire_format::Type, expected_wire_format: wire_format::Type, msg_name: &'static str, - field_number: usize, + field_number: u32, out: &mut Vec, ) -> io::Result<()> { match typ { @@ -42,7 +42,7 @@ pub fn deserialize_length_delimited( buf: &mut B, typ: wire_format::Type, msg_name: &'static str, - field_number: usize, + field_number: u32, ) -> io::Result { ensure_wire_format(typ, wire_format::Type::LengthDelimited, msg_name, field_number)?; let len = varint::ensure_read(buf)?; @@ -57,7 +57,7 @@ pub fn deserialize_known_length( typ: wire_format::Type, expected_wire_format: wire_format::Type, msg_name: &'static str, - field_number: usize, + field_number: u32, ) -> io::Result { ensure_wire_format(typ, expected_wire_format, msg_name, field_number)?; let mut val: T = Default::default(); diff --git a/pb-jelly/src/lib.rs b/pb-jelly/src/lib.rs index ac72c90..070eec3 100644 --- a/pb-jelly/src/lib.rs +++ b/pb-jelly/src/lib.rs @@ -9,7 +9,10 @@ extern crate serde; use std::any::Any; use std::collections::BTreeMap; use std::default::Default; -use std::fmt::Debug; +use std::fmt::{ + self, + Debug, +}; use std::io::{ Cursor, Error, @@ -24,6 +27,7 @@ use bytes::buf::{ }; pub mod erased; +pub mod extensions; pub mod helpers; pub mod varint; pub mod wire_format; @@ -129,7 +133,7 @@ pub fn ensure_wire_format( format: wire_format::Type, expected: wire_format::Type, msg_name: &str, - field_number: usize, + field_number: u32, ) -> Result<()> { if format != expected { return Err(Error::new( @@ -148,29 +152,45 @@ pub fn unexpected_eof() -> Error { Error::new(ErrorKind::UnexpectedEof, "unexpected EOF") } -#[derive(Default)] +// XXX: arguably this should not impl PartialEq since we cannot canonicalize the unparsed field contents +#[derive(Clone, Default, PartialEq)] pub struct Unrecognized { by_field_number: BTreeMap>, } +impl fmt::Debug for Unrecognized { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .entries(self.by_field_number.keys().map(|k| (k, ..))) + .finish() + } +} + impl Unrecognized { - pub fn serialize(self, unrecognized_buf: &mut Vec) -> Result<()> { + pub fn new() -> Self { + Self::default() + } + + pub fn serialize(&self, unrecognized_buf: &mut impl PbBufferWriter) -> Result<()> { // Write out sorted by field number - unrecognized_buf.reserve(self.by_field_number.values().map(|v| v.len()).sum()); for serialized_field in self.by_field_number.values() { unrecognized_buf.write_all(&serialized_field)?; } Ok(()) } + pub fn compute_size(&self) -> usize { + self.by_field_number.values().map(|v| v.len()).sum() + } + pub fn gather(&mut self, field_number: u32, typ: wire_format::Type, buf: &mut B) -> Result<()> { - let mut unrecognized_buf = vec![]; + let unrecognized_buf = self.by_field_number.entry(field_number).or_default(); - wire_format::write(field_number, typ, &mut unrecognized_buf)?; + wire_format::write(field_number, typ, unrecognized_buf)?; let advance = match typ { wire_format::Type::Varint => { if let Some(num) = varint::read(buf)? { - varint::write(num, &mut unrecognized_buf)?; + varint::write(num, unrecognized_buf)?; } else { return Err(unexpected_eof()); }; @@ -181,7 +201,7 @@ impl Unrecognized { wire_format::Type::Fixed32 => 4, wire_format::Type::LengthDelimited => match varint::read(buf)? { Some(n) => { - varint::write(n, &mut unrecognized_buf)?; + varint::write(n, unrecognized_buf)?; n as usize }, None => return Err(unexpected_eof()), @@ -194,10 +214,27 @@ impl Unrecognized { unrecognized_buf.put(buf.take(advance)); - self.by_field_number.insert(field_number, unrecognized_buf); - Ok(()) } + + pub(crate) fn get_singular_field(&self, field_number: u32) -> Option<(&[u8], wire_format::Type)> { + let mut buf = Cursor::new(&self.by_field_number.get(&field_number)?[..]); + let mut result = None; + // It's technically legal for a singular field to occur multiple times on the wire, + // so skip over all but the last instance. + while let Some((_field_number, wire_format)) = + wire_format::read(&mut buf).expect("self.by_field_number malformed") + { + result = Some((&buf.get_ref()[buf.position() as usize..], wire_format)); + + skip(wire_format, &mut buf).expect("self.by_field_number malformed"); + } + result + } + + pub(crate) fn get_fields(&self, field_number: u32) -> &[u8] { + self.by_field_number.get(&field_number).map_or(&[], Vec::as_ref) + } } pub fn skip(typ: wire_format::Type, buf: &mut B) -> Result<()> { diff --git a/pb-test/gen/pb-jelly/proto_pbtest/src/extensions.rs.expected b/pb-test/gen/pb-jelly/proto_pbtest/src/extensions.rs.expected new file mode 100644 index 0000000..62dbd38 --- /dev/null +++ b/pb-test/gen/pb-jelly/proto_pbtest/src/extensions.rs.expected @@ -0,0 +1,403 @@ +// @generated, do not edit +#[derive(Clone, Debug, PartialEq)] +pub struct Msg { + pub base_field: ::std::option::Option, + pub _extensions: ::pb_jelly::Unrecognized, +} +impl Msg { + pub fn has_base_field(&self) -> bool { + self.base_field.is_some() + } + pub fn set_base_field(&mut self, v: i32) { + self.base_field = Some(v); + } + pub fn get_base_field(&self) -> i32 { + self.base_field.unwrap_or(0) + } +} +impl ::std::default::Default for Msg { + fn default() -> Self { + Msg { + base_field: ::std::default::Default::default(), + _extensions: ::pb_jelly::Unrecognized::default(), + } + } +} +lazy_static! { + pub static ref Msg_default: Msg = Msg::default(); +} +impl ::pb_jelly::Message for Msg { + fn descriptor(&self) -> ::std::option::Option<::pb_jelly::MessageDescriptor> { + Some(::pb_jelly::MessageDescriptor { + name: "Msg", + full_name: "pbtest.Msg", + fields: &[ + ::pb_jelly::FieldDescriptor { + name: "base_field", + full_name: "pbtest.Msg.base_field", + index: 0, + number: 250, + typ: ::pb_jelly::wire_format::Type::Varint, + label: ::pb_jelly::Label::Optional, + oneof_index: None, + }, + ], + oneofs: &[ + ], + }) + } + fn compute_size(&self) -> usize { + let mut size = 0; + let mut base_field_size = 0; + if let Some(ref val) = self.base_field { + let l = ::pb_jelly::Message::compute_size(val); + base_field_size += ::pb_jelly::wire_format::serialized_length(250); + base_field_size += l; + } + size += base_field_size; + size += self._extensions.compute_size(); + size + } + fn serialize(&self, w: &mut W) -> ::std::io::Result<()> { + if let Some(ref val) = self.base_field { + ::pb_jelly::wire_format::write(250, ::pb_jelly::wire_format::Type::Varint, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + self._extensions.serialize(w)?; + Ok(()) + } + fn deserialize(&mut self, mut buf: &mut B) -> ::std::io::Result<()> { + while let Some((field_number, typ)) = ::pb_jelly::wire_format::read(&mut buf)? { + match field_number { + 250 => { + let val = ::pb_jelly::helpers::deserialize_known_length::(buf, typ, ::pb_jelly::wire_format::Type::Varint, "Msg", 250)?; + self.base_field = Some(val); + } + 100..=200 | 300..=536870911 => { + self._extensions.gather(field_number, typ, &mut buf)?; + } + _ => { + ::pb_jelly::skip(typ, &mut buf)?; + } + } + } + Ok(()) + } +} +impl ::pb_jelly::Reflection for Msg { + fn which_one_of(&self, oneof_name: &str) -> ::std::option::Option<&'static str> { + match oneof_name { + _ => { + panic!("unknown oneof name given"); + } + } + } + fn get_field_mut(&mut self, field_name: &str) -> ::pb_jelly::reflection::FieldMut<'_> { + match field_name { + "base_field" => { + ::pb_jelly::reflection::FieldMut::Value(self.base_field.get_or_insert_with(::std::default::Default::default)) + } + _ => { + panic!("unknown field name given") + } + } + } +} +impl ::pb_jelly::extensions::Extensible for Msg { + fn _extensions(&self) -> &::pb_jelly::Unrecognized { + &self._extensions + } +} + +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct FakeMsg { + pub base_field: ::std::option::Option, + pub singular_primitive: ::std::option::Option, + pub singular_message: ::std::option::Option, + pub repeated_primitive: ::std::vec::Vec, + pub repeated_message: ::std::vec::Vec, +} +impl FakeMsg { + pub fn has_base_field(&self) -> bool { + self.base_field.is_some() + } + pub fn set_base_field(&mut self, v: i32) { + self.base_field = Some(v); + } + pub fn get_base_field(&self) -> i32 { + self.base_field.unwrap_or(0) + } + pub fn has_singular_primitive(&self) -> bool { + self.singular_primitive.is_some() + } + pub fn set_singular_primitive(&mut self, v: i32) { + self.singular_primitive = Some(v); + } + pub fn get_singular_primitive(&self) -> i32 { + self.singular_primitive.unwrap_or(0) + } + pub fn has_singular_message(&self) -> bool { + self.singular_message.is_some() + } + pub fn set_singular_message(&mut self, v: super::pbtest3::ForeignMessage3) { + self.singular_message = Some(v); + } + pub fn take_singular_message(&mut self) -> super::pbtest3::ForeignMessage3 { + self.singular_message.take().unwrap_or_default() + } + pub fn get_singular_message(&self) -> &super::pbtest3::ForeignMessage3 { + self.singular_message.as_ref().unwrap_or(&super::pbtest3::ForeignMessage3_default) + } + pub fn set_repeated_primitive(&mut self, v: ::std::vec::Vec) { + self.repeated_primitive = v; + } + pub fn take_repeated_primitive(&mut self) -> ::std::vec::Vec { + ::std::mem::take(&mut self.repeated_primitive) + } + pub fn get_repeated_primitive(&self) -> &[i32] { + &self.repeated_primitive + } + pub fn mut_repeated_primitive(&mut self) -> &mut ::std::vec::Vec { + &mut self.repeated_primitive + } + pub fn set_repeated_message(&mut self, v: ::std::vec::Vec) { + self.repeated_message = v; + } + pub fn take_repeated_message(&mut self) -> ::std::vec::Vec { + ::std::mem::take(&mut self.repeated_message) + } + pub fn get_repeated_message(&self) -> &[super::pbtest3::ForeignMessage3] { + &self.repeated_message + } + pub fn mut_repeated_message(&mut self) -> &mut ::std::vec::Vec { + &mut self.repeated_message + } +} +impl ::std::default::Default for FakeMsg { + fn default() -> Self { + FakeMsg { + base_field: ::std::default::Default::default(), + singular_primitive: ::std::default::Default::default(), + singular_message: ::std::default::Default::default(), + repeated_primitive: ::std::default::Default::default(), + repeated_message: ::std::default::Default::default(), + } + } +} +lazy_static! { + pub static ref FakeMsg_default: FakeMsg = FakeMsg::default(); +} +impl ::pb_jelly::Message for FakeMsg { + fn descriptor(&self) -> ::std::option::Option<::pb_jelly::MessageDescriptor> { + Some(::pb_jelly::MessageDescriptor { + name: "FakeMsg", + full_name: "pbtest.FakeMsg", + fields: &[ + ::pb_jelly::FieldDescriptor { + name: "base_field", + full_name: "pbtest.FakeMsg.base_field", + index: 0, + number: 250, + typ: ::pb_jelly::wire_format::Type::Varint, + label: ::pb_jelly::Label::Optional, + oneof_index: None, + }, + ::pb_jelly::FieldDescriptor { + name: "singular_primitive", + full_name: "pbtest.FakeMsg.singular_primitive", + index: 1, + number: 101, + typ: ::pb_jelly::wire_format::Type::Varint, + label: ::pb_jelly::Label::Optional, + oneof_index: None, + }, + ::pb_jelly::FieldDescriptor { + name: "singular_message", + full_name: "pbtest.FakeMsg.singular_message", + index: 2, + number: 301, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: None, + }, + ::pb_jelly::FieldDescriptor { + name: "repeated_primitive", + full_name: "pbtest.FakeMsg.repeated_primitive", + index: 3, + number: 300, + typ: ::pb_jelly::wire_format::Type::Varint, + label: ::pb_jelly::Label::Repeated, + oneof_index: None, + }, + ::pb_jelly::FieldDescriptor { + name: "repeated_message", + full_name: "pbtest.FakeMsg.repeated_message", + index: 4, + number: 200, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Repeated, + oneof_index: None, + }, + ], + oneofs: &[ + ], + }) + } + fn compute_size(&self) -> usize { + let mut size = 0; + let mut base_field_size = 0; + if let Some(ref val) = self.base_field { + let l = ::pb_jelly::Message::compute_size(val); + base_field_size += ::pb_jelly::wire_format::serialized_length(250); + base_field_size += l; + } + size += base_field_size; + let mut singular_primitive_size = 0; + if let Some(ref val) = self.singular_primitive { + let l = ::pb_jelly::Message::compute_size(val); + singular_primitive_size += ::pb_jelly::wire_format::serialized_length(101); + singular_primitive_size += l; + } + size += singular_primitive_size; + let mut singular_message_size = 0; + if let Some(ref val) = self.singular_message { + let l = ::pb_jelly::Message::compute_size(val); + singular_message_size += ::pb_jelly::wire_format::serialized_length(301); + singular_message_size += ::pb_jelly::varint::serialized_length(l as u64); + singular_message_size += l; + } + size += singular_message_size; + let mut repeated_primitive_size = 0; + for val in &self.repeated_primitive { + let l = ::pb_jelly::Message::compute_size(val); + repeated_primitive_size += ::pb_jelly::wire_format::serialized_length(300); + repeated_primitive_size += l; + } + size += repeated_primitive_size; + let mut repeated_message_size = 0; + for val in &self.repeated_message { + let l = ::pb_jelly::Message::compute_size(val); + repeated_message_size += ::pb_jelly::wire_format::serialized_length(200); + repeated_message_size += ::pb_jelly::varint::serialized_length(l as u64); + repeated_message_size += l; + } + size += repeated_message_size; + size + } + fn serialize(&self, w: &mut W) -> ::std::io::Result<()> { + if let Some(ref val) = self.singular_primitive { + ::pb_jelly::wire_format::write(101, ::pb_jelly::wire_format::Type::Varint, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + for val in &self.repeated_message { + ::pb_jelly::wire_format::write(200, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + if let Some(ref val) = self.base_field { + ::pb_jelly::wire_format::write(250, ::pb_jelly::wire_format::Type::Varint, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + for val in &self.repeated_primitive { + ::pb_jelly::wire_format::write(300, ::pb_jelly::wire_format::Type::Varint, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + if let Some(ref val) = self.singular_message { + ::pb_jelly::wire_format::write(301, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + Ok(()) + } + fn deserialize(&mut self, mut buf: &mut B) -> ::std::io::Result<()> { + while let Some((field_number, typ)) = ::pb_jelly::wire_format::read(&mut buf)? { + match field_number { + 250 => { + let val = ::pb_jelly::helpers::deserialize_known_length::(buf, typ, ::pb_jelly::wire_format::Type::Varint, "FakeMsg", 250)?; + self.base_field = Some(val); + } + 101 => { + let val = ::pb_jelly::helpers::deserialize_known_length::(buf, typ, ::pb_jelly::wire_format::Type::Varint, "FakeMsg", 101)?; + self.singular_primitive = Some(val); + } + 301 => { + let val = ::pb_jelly::helpers::deserialize_length_delimited::(buf, typ, "FakeMsg", 301)?; + self.singular_message = Some(val); + } + 300 => { + ::pb_jelly::helpers::deserialize_packed::(buf, typ, ::pb_jelly::wire_format::Type::Varint, "FakeMsg", 300, &mut self.repeated_primitive)?; + } + 200 => { + let val = ::pb_jelly::helpers::deserialize_length_delimited::(buf, typ, "FakeMsg", 200)?; + self.repeated_message.push(val); + } + _ => { + ::pb_jelly::skip(typ, &mut buf)?; + } + } + } + Ok(()) + } +} +impl ::pb_jelly::Reflection for FakeMsg { + fn which_one_of(&self, oneof_name: &str) -> ::std::option::Option<&'static str> { + match oneof_name { + _ => { + panic!("unknown oneof name given"); + } + } + } + fn get_field_mut(&mut self, field_name: &str) -> ::pb_jelly::reflection::FieldMut<'_> { + match field_name { + "base_field" => { + ::pb_jelly::reflection::FieldMut::Value(self.base_field.get_or_insert_with(::std::default::Default::default)) + } + "singular_primitive" => { + ::pb_jelly::reflection::FieldMut::Value(self.singular_primitive.get_or_insert_with(::std::default::Default::default)) + } + "singular_message" => { + ::pb_jelly::reflection::FieldMut::Value(self.singular_message.get_or_insert_with(::std::default::Default::default)) + } + "repeated_primitive" => { + unimplemented!("Repeated fields are not currently supported.") + } + "repeated_message" => { + unimplemented!("Repeated fields are not currently supported.") + } + _ => { + panic!("unknown field name given") + } + } + } +} + +pub const SINGULAR_PRIMITIVE: ::pb_jelly::extensions::SingularExtension = + ::pb_jelly::extensions::SingularExtension::new( + 101, + ::pb_jelly::wire_format::Type::Varint, + "singular_primitive", + ); + +pub const SINGULAR_MESSAGE: ::pb_jelly::extensions::SingularExtension = + ::pb_jelly::extensions::SingularExtension::new( + 301, + ::pb_jelly::wire_format::Type::LengthDelimited, + "singular_message", + ); + +pub const REPEATED_PRIMITIVE: ::pb_jelly::extensions::RepeatedExtension = + ::pb_jelly::extensions::RepeatedExtension::new( + 300, + ::pb_jelly::wire_format::Type::Varint, + "repeated_primitive", + ); + +pub const REPEATED_MESSAGE: ::pb_jelly::extensions::RepeatedExtension = + ::pb_jelly::extensions::RepeatedExtension::new( + 200, + ::pb_jelly::wire_format::Type::LengthDelimited, + "repeated_message", + ); + diff --git a/pb-test/gen/pb-jelly/proto_pbtest/src/lib.rs.expected b/pb-test/gen/pb-jelly/proto_pbtest/src/lib.rs.expected index a6f5407..c0b7cb0 100644 --- a/pb-test/gen/pb-jelly/proto_pbtest/src/lib.rs.expected +++ b/pb-test/gen/pb-jelly/proto_pbtest/src/lib.rs.expected @@ -25,6 +25,7 @@ extern crate lazy_static; pub mod bench; +pub mod extensions; pub mod r#mod; pub mod pbtest2; pub mod pbtest3; diff --git a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected index ec09501..7cd55be 100644 --- a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected +++ b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected @@ -5491,7 +5491,8 @@ impl ::pb_jelly::Message for TestPreserveUnrecognized1 { } } } - unrecognized.serialize(&mut self._unrecognized)?; + self._unrecognized.reserve(unrecognized.compute_size()); + unrecognized.serialize(&mut std::io::Cursor::new(&mut self._unrecognized))?; Ok(()) } } @@ -5796,7 +5797,8 @@ impl ::pb_jelly::Message for TestPreserveUnrecognized2 { } } } - unrecognized.serialize(&mut self._unrecognized)?; + self._unrecognized.reserve(unrecognized.compute_size()); + unrecognized.serialize(&mut std::io::Cursor::new(&mut self._unrecognized))?; Ok(()) } } @@ -5902,7 +5904,8 @@ impl ::pb_jelly::Message for TestPreserveUnrecognizedEmpty { } } } - unrecognized.serialize(&mut self._unrecognized)?; + self._unrecognized.reserve(unrecognized.compute_size()); + unrecognized.serialize(&mut std::io::Cursor::new(&mut self._unrecognized))?; Ok(()) } } @@ -6217,7 +6220,8 @@ impl ::pb_jelly::Message for TestSmallStringPreserveUnrecognized { } } } - unrecognized.serialize(&mut self._unrecognized)?; + self._unrecognized.reserve(unrecognized.compute_size()); + unrecognized.serialize(&mut std::io::Cursor::new(&mut self._unrecognized))?; Ok(()) } } diff --git a/pb-test/proto/packages/pbtest/extensions.proto b/pb-test/proto/packages/pbtest/extensions.proto new file mode 100644 index 0000000..8e22963 --- /dev/null +++ b/pb-test/proto/packages/pbtest/extensions.proto @@ -0,0 +1,26 @@ +syntax = "proto2"; +package pbtest; + +import "pbtest/pbtest3.proto"; + +message Msg { + optional int32 base_field = 250; + extensions 100 to 200; + extensions 300 to max; +} + +extend Msg { + optional int32 singular_primitive = 101; + optional ForeignMessage3 singular_message = 301; + repeated int32 repeated_primitive = 300; + repeated ForeignMessage3 repeated_message = 200; +} + +message FakeMsg { + optional int32 base_field = 250; + + optional int32 singular_primitive = 101; + optional ForeignMessage3 singular_message = 301; + repeated int32 repeated_primitive = 300; + repeated ForeignMessage3 repeated_message = 200; +} diff --git a/pb-test/src/pbtest.rs b/pb-test/src/pbtest.rs index 0a7a51c..a69bc48 100644 --- a/pb-test/src/pbtest.rs +++ b/pb-test/src/pbtest.rs @@ -4,6 +4,7 @@ use std::io::Cursor; use std::io::Read; use bytes::Bytes; +use pb_jelly::extensions::Extensible; use pb_jelly::reflection::FieldMut; use pb_jelly::wire_format::Type; use pb_jelly::{ @@ -17,6 +18,7 @@ use pretty_assertions::{ assert_eq, assert_ne, }; +use proto_pbtest::extensions; use proto_pbtest::pbtest2::*; use proto_pbtest::pbtest3::*; @@ -937,3 +939,38 @@ fn test_mutual_recursion() { })), }); } + +#[test] +fn test_extensions() { + check_roundtrip(extensions::FakeMsg::default()); + check_roundtrip(extensions::FakeMsg { + base_field: Some(39), + singular_primitive: Some(123), + singular_message: Some(ForeignMessage3 { c: 321 }), + repeated_primitive: vec![456, 789], + repeated_message: vec![ForeignMessage3 { c: 654 }, ForeignMessage3 { c: 987 }], + }); + + // Check that serializing a FakeMsg and deserializing into Msg preserves the extension fields, + // and that those fields can be read using `get_extension()`. + fn check_roundtrip(orig: extensions::FakeMsg) { + let m = extensions::Msg::deserialize_from_slice(&orig.serialize_to_vec()).unwrap(); + assert_eq!(m.base_field, orig.base_field); + assert_eq!( + m.get_extension(extensions::SINGULAR_PRIMITIVE).unwrap(), + orig.singular_primitive + ); + assert_eq!( + m.get_extension(extensions::SINGULAR_MESSAGE).unwrap(), + orig.singular_message, + ); + assert_eq!( + m.get_extension(extensions::REPEATED_PRIMITIVE).unwrap(), + orig.repeated_primitive, + ); + assert_eq!( + m.get_extension(extensions::REPEATED_MESSAGE).unwrap(), + orig.repeated_message + ); + } +} diff --git a/pb-test/src/verify_generated_files.rs b/pb-test/src/verify_generated_files.rs index 893744e..39d45f9 100644 --- a/pb-test/src/verify_generated_files.rs +++ b/pb-test/src/verify_generated_files.rs @@ -23,7 +23,7 @@ fn verify_generated_files() { // Assert the correct number of pb-test generated files // Developers - please change this number if the change is intentional - assert_eq!(proto_files.len(), 15); + assert_eq!(proto_files.len(), 16); // Assert contents of the generated files for proto_file in proto_files {