From 4d8c43fd17dfae40f60db34aa345ee626e739da8 Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Wed, 18 Oct 2023 13:13:00 -0700 Subject: [PATCH] Implement box_it for oneof fields (#150) This fixes codegen for messages that have recursive fields inside of a oneof. --- pb-jelly-gen/codegen/codegen.py | 60 ++- .../proto_pbtest/src/pbtest2.rs.expected | 4 +- .../proto_pbtest/src/pbtest3.rs.expected | 445 ++++++++++++++++++ pb-test/proto/packages/pbtest/pbtest3.proto | 16 + pb-test/src/lib.rs | 1 - pb-test/src/pbtest.rs | 17 + 6 files changed, 507 insertions(+), 36 deletions(-) diff --git a/pb-jelly-gen/codegen/codegen.py b/pb-jelly-gen/codegen/codegen.py index 899ca0d..08275a3 100755 --- a/pb-jelly-gen/codegen/codegen.py +++ b/pb-jelly-gen/codegen/codegen.py @@ -1,12 +1,9 @@ #!/usr/bin/env python3 -import itertools import os import re import sys -import google.protobuf - from collections import defaultdict, namedtuple, OrderedDict from contextlib import contextmanager from typing import ( @@ -35,7 +32,6 @@ OneofDescriptorProto, SourceCodeInfo, ) -from google.protobuf.message import Message from proto.rust import extensions_pb2 @@ -280,6 +276,8 @@ def custom_type(self) -> Text: return self.field.options.Extensions[extensions_pb2.type] def is_nullable(self) -> bool: + if self.oneof: + return False if ( self.field.type in PRIMITIVE_TYPES and self.is_proto3 @@ -301,7 +299,7 @@ def is_nullable(self) -> bool: def is_empty_oneof_field(self) -> bool: assert self.oneof - return self.field.type_name == ".google.protobuf.Empty" + return self.field.type_name == ".google.protobuf.Empty" and not self.is_boxed() def can_be_packed(self) -> bool: # Return true if incoming messages could be packed on the wire @@ -367,7 +365,7 @@ def set_method(self) -> Tuple[Text, Text]: return self.rust_type(), "v" elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE: if self.is_boxed(): - return "Box<%s>" % self.rust_type(), "v" + return "::std::boxed::Box<%s>" % self.rust_type(), "v" else: return self.rust_type(), "v" raise AssertionError("Unexpected field type") @@ -403,7 +401,7 @@ def take_method(self) -> Tuple[Optional[Text], Optional[Text]]: return self.rust_type(), expr elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE: if self.is_boxed(): - return "Box<%s>" % self.rust_type(), expr + return "::std::boxed::Box<%s>" % self.rust_type(), expr else: return self.rust_type(), expr raise AssertionError("Unexpected field type") @@ -492,19 +490,18 @@ def rust_type(self) -> Text: "Unsupported type: {!r}".format(FieldDescriptorProto.Type.Name(typ)) ) - def __str__(self) -> str: + def storage_type(self) -> str: rust_type = self.rust_type() + if self.is_boxed(): + rust_type = "::std::boxed::Box<%s>" % rust_type + if self.is_repeated(): - return "::std::vec::Vec<%s>" % rust_type - elif self.is_nullable() and self.is_boxed(): - return "::std::option::Option<::std::boxed::Box<%s>>" % str(rust_type) - elif self.is_boxed(): - return "::std::boxed::Box<%s>" % rust_type + rust_type = "::std::vec::Vec<%s>" % rust_type elif self.is_nullable(): - return "::std::option::Option<%s>" % rust_type - else: - return str(rust_type) + rust_type = "::std::option::Option<%s>" % rust_type + + return rust_type def oneof_field_match(self, var: Text) -> Text: if self.is_empty_oneof_field(): @@ -584,6 +581,11 @@ def field_iter( "let %s: &%s = &::std::default::Default::default();" % (var, typ.rust_type()) ) + elif typ.is_boxed(): + ctx.write( + "let %(var)s: &%(typ)s = &**%(var)s;" + % dict(var=var, typ=typ.rust_type()) + ) yield elif ( field.type == FieldDescriptorProto.TYPE_MESSAGE @@ -612,7 +614,7 @@ def field_iter( with block( ctx, "if self.%s != <%s as ::std::default::Default>::default()" - % (field.name, typ), + % (field.name, typ.storage_type()), ): if typ.is_boxed(): ctx.write("let %s = &*self.%s;" % (var, field.name)) @@ -937,7 +939,7 @@ def gen_msg( if typ.oneof: oneof_fields[typ.oneof.name].append(field) else: - self.write("pub %s: %s," % (field.name, typ)) + self.write("pub %s: %s," % (field.name, typ.storage_type())) for oneof in oneof_decls: if oneof_nullable(oneof): @@ -959,7 +961,7 @@ def gen_msg( with block(self, "pub enum " + oneof_msg_name(name, oneof)): for oneof_field in oneof_fields[oneof.name]: typ = self.rust_type(msg_type, oneof_field) - self.write("%s," % typ.oneof_field_match(typ.rust_type())) + self.write("%s," % typ.oneof_field_match(typ.storage_type())) if not self.is_proto3: with block(self, "impl " + name): @@ -1461,6 +1463,8 @@ def gen_msg( typ.oneof.name, ), ): + if typ.is_boxed(): + self.write("let val = &mut **val;") self.write( "return ::pb_jelly::reflection::FieldMut::Value(val);" ) @@ -1683,6 +1687,9 @@ def calc_impls( msg_impls_eq = False if not self.impls_by_msg[field_fq_msg].Copy: msg_impls_copy = False + + if rust_type.is_boxed(): + msg_impls_copy = False else: raise RuntimeError( "Unsupported type: {!r}".format(FieldDescriptorProto.Type.Name(typ)) @@ -1744,20 +1751,7 @@ def _set_boxed_if_recursive( visited, looking_for, self.find_msg(field.type_name) ) if need_box or field.type_name == looking_for: - # We only box normal fields, not oneof variants - # - # TODO: We are restricting this case because the codegen - # can't currently box oneof variants. This means there are - # cases won't work with the Rust codegen. Specifically, if - # you have a oneof variant that directly references the - # containing message or is co-recursive to another message, - # the codegen won't box the variant and the resulting code - # won't compile. - if not ( - field.HasField("oneof_index") - and pt.typ.oneof_decl[field.oneof_index] - ): - field.options.Extensions[extensions_pb2.box_it] = True + field.options.Extensions[extensions_pb2.box_it] = True any_field_boxed = True return any_field_boxed diff --git a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest2.rs.expected b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest2.rs.expected index 6320bb6..603713d 100644 --- a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest2.rs.expected +++ b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest2.rs.expected @@ -3085,10 +3085,10 @@ impl TestMessage { pub fn has_optional_foreign_message_boxed(&self) -> bool { self.optional_foreign_message_boxed.is_some() } - pub fn set_optional_foreign_message_boxed(&mut self, v: Box) { + pub fn set_optional_foreign_message_boxed(&mut self, v: ::std::boxed::Box) { self.optional_foreign_message_boxed = Some(v); } - pub fn take_optional_foreign_message_boxed(&mut self) -> Box { + pub fn take_optional_foreign_message_boxed(&mut self) -> ::std::boxed::Box { self.optional_foreign_message_boxed.take().unwrap_or_default() } pub fn get_optional_foreign_message_boxed(&self) -> &ForeignMessage { diff --git a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected index 5ca70e5..571b2fe 100644 --- a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected +++ b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected @@ -5817,6 +5817,110 @@ impl ::pb_jelly::Reflection for TestMessage3_NestedMessage_Dir { } } +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct TestBoxedNonnullable { + pub field: ::std::boxed::Box, +} +impl ::std::default::Default for TestBoxedNonnullable { + fn default() -> Self { + TestBoxedNonnullable { + field: ::std::default::Default::default(), + } + } +} +lazy_static! { + pub static ref TestBoxedNonnullable_default: TestBoxedNonnullable = TestBoxedNonnullable::default(); +} +impl ::pb_jelly::Message for TestBoxedNonnullable { + fn descriptor(&self) -> ::std::option::Option<::pb_jelly::MessageDescriptor> { + Some(::pb_jelly::MessageDescriptor { + name: "TestBoxedNonnullable", + full_name: "pbtest.TestBoxedNonnullable", + fields: &[ + ::pb_jelly::FieldDescriptor { + name: "field", + full_name: "pbtest.TestBoxedNonnullable.field", + index: 0, + number: 1, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: None, + }, + ], + oneofs: &[ + ], + }) + } + fn compute_size(&self) -> usize { + let mut size = 0; + let mut field_size = 0; + { + let val = &*self.field; + let l = ::pb_jelly::Message::compute_size(val); + field_size += ::pb_jelly::wire_format::serialized_length(1); + field_size += ::pb_jelly::varint::serialized_length(l as u64); + field_size += l; + } + size += field_size; + size + } + fn compute_grpc_slices_size(&self) -> usize { + let mut size = 0; + { + let val = &*self.field; + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + size + } + fn serialize(&self, w: &mut W) -> ::std::io::Result<()> { + { + let val = &*self.field; + ::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + Ok(()) + } + fn deserialize(&mut self, mut buf: &mut B) -> ::std::io::Result<()> { + while let Some((field_number, typ)) = ::pb_jelly::wire_format::read(&mut buf)? { + match field_number { + 1 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "TestBoxedNonnullable", 1)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: ForeignMessage3 = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.field = Box::new(val); + } + _ => { + ::pb_jelly::skip(typ, &mut buf)?; + } + } + } + Ok(()) + } +} +impl ::pb_jelly::Reflection for TestBoxedNonnullable { + fn which_one_of(&self, oneof_name: &str) -> ::std::option::Option<&'static str> { + match oneof_name { + _ => { + panic!("unknown oneof name given"); + } + } + } + fn get_field_mut(&mut self, field_name: &str) -> ::pb_jelly::reflection::FieldMut<'_> { + match field_name { + "field" => { + ::pb_jelly::reflection::FieldMut::Value(self.field.as_mut()) + } + _ => { + panic!("unknown field name given") + } + } + } +} + #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct TestMessage3NonNullableOneof { pub other_field: u64, @@ -8783,3 +8887,344 @@ impl ::pb_jelly::Reflection for TestProto3Optional { } } +#[derive(Clone, Debug, PartialEq)] +pub struct RecursiveOneof { + /// This field should be boxed automatically. + /// Boxing should override the empty-oneof-field special case. + pub oneof_field: ::std::option::Option, +} +#[derive(Clone, Debug, PartialEq)] +pub enum RecursiveOneof_OneofField { + Field(::std::boxed::Box), + Empty, + BoxedEmpty(::std::boxed::Box<::proto_google::empty::Empty>), + NotBoxed(ForeignMessage3), + Boxed(::std::boxed::Box), +} +impl ::std::default::Default for RecursiveOneof { + fn default() -> Self { + RecursiveOneof { + oneof_field: None, + } + } +} +lazy_static! { + pub static ref RecursiveOneof_default: RecursiveOneof = RecursiveOneof::default(); +} +impl ::pb_jelly::Message for RecursiveOneof { + fn descriptor(&self) -> ::std::option::Option<::pb_jelly::MessageDescriptor> { + Some(::pb_jelly::MessageDescriptor { + name: "RecursiveOneof", + full_name: "pbtest.RecursiveOneof", + fields: &[ + ::pb_jelly::FieldDescriptor { + name: "field", + full_name: "pbtest.RecursiveOneof.field", + index: 0, + number: 1, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: Some(0), + }, + ::pb_jelly::FieldDescriptor { + name: "empty", + full_name: "pbtest.RecursiveOneof.empty", + index: 1, + number: 2, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: Some(0), + }, + ::pb_jelly::FieldDescriptor { + name: "boxed_empty", + full_name: "pbtest.RecursiveOneof.boxed_empty", + index: 2, + number: 3, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: Some(0), + }, + ::pb_jelly::FieldDescriptor { + name: "not_boxed", + full_name: "pbtest.RecursiveOneof.not_boxed", + index: 3, + number: 4, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: Some(0), + }, + ::pb_jelly::FieldDescriptor { + name: "boxed", + full_name: "pbtest.RecursiveOneof.boxed", + index: 4, + number: 5, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: Some(0), + }, + ], + oneofs: &[ + ::pb_jelly::OneofDescriptor { + name: "oneof_field", + }, + ], + }) + } + fn compute_size(&self) -> usize { + let mut size = 0; + let mut field_size = 0; + if let Some(RecursiveOneof_OneofField::Field(ref val)) = self.oneof_field { + let val: &RecursiveOneof = &**val; + let l = ::pb_jelly::Message::compute_size(val); + field_size += ::pb_jelly::wire_format::serialized_length(1); + field_size += ::pb_jelly::varint::serialized_length(l as u64); + field_size += l; + } + size += field_size; + let mut empty_size = 0; + if let Some(RecursiveOneof_OneofField::Empty) = self.oneof_field { + let val: &::proto_google::empty::Empty = &::std::default::Default::default(); + let l = ::pb_jelly::Message::compute_size(val); + empty_size += ::pb_jelly::wire_format::serialized_length(2); + empty_size += ::pb_jelly::varint::serialized_length(l as u64); + empty_size += l; + } + size += empty_size; + let mut boxed_empty_size = 0; + if let Some(RecursiveOneof_OneofField::BoxedEmpty(ref val)) = self.oneof_field { + let val: &::proto_google::empty::Empty = &**val; + let l = ::pb_jelly::Message::compute_size(val); + boxed_empty_size += ::pb_jelly::wire_format::serialized_length(3); + boxed_empty_size += ::pb_jelly::varint::serialized_length(l as u64); + boxed_empty_size += l; + } + size += boxed_empty_size; + let mut not_boxed_size = 0; + if let Some(RecursiveOneof_OneofField::NotBoxed(ref val)) = self.oneof_field { + let l = ::pb_jelly::Message::compute_size(val); + not_boxed_size += ::pb_jelly::wire_format::serialized_length(4); + not_boxed_size += ::pb_jelly::varint::serialized_length(l as u64); + not_boxed_size += l; + } + size += not_boxed_size; + let mut boxed_size = 0; + if let Some(RecursiveOneof_OneofField::Boxed(ref val)) = self.oneof_field { + let val: &ForeignMessage3 = &**val; + let l = ::pb_jelly::Message::compute_size(val); + boxed_size += ::pb_jelly::wire_format::serialized_length(5); + boxed_size += ::pb_jelly::varint::serialized_length(l as u64); + boxed_size += l; + } + size += boxed_size; + size + } + fn compute_grpc_slices_size(&self) -> usize { + let mut size = 0; + if let Some(RecursiveOneof_OneofField::Field(ref val)) = self.oneof_field { + let val: &RecursiveOneof = &**val; + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + if let Some(RecursiveOneof_OneofField::Empty) = self.oneof_field { + let val: &::proto_google::empty::Empty = &::std::default::Default::default(); + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + if let Some(RecursiveOneof_OneofField::BoxedEmpty(ref val)) = self.oneof_field { + let val: &::proto_google::empty::Empty = &**val; + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + if let Some(RecursiveOneof_OneofField::NotBoxed(ref val)) = self.oneof_field { + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + if let Some(RecursiveOneof_OneofField::Boxed(ref val)) = self.oneof_field { + let val: &ForeignMessage3 = &**val; + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + size + } + fn serialize(&self, w: &mut W) -> ::std::io::Result<()> { + if let Some(RecursiveOneof_OneofField::Field(ref val)) = self.oneof_field { + let val: &RecursiveOneof = &**val; + ::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + if let Some(RecursiveOneof_OneofField::Empty) = self.oneof_field { + let val: &::proto_google::empty::Empty = &::std::default::Default::default(); + ::pb_jelly::wire_format::write(2, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + if let Some(RecursiveOneof_OneofField::BoxedEmpty(ref val)) = self.oneof_field { + let val: &::proto_google::empty::Empty = &**val; + ::pb_jelly::wire_format::write(3, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + if let Some(RecursiveOneof_OneofField::NotBoxed(ref val)) = self.oneof_field { + ::pb_jelly::wire_format::write(4, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + if let Some(RecursiveOneof_OneofField::Boxed(ref val)) = self.oneof_field { + let val: &ForeignMessage3 = &**val; + ::pb_jelly::wire_format::write(5, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + Ok(()) + } + fn deserialize(&mut self, mut buf: &mut B) -> ::std::io::Result<()> { + while let Some((field_number, typ)) = ::pb_jelly::wire_format::read(&mut buf)? { + match field_number { + 1 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "RecursiveOneof", 1)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: RecursiveOneof = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.oneof_field = Some(RecursiveOneof_OneofField::Field(Box::new(val))); + } + 2 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "RecursiveOneof", 2)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: ::proto_google::empty::Empty = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.oneof_field = Some(RecursiveOneof_OneofField::Empty); + } + 3 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "RecursiveOneof", 3)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: ::proto_google::empty::Empty = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.oneof_field = Some(RecursiveOneof_OneofField::BoxedEmpty(Box::new(val))); + } + 4 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "RecursiveOneof", 4)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: ForeignMessage3 = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.oneof_field = Some(RecursiveOneof_OneofField::NotBoxed(val)); + } + 5 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "RecursiveOneof", 5)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: ForeignMessage3 = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.oneof_field = Some(RecursiveOneof_OneofField::Boxed(Box::new(val))); + } + _ => { + ::pb_jelly::skip(typ, &mut buf)?; + } + } + } + Ok(()) + } +} +impl ::pb_jelly::Reflection for RecursiveOneof { + fn which_one_of(&self, oneof_name: &str) -> ::std::option::Option<&'static str> { + match oneof_name { + "oneof_field" => { + if let Some(RecursiveOneof_OneofField::Field(ref val)) = self.oneof_field { + let val: &RecursiveOneof = &**val; + return Some("field"); + } + if let Some(RecursiveOneof_OneofField::Empty) = self.oneof_field { + let val: &::proto_google::empty::Empty = &::std::default::Default::default(); + return Some("empty"); + } + if let Some(RecursiveOneof_OneofField::BoxedEmpty(ref val)) = self.oneof_field { + let val: &::proto_google::empty::Empty = &**val; + return Some("boxed_empty"); + } + if let Some(RecursiveOneof_OneofField::NotBoxed(ref val)) = self.oneof_field { + return Some("not_boxed"); + } + if let Some(RecursiveOneof_OneofField::Boxed(ref val)) = self.oneof_field { + let val: &ForeignMessage3 = &**val; + return Some("boxed"); + } + None + } + _ => { + panic!("unknown oneof name given"); + } + } + } + fn get_field_mut(&mut self, field_name: &str) -> ::pb_jelly::reflection::FieldMut<'_> { + match field_name { + "field" => { + match self.oneof_field { + Some(RecursiveOneof_OneofField::Field(_)) => (), + _ => { + self.oneof_field = Some(RecursiveOneof_OneofField::Field(::std::default::Default::default())); + }, + } + if let Some(RecursiveOneof_OneofField::Field(ref mut val)) = self.oneof_field { + let val = &mut **val; + return ::pb_jelly::reflection::FieldMut::Value(val); + } + unreachable!() + } + "empty" => { + match self.oneof_field { + Some(RecursiveOneof_OneofField::Empty) => (), + _ => { + self.oneof_field = Some(RecursiveOneof_OneofField::Empty); + }, + } + ::pb_jelly::reflection::FieldMut::Empty + } + "boxed_empty" => { + match self.oneof_field { + Some(RecursiveOneof_OneofField::BoxedEmpty(_)) => (), + _ => { + self.oneof_field = Some(RecursiveOneof_OneofField::BoxedEmpty(::std::default::Default::default())); + }, + } + if let Some(RecursiveOneof_OneofField::BoxedEmpty(ref mut val)) = self.oneof_field { + let val = &mut **val; + return ::pb_jelly::reflection::FieldMut::Value(val); + } + unreachable!() + } + "not_boxed" => { + match self.oneof_field { + Some(RecursiveOneof_OneofField::NotBoxed(_)) => (), + _ => { + self.oneof_field = Some(RecursiveOneof_OneofField::NotBoxed(::std::default::Default::default())); + }, + } + if let Some(RecursiveOneof_OneofField::NotBoxed(ref mut val)) = self.oneof_field { + return ::pb_jelly::reflection::FieldMut::Value(val); + } + unreachable!() + } + "boxed" => { + match self.oneof_field { + Some(RecursiveOneof_OneofField::Boxed(_)) => (), + _ => { + self.oneof_field = Some(RecursiveOneof_OneofField::Boxed(::std::default::Default::default())); + }, + } + if let Some(RecursiveOneof_OneofField::Boxed(ref mut val)) = self.oneof_field { + let val = &mut **val; + return ::pb_jelly::reflection::FieldMut::Value(val); + } + unreachable!() + } + _ => { + panic!("unknown field name given") + } + } + } +} + diff --git a/pb-test/proto/packages/pbtest/pbtest3.proto b/pb-test/proto/packages/pbtest/pbtest3.proto index 3d292a8..a05b619 100644 --- a/pb-test/proto/packages/pbtest/pbtest3.proto +++ b/pb-test/proto/packages/pbtest/pbtest3.proto @@ -208,6 +208,10 @@ message TestMessage3 { repeated bytes zero_or_fixed_length_repeated = 79 [(rust.type)="Option<[u8; 4]>"]; } +message TestBoxedNonnullable { + ForeignMessage3 field = 1 [(rust.box_it)=true, (rust.nullable_field)=false]; +} + message TestMessage3NonNullableOneof { oneof non_nullable_oneof { option (rust.nullable) = false; @@ -347,3 +351,15 @@ message TestProto3Optional { string real_oneof_2_2 = 18; } } + +message RecursiveOneof { + oneof oneof_field { + // This field should be boxed automatically. + RecursiveOneof field = 1; + google.protobuf.Empty empty = 2; + // Boxing should override the empty-oneof-field special case. + google.protobuf.Empty boxed_empty = 3 [(rust.box_it) = true]; + ForeignMessage3 not_boxed = 4; + ForeignMessage3 boxed = 5 [(rust.box_it) = true]; + } +} diff --git a/pb-test/src/lib.rs b/pb-test/src/lib.rs index 32bd324..e215181 100644 --- a/pb-test/src/lib.rs +++ b/pb-test/src/lib.rs @@ -1,5 +1,4 @@ #![warn(rust_2018_idioms)] -#![feature(bench_black_box)] #![feature(test)] #[allow(unused_extern_crates)] diff --git a/pb-test/src/pbtest.rs b/pb-test/src/pbtest.rs index e5c8aea..5ce8b26 100644 --- a/pb-test/src/pbtest.rs +++ b/pb-test/src/pbtest.rs @@ -871,3 +871,20 @@ fn test_proto3_optional() { .unwrap(); assert_eq!(proto.a_int32, Some(456)); } + +// Test that boxing works properly for oneof fields. +#[test] +fn test_recursive_oneof() { + let message = RecursiveOneof { + oneof_field: Some(RecursiveOneof_OneofField::BoxedEmpty(std::default::Default::default())), + }; + check_roundtrip(&message); + let message = RecursiveOneof { + oneof_field: Some(RecursiveOneof_OneofField::Field(Box::new(message))), + }; + check_roundtrip(&message); + let message = RecursiveOneof { + oneof_field: Some(RecursiveOneof_OneofField::Field(Box::new(message))), + }; + check_roundtrip(&message); +}