diff --git a/pb-jelly-gen/codegen/codegen.py b/pb-jelly-gen/codegen/codegen.py index 762c1d7..7b13d46 100755 --- a/pb-jelly-gen/codegen/codegen.py +++ b/pb-jelly-gen/codegen/codegen.py @@ -276,6 +276,8 @@ def custom_type(self) -> Text: return self.field.options.Extensions[extensions_pb2.type] def is_nullable(self) -> bool: + if self.oneof: + return False if ( self.field.type in PRIMITIVE_TYPES and self.is_proto3 @@ -362,7 +364,7 @@ def set_method(self) -> Tuple[Text, Text]: elif self.field.type == FieldDescriptorProto.TYPE_ENUM: return self.rust_type(), "v" elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE: - return self.rust_type(maybe_boxed=True), "v" + return self.storage_type(), "v" raise AssertionError("Unexpected field type") def take_method(self) -> Tuple[Optional[Text], Optional[Text]]: @@ -395,7 +397,7 @@ def take_method(self) -> Tuple[Optional[Text], Optional[Text]]: elif self.field.type == FieldDescriptorProto.TYPE_ENUM: return self.rust_type(), expr elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE: - return self.rust_type(maybe_boxed=True), expr + return self.storage_type(), expr raise AssertionError("Unexpected field type") def get_method(self) -> Tuple[Text, Text]: @@ -446,10 +448,7 @@ def get_method(self) -> Tuple[Text, Text]: ) raise AssertionError("Unexpected field type") - def rust_type(self, maybe_boxed: bool = False) -> Text: - if maybe_boxed and self.is_boxed(): - return "::std::boxed::Box<%s>" % self.rust_type(maybe_boxed=False) - + def rust_type(self) -> Text: typ = self.field.type if self.has_custom_type(): @@ -485,15 +484,18 @@ def rust_type(self, maybe_boxed: bool = False) -> Text: "Unsupported type: {!r}".format(FieldDescriptorProto.Type.Name(typ)) ) - def __str__(self) -> str: - rust_type = self.rust_type(maybe_boxed=True) + def storage_type(self) -> str: + rust_type = self.rust_type() + + if self.is_boxed(): + rust_type = "::std::boxed::Box<%s>" % rust_type if self.is_repeated(): - return "::std::vec::Vec<%s>" % rust_type + rust_type = "::std::vec::Vec<%s>" % rust_type elif self.is_nullable(): - return "::std::option::Option<%s>" % rust_type - else: - return rust_type + rust_type = "::std::option::Option<%s>" % rust_type + + return rust_type def oneof_field_match(self, var: Text) -> Text: if self.is_empty_oneof_field(): @@ -931,7 +933,7 @@ def gen_msg( if typ.oneof: oneof_fields[typ.oneof.name].append(field) else: - self.write("pub %s: %s," % (field.name, typ)) + self.write("pub %s: %s," % (field.name, typ.storage_type())) for oneof in oneof_decls: if oneof_nullable(oneof): @@ -954,7 +956,7 @@ def gen_msg( for oneof_field in oneof_fields[oneof.name]: typ = self.rust_type(msg_type, oneof_field) self.write( - "%s," % typ.oneof_field_match(typ.rust_type(maybe_boxed=True)) + "%s," % typ.oneof_field_match(typ.storage_type()) ) if not self.is_proto3: