Skip to content

Commit

Permalink
Implement box_it for oneof fields
Browse files Browse the repository at this point in the history
  • Loading branch information
goffrie committed Oct 17, 2023
1 parent 238912e commit 00fccbf
Show file tree
Hide file tree
Showing 6 changed files with 392 additions and 38 deletions.
55 changes: 20 additions & 35 deletions pb-jelly-gen/codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
#!/usr/bin/env python3

import itertools
import os
import re
import sys

import google.protobuf

from collections import defaultdict, namedtuple, OrderedDict
from contextlib import contextmanager
from typing import (
Expand Down Expand Up @@ -35,7 +32,6 @@
OneofDescriptorProto,
SourceCodeInfo,
)
from google.protobuf.message import Message

from proto.rust import extensions_pb2

Expand Down Expand Up @@ -301,7 +297,7 @@ def is_nullable(self) -> bool:

def is_empty_oneof_field(self) -> bool:
assert self.oneof
return self.field.type_name == ".google.protobuf.Empty"
return self.field.type_name == ".google.protobuf.Empty" and not self.is_boxed()

def can_be_packed(self) -> bool:
# Return true if incoming messages could be packed on the wire
Expand Down Expand Up @@ -366,10 +362,7 @@ def set_method(self) -> Tuple[Text, Text]:
elif self.field.type == FieldDescriptorProto.TYPE_ENUM:
return self.rust_type(), "v"
elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE:
if self.is_boxed():
return "Box<%s>" % self.rust_type(), "v"
else:
return self.rust_type(), "v"
return self.rust_type(maybe_boxed=True), "v"
raise AssertionError("Unexpected field type")

def take_method(self) -> Tuple[Optional[Text], Optional[Text]]:
Expand Down Expand Up @@ -402,10 +395,7 @@ def take_method(self) -> Tuple[Optional[Text], Optional[Text]]:
elif self.field.type == FieldDescriptorProto.TYPE_ENUM:
return self.rust_type(), expr
elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE:
if self.is_boxed():
return "Box<%s>" % self.rust_type(), expr
else:
return self.rust_type(), expr
return self.rust_type(maybe_boxed=True), expr
raise AssertionError("Unexpected field type")

def get_method(self) -> Tuple[Text, Text]:
Expand Down Expand Up @@ -456,7 +446,10 @@ def get_method(self) -> Tuple[Text, Text]:
)
raise AssertionError("Unexpected field type")

def rust_type(self) -> Text:
def rust_type(self, maybe_boxed: bool = False) -> Text:
if maybe_boxed and self.is_boxed():
return "::std::boxed::Box<%s>" % self.rust_type(maybe_boxed=False)

typ = self.field.type

if self.has_custom_type():
Expand Down Expand Up @@ -493,18 +486,14 @@ def rust_type(self) -> Text:
)

def __str__(self) -> str:
rust_type = self.rust_type()
rust_type = self.rust_type(maybe_boxed=True)

if self.is_repeated():
return "::std::vec::Vec<%s>" % rust_type
elif self.is_nullable() and self.is_boxed():
return "::std::option::Option<::std::boxed::Box<%s>>" % str(rust_type)
elif self.is_boxed():
return "::std::boxed::Box<%s>" % rust_type
elif self.is_nullable():
return "::std::option::Option<%s>" % rust_type
else:
return str(rust_type)
return rust_type

def oneof_field_match(self, var: Text) -> Text:
if self.is_empty_oneof_field():
Expand Down Expand Up @@ -584,6 +573,11 @@ def field_iter(
"let %s: &%s = &::std::default::Default::default();"
% (var, typ.rust_type())
)
elif typ.is_boxed():
ctx.write(
"let %(var)s: &%(typ)s = &**%(var)s;"
% dict(var=var, typ=typ.rust_type())
)
yield
elif (
field.type == FieldDescriptorProto.TYPE_MESSAGE
Expand Down Expand Up @@ -959,7 +953,9 @@ def gen_msg(
with block(self, "pub enum " + oneof_msg_name(name, oneof)):
for oneof_field in oneof_fields[oneof.name]:
typ = self.rust_type(msg_type, oneof_field)
self.write("%s," % typ.oneof_field_match(typ.rust_type()))
self.write(
"%s," % typ.oneof_field_match(typ.rust_type(maybe_boxed=True))
)

if not self.is_proto3:
with block(self, "impl " + name):
Expand Down Expand Up @@ -1461,6 +1457,8 @@ def gen_msg(
typ.oneof.name,
),
):
if typ.is_boxed():
self.write("let val = &mut **val;")
self.write(
"return ::pb_jelly::reflection::FieldMut::Value(val);"
)
Expand Down Expand Up @@ -1744,20 +1742,7 @@ def _set_boxed_if_recursive(
visited, looking_for, self.find_msg(field.type_name)
)
if need_box or field.type_name == looking_for:
# We only box normal fields, not oneof variants
#
# TODO: We are restricting this case because the codegen
# can't currently box oneof variants. This means there are
# cases won't work with the Rust codegen. Specifically, if
# you have a oneof variant that directly references the
# containing message or is co-recursive to another message,
# the codegen won't box the variant and the resulting code
# won't compile.
if not (
field.HasField("oneof_index")
and pt.typ.oneof_decl[field.oneof_index]
):
field.options.Extensions[extensions_pb2.box_it] = True
field.options.Extensions[extensions_pb2.box_it] = True
any_field_boxed = True
return any_field_boxed

Expand Down
4 changes: 2 additions & 2 deletions pb-test/gen/pb-jelly/proto_pbtest/src/pbtest2.rs.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3085,10 +3085,10 @@ impl TestMessage {
pub fn has_optional_foreign_message_boxed(&self) -> bool {
self.optional_foreign_message_boxed.is_some()
}
pub fn set_optional_foreign_message_boxed(&mut self, v: Box<ForeignMessage>) {
pub fn set_optional_foreign_message_boxed(&mut self, v: ::std::boxed::Box<ForeignMessage>) {
self.optional_foreign_message_boxed = Some(v);
}
pub fn take_optional_foreign_message_boxed(&mut self) -> Box<ForeignMessage> {
pub fn take_optional_foreign_message_boxed(&mut self) -> ::std::boxed::Box<ForeignMessage> {
self.optional_foreign_message_boxed.take().unwrap_or_default()
}
pub fn get_optional_foreign_message_boxed(&self) -> &ForeignMessage {
Expand Down
Loading

0 comments on commit 00fccbf

Please sign in to comment.