diff --git a/pb-jelly-gen/codegen/codegen.py b/pb-jelly-gen/codegen/codegen.py index 899ca0d..762c1d7 100755 --- a/pb-jelly-gen/codegen/codegen.py +++ b/pb-jelly-gen/codegen/codegen.py @@ -1,12 +1,9 @@ #!/usr/bin/env python3 -import itertools import os import re import sys -import google.protobuf - from collections import defaultdict, namedtuple, OrderedDict from contextlib import contextmanager from typing import ( @@ -35,7 +32,6 @@ OneofDescriptorProto, SourceCodeInfo, ) -from google.protobuf.message import Message from proto.rust import extensions_pb2 @@ -301,7 +297,7 @@ def is_nullable(self) -> bool: def is_empty_oneof_field(self) -> bool: assert self.oneof - return self.field.type_name == ".google.protobuf.Empty" + return self.field.type_name == ".google.protobuf.Empty" and not self.is_boxed() def can_be_packed(self) -> bool: # Return true if incoming messages could be packed on the wire @@ -366,10 +362,7 @@ def set_method(self) -> Tuple[Text, Text]: elif self.field.type == FieldDescriptorProto.TYPE_ENUM: return self.rust_type(), "v" elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE: - if self.is_boxed(): - return "Box<%s>" % self.rust_type(), "v" - else: - return self.rust_type(), "v" + return self.rust_type(maybe_boxed=True), "v" raise AssertionError("Unexpected field type") def take_method(self) -> Tuple[Optional[Text], Optional[Text]]: @@ -402,10 +395,7 @@ def take_method(self) -> Tuple[Optional[Text], Optional[Text]]: elif self.field.type == FieldDescriptorProto.TYPE_ENUM: return self.rust_type(), expr elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE: - if self.is_boxed(): - return "Box<%s>" % self.rust_type(), expr - else: - return self.rust_type(), expr + return self.rust_type(maybe_boxed=True), expr raise AssertionError("Unexpected field type") def get_method(self) -> Tuple[Text, Text]: @@ -456,7 +446,10 @@ def get_method(self) -> Tuple[Text, Text]: ) raise AssertionError("Unexpected field type") - def rust_type(self) -> Text: + def rust_type(self, maybe_boxed: bool = False) -> Text: + if maybe_boxed and self.is_boxed(): + return "::std::boxed::Box<%s>" % self.rust_type(maybe_boxed=False) + typ = self.field.type if self.has_custom_type(): @@ -493,18 +486,14 @@ def rust_type(self) -> Text: ) def __str__(self) -> str: - rust_type = self.rust_type() + rust_type = self.rust_type(maybe_boxed=True) if self.is_repeated(): return "::std::vec::Vec<%s>" % rust_type - elif self.is_nullable() and self.is_boxed(): - return "::std::option::Option<::std::boxed::Box<%s>>" % str(rust_type) - elif self.is_boxed(): - return "::std::boxed::Box<%s>" % rust_type elif self.is_nullable(): return "::std::option::Option<%s>" % rust_type else: - return str(rust_type) + return rust_type def oneof_field_match(self, var: Text) -> Text: if self.is_empty_oneof_field(): @@ -584,6 +573,11 @@ def field_iter( "let %s: &%s = &::std::default::Default::default();" % (var, typ.rust_type()) ) + elif typ.is_boxed(): + ctx.write( + "let %(var)s: &%(typ)s = &**%(var)s;" + % dict(var=var, typ=typ.rust_type()) + ) yield elif ( field.type == FieldDescriptorProto.TYPE_MESSAGE @@ -959,7 +953,9 @@ def gen_msg( with block(self, "pub enum " + oneof_msg_name(name, oneof)): for oneof_field in oneof_fields[oneof.name]: typ = self.rust_type(msg_type, oneof_field) - self.write("%s," % typ.oneof_field_match(typ.rust_type())) + self.write( + "%s," % typ.oneof_field_match(typ.rust_type(maybe_boxed=True)) + ) if not self.is_proto3: with block(self, "impl " + name): @@ -1461,6 +1457,8 @@ def gen_msg( typ.oneof.name, ), ): + if typ.is_boxed(): + self.write("let val = &mut **val;") self.write( "return ::pb_jelly::reflection::FieldMut::Value(val);" ) @@ -1744,20 +1742,7 @@ def _set_boxed_if_recursive( visited, looking_for, self.find_msg(field.type_name) ) if need_box or field.type_name == looking_for: - # We only box normal fields, not oneof variants - # - # TODO: We are restricting this case because the codegen - # can't currently box oneof variants. This means there are - # cases won't work with the Rust codegen. Specifically, if - # you have a oneof variant that directly references the - # containing message or is co-recursive to another message, - # the codegen won't box the variant and the resulting code - # won't compile. - if not ( - field.HasField("oneof_index") - and pt.typ.oneof_decl[field.oneof_index] - ): - field.options.Extensions[extensions_pb2.box_it] = True + field.options.Extensions[extensions_pb2.box_it] = True any_field_boxed = True return any_field_boxed diff --git a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest2.rs.expected b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest2.rs.expected index 6320bb6..603713d 100644 --- a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest2.rs.expected +++ b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest2.rs.expected @@ -3085,10 +3085,10 @@ impl TestMessage { pub fn has_optional_foreign_message_boxed(&self) -> bool { self.optional_foreign_message_boxed.is_some() } - pub fn set_optional_foreign_message_boxed(&mut self, v: Box) { + pub fn set_optional_foreign_message_boxed(&mut self, v: ::std::boxed::Box) { self.optional_foreign_message_boxed = Some(v); } - pub fn take_optional_foreign_message_boxed(&mut self) -> Box { + pub fn take_optional_foreign_message_boxed(&mut self) -> ::std::boxed::Box { self.optional_foreign_message_boxed.take().unwrap_or_default() } pub fn get_optional_foreign_message_boxed(&self) -> &ForeignMessage { diff --git a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected index 5ca70e5..54ba73f 100644 --- a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected +++ b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected @@ -8783,3 +8783,344 @@ impl ::pb_jelly::Reflection for TestProto3Optional { } } +#[derive(Clone, Debug, PartialEq)] +pub struct RecursiveOneof { + /// This field should be boxed automatically. + /// Boxing should override the empty-oneof-field special case. + pub oneof_field: ::std::option::Option, +} +#[derive(Clone, Debug, PartialEq)] +pub enum RecursiveOneof_OneofField { + Field(::std::boxed::Box), + Empty, + BoxedEmpty(::std::boxed::Box<::proto_google::empty::Empty>), + NotBoxed(ForeignMessage3), + Boxed(::std::boxed::Box), +} +impl ::std::default::Default for RecursiveOneof { + fn default() -> Self { + RecursiveOneof { + oneof_field: None, + } + } +} +lazy_static! { + pub static ref RecursiveOneof_default: RecursiveOneof = RecursiveOneof::default(); +} +impl ::pb_jelly::Message for RecursiveOneof { + fn descriptor(&self) -> ::std::option::Option<::pb_jelly::MessageDescriptor> { + Some(::pb_jelly::MessageDescriptor { + name: "RecursiveOneof", + full_name: "pbtest.RecursiveOneof", + fields: &[ + ::pb_jelly::FieldDescriptor { + name: "field", + full_name: "pbtest.RecursiveOneof.field", + index: 0, + number: 1, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: Some(0), + }, + ::pb_jelly::FieldDescriptor { + name: "empty", + full_name: "pbtest.RecursiveOneof.empty", + index: 1, + number: 2, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: Some(0), + }, + ::pb_jelly::FieldDescriptor { + name: "boxed_empty", + full_name: "pbtest.RecursiveOneof.boxed_empty", + index: 2, + number: 3, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: Some(0), + }, + ::pb_jelly::FieldDescriptor { + name: "not_boxed", + full_name: "pbtest.RecursiveOneof.not_boxed", + index: 3, + number: 4, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: Some(0), + }, + ::pb_jelly::FieldDescriptor { + name: "boxed", + full_name: "pbtest.RecursiveOneof.boxed", + index: 4, + number: 5, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: Some(0), + }, + ], + oneofs: &[ + ::pb_jelly::OneofDescriptor { + name: "oneof_field", + }, + ], + }) + } + fn compute_size(&self) -> usize { + let mut size = 0; + let mut field_size = 0; + if let Some(RecursiveOneof_OneofField::Field(ref val)) = self.oneof_field { + let val: &RecursiveOneof = &**val; + let l = ::pb_jelly::Message::compute_size(val); + field_size += ::pb_jelly::wire_format::serialized_length(1); + field_size += ::pb_jelly::varint::serialized_length(l as u64); + field_size += l; + } + size += field_size; + let mut empty_size = 0; + if let Some(RecursiveOneof_OneofField::Empty) = self.oneof_field { + let val: &::proto_google::empty::Empty = &::std::default::Default::default(); + let l = ::pb_jelly::Message::compute_size(val); + empty_size += ::pb_jelly::wire_format::serialized_length(2); + empty_size += ::pb_jelly::varint::serialized_length(l as u64); + empty_size += l; + } + size += empty_size; + let mut boxed_empty_size = 0; + if let Some(RecursiveOneof_OneofField::BoxedEmpty(ref val)) = self.oneof_field { + let val: &::proto_google::empty::Empty = &**val; + let l = ::pb_jelly::Message::compute_size(val); + boxed_empty_size += ::pb_jelly::wire_format::serialized_length(3); + boxed_empty_size += ::pb_jelly::varint::serialized_length(l as u64); + boxed_empty_size += l; + } + size += boxed_empty_size; + let mut not_boxed_size = 0; + if let Some(RecursiveOneof_OneofField::NotBoxed(ref val)) = self.oneof_field { + let l = ::pb_jelly::Message::compute_size(val); + not_boxed_size += ::pb_jelly::wire_format::serialized_length(4); + not_boxed_size += ::pb_jelly::varint::serialized_length(l as u64); + not_boxed_size += l; + } + size += not_boxed_size; + let mut boxed_size = 0; + if let Some(RecursiveOneof_OneofField::Boxed(ref val)) = self.oneof_field { + let val: &ForeignMessage3 = &**val; + let l = ::pb_jelly::Message::compute_size(val); + boxed_size += ::pb_jelly::wire_format::serialized_length(5); + boxed_size += ::pb_jelly::varint::serialized_length(l as u64); + boxed_size += l; + } + size += boxed_size; + size + } + fn compute_grpc_slices_size(&self) -> usize { + let mut size = 0; + if let Some(RecursiveOneof_OneofField::Field(ref val)) = self.oneof_field { + let val: &RecursiveOneof = &**val; + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + if let Some(RecursiveOneof_OneofField::Empty) = self.oneof_field { + let val: &::proto_google::empty::Empty = &::std::default::Default::default(); + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + if let Some(RecursiveOneof_OneofField::BoxedEmpty(ref val)) = self.oneof_field { + let val: &::proto_google::empty::Empty = &**val; + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + if let Some(RecursiveOneof_OneofField::NotBoxed(ref val)) = self.oneof_field { + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + if let Some(RecursiveOneof_OneofField::Boxed(ref val)) = self.oneof_field { + let val: &ForeignMessage3 = &**val; + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + size + } + fn serialize(&self, w: &mut W) -> ::std::io::Result<()> { + if let Some(RecursiveOneof_OneofField::Field(ref val)) = self.oneof_field { + let val: &RecursiveOneof = &**val; + ::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + if let Some(RecursiveOneof_OneofField::Empty) = self.oneof_field { + let val: &::proto_google::empty::Empty = &::std::default::Default::default(); + ::pb_jelly::wire_format::write(2, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + if let Some(RecursiveOneof_OneofField::BoxedEmpty(ref val)) = self.oneof_field { + let val: &::proto_google::empty::Empty = &**val; + ::pb_jelly::wire_format::write(3, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + if let Some(RecursiveOneof_OneofField::NotBoxed(ref val)) = self.oneof_field { + ::pb_jelly::wire_format::write(4, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + if let Some(RecursiveOneof_OneofField::Boxed(ref val)) = self.oneof_field { + let val: &ForeignMessage3 = &**val; + ::pb_jelly::wire_format::write(5, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + Ok(()) + } + fn deserialize(&mut self, mut buf: &mut B) -> ::std::io::Result<()> { + while let Some((field_number, typ)) = ::pb_jelly::wire_format::read(&mut buf)? { + match field_number { + 1 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "RecursiveOneof", 1)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: RecursiveOneof = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.oneof_field = Some(RecursiveOneof_OneofField::Field(Box::new(val))); + } + 2 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "RecursiveOneof", 2)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: ::proto_google::empty::Empty = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.oneof_field = Some(RecursiveOneof_OneofField::Empty); + } + 3 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "RecursiveOneof", 3)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: ::proto_google::empty::Empty = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.oneof_field = Some(RecursiveOneof_OneofField::BoxedEmpty(Box::new(val))); + } + 4 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "RecursiveOneof", 4)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: ForeignMessage3 = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.oneof_field = Some(RecursiveOneof_OneofField::NotBoxed(val)); + } + 5 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "RecursiveOneof", 5)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: ForeignMessage3 = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.oneof_field = Some(RecursiveOneof_OneofField::Boxed(Box::new(val))); + } + _ => { + ::pb_jelly::skip(typ, &mut buf)?; + } + } + } + Ok(()) + } +} +impl ::pb_jelly::Reflection for RecursiveOneof { + fn which_one_of(&self, oneof_name: &str) -> ::std::option::Option<&'static str> { + match oneof_name { + "oneof_field" => { + if let Some(RecursiveOneof_OneofField::Field(ref val)) = self.oneof_field { + let val: &RecursiveOneof = &**val; + return Some("field"); + } + if let Some(RecursiveOneof_OneofField::Empty) = self.oneof_field { + let val: &::proto_google::empty::Empty = &::std::default::Default::default(); + return Some("empty"); + } + if let Some(RecursiveOneof_OneofField::BoxedEmpty(ref val)) = self.oneof_field { + let val: &::proto_google::empty::Empty = &**val; + return Some("boxed_empty"); + } + if let Some(RecursiveOneof_OneofField::NotBoxed(ref val)) = self.oneof_field { + return Some("not_boxed"); + } + if let Some(RecursiveOneof_OneofField::Boxed(ref val)) = self.oneof_field { + let val: &ForeignMessage3 = &**val; + return Some("boxed"); + } + None + } + _ => { + panic!("unknown oneof name given"); + } + } + } + fn get_field_mut(&mut self, field_name: &str) -> ::pb_jelly::reflection::FieldMut<'_> { + match field_name { + "field" => { + match self.oneof_field { + Some(RecursiveOneof_OneofField::Field(_)) => (), + _ => { + self.oneof_field = Some(RecursiveOneof_OneofField::Field(::std::default::Default::default())); + }, + } + if let Some(RecursiveOneof_OneofField::Field(ref mut val)) = self.oneof_field { + let val = &mut **val; + return ::pb_jelly::reflection::FieldMut::Value(val); + } + unreachable!() + } + "empty" => { + match self.oneof_field { + Some(RecursiveOneof_OneofField::Empty) => (), + _ => { + self.oneof_field = Some(RecursiveOneof_OneofField::Empty); + }, + } + ::pb_jelly::reflection::FieldMut::Empty + } + "boxed_empty" => { + match self.oneof_field { + Some(RecursiveOneof_OneofField::BoxedEmpty(_)) => (), + _ => { + self.oneof_field = Some(RecursiveOneof_OneofField::BoxedEmpty(::std::default::Default::default())); + }, + } + if let Some(RecursiveOneof_OneofField::BoxedEmpty(ref mut val)) = self.oneof_field { + let val = &mut **val; + return ::pb_jelly::reflection::FieldMut::Value(val); + } + unreachable!() + } + "not_boxed" => { + match self.oneof_field { + Some(RecursiveOneof_OneofField::NotBoxed(_)) => (), + _ => { + self.oneof_field = Some(RecursiveOneof_OneofField::NotBoxed(::std::default::Default::default())); + }, + } + if let Some(RecursiveOneof_OneofField::NotBoxed(ref mut val)) = self.oneof_field { + return ::pb_jelly::reflection::FieldMut::Value(val); + } + unreachable!() + } + "boxed" => { + match self.oneof_field { + Some(RecursiveOneof_OneofField::Boxed(_)) => (), + _ => { + self.oneof_field = Some(RecursiveOneof_OneofField::Boxed(::std::default::Default::default())); + }, + } + if let Some(RecursiveOneof_OneofField::Boxed(ref mut val)) = self.oneof_field { + let val = &mut **val; + return ::pb_jelly::reflection::FieldMut::Value(val); + } + unreachable!() + } + _ => { + panic!("unknown field name given") + } + } + } +} + diff --git a/pb-test/proto/packages/pbtest/pbtest3.proto b/pb-test/proto/packages/pbtest/pbtest3.proto index 3d292a8..b8099d9 100644 --- a/pb-test/proto/packages/pbtest/pbtest3.proto +++ b/pb-test/proto/packages/pbtest/pbtest3.proto @@ -347,3 +347,15 @@ message TestProto3Optional { string real_oneof_2_2 = 18; } } + +message RecursiveOneof { + oneof oneof_field { + // This field should be boxed automatically. + RecursiveOneof field = 1; + google.protobuf.Empty empty = 2; + // Boxing should override the empty-oneof-field special case. + google.protobuf.Empty boxed_empty = 3 [(rust.box_it) = true]; + ForeignMessage3 not_boxed = 4; + ForeignMessage3 boxed = 5 [(rust.box_it) = true]; + } +} diff --git a/pb-test/src/lib.rs b/pb-test/src/lib.rs index 32bd324..e215181 100644 --- a/pb-test/src/lib.rs +++ b/pb-test/src/lib.rs @@ -1,5 +1,4 @@ #![warn(rust_2018_idioms)] -#![feature(bench_black_box)] #![feature(test)] #[allow(unused_extern_crates)] diff --git a/pb-test/src/pbtest.rs b/pb-test/src/pbtest.rs index e5c8aea..5ce8b26 100644 --- a/pb-test/src/pbtest.rs +++ b/pb-test/src/pbtest.rs @@ -871,3 +871,20 @@ fn test_proto3_optional() { .unwrap(); assert_eq!(proto.a_int32, Some(456)); } + +// Test that boxing works properly for oneof fields. +#[test] +fn test_recursive_oneof() { + let message = RecursiveOneof { + oneof_field: Some(RecursiveOneof_OneofField::BoxedEmpty(std::default::Default::default())), + }; + check_roundtrip(&message); + let message = RecursiveOneof { + oneof_field: Some(RecursiveOneof_OneofField::Field(Box::new(message))), + }; + check_roundtrip(&message); + let message = RecursiveOneof { + oneof_field: Some(RecursiveOneof_OneofField::Field(Box::new(message))), + }; + check_roundtrip(&message); +}