Skip to content

Commit

Permalink
Implement extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
goffrie committed Oct 19, 2023
1 parent aa75014 commit 6c77402
Show file tree
Hide file tree
Showing 9 changed files with 847 additions and 45 deletions.
163 changes: 134 additions & 29 deletions pb-jelly-gen/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,7 @@
# https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
SourceCodeLocation = List[int]
ProtoTypes = Union[FileDescriptorProto, EnumDescriptorProto, DescriptorProto]
WalkRet = Tuple[
List[Tuple[List[Text], EnumDescriptorProto, SourceCodeLocation]],
List[Tuple[List[Text], DescriptorProto, SourceCodeLocation]],
]
ModTree = DefaultDict[Text, DefaultDict[Text, Any]]
ModTree = DefaultDict[Text, "ModTree"]


def camelcase(underscored: Text) -> Text:
Expand All @@ -148,7 +144,7 @@ def __init__(
self,
ctx: "Context",
proto_file: FileDescriptorProto,
msg_type: DescriptorProto,
msg_type: Optional[DescriptorProto],
field: FieldDescriptorProto,
) -> None:
self.ctx = ctx
Expand All @@ -159,6 +155,7 @@ def __init__(
self.oneof = (
field.HasField("oneof_index")
and not field.proto3_optional
and msg_type is not None
and msg_type.oneof_decl[field.oneof_index]
)

Expand Down Expand Up @@ -692,7 +689,7 @@ def write_comments(self, sci_loc: Optional[SourceCodeInfo.Location]) -> None:
self.write_line_broken_text_with_prefix(sci_loc.trailing_comments, "///")

def rust_type(
self, msg_type: DescriptorProto, field: FieldDescriptorProto
self, msg_type: Optional[DescriptorProto], field: FieldDescriptorProto
) -> RustType:
rust_type = RustType(self.ctx, self.proto_file, msg_type, field)

Expand Down Expand Up @@ -896,6 +893,11 @@ def gen_msg(
assert self.indentation == 0
name = "_".join(path + [msg_type.name])

preserve_unrecognized = msg_type.options.Extensions[
extensions_pb2.preserve_unrecognized
]
has_extensions = len(msg_type.extension_range) > 0

# Adjust some field names
for field in msg_type.field:
if field.name in RESERVED_KEYWORDS:
Expand Down Expand Up @@ -952,9 +954,12 @@ def gen_msg(
"pub %s: %s," % (oneof.name, oneof_msg_name(name, oneof))
)

if msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]:
if preserve_unrecognized:
self.write("pub _unrecognized: Vec<u8>,")

if has_extensions:
self.write("pub _extensions: ::pb_jelly::Unrecognized,")

# Generate any oneof enum structs
for oneof in oneof_decls:
self.write("#[derive(%s)]" % ", ".join(sorted(derives)))
Expand Down Expand Up @@ -1042,10 +1047,10 @@ def gen_msg(
oneof_field = oneof_fields[oneof.name][0]
typ = self.rust_type(msg_type, oneof_field)
self.write("%s: %s," % (oneof.name, typ.default(name)))
if msg_type.options.Extensions[
extensions_pb2.preserve_unrecognized
]:
if preserve_unrecognized:
self.write("_unrecognized: Vec::new(),")
if has_extensions:
self.write("_extensions: ::pb_jelly::Unrecognized::default(),")

with block(self, "lazy_static!"):
self.write(
Expand Down Expand Up @@ -1122,10 +1127,7 @@ def gen_msg(
self.write('name: "%s",' % oneof.name)

with block(self, "fn compute_size(&self) -> usize"):
if (
len(msg_type.field) > 0
or msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]
):
if len(msg_type.field) > 0 or preserve_unrecognized or has_extensions:
self.write("let mut size = 0;")
for field in msg_type.field:
typ = self.rust_type(msg_type, field)
Expand Down Expand Up @@ -1157,10 +1159,10 @@ def gen_msg(
% field.name
)
self.write("size += %s_size;" % field.name)
if msg_type.options.Extensions[
extensions_pb2.preserve_unrecognized
]:
if preserve_unrecognized:
self.write("size += self._unrecognized.len();")
if has_extensions:
self.write("size += self._extensions.compute_size();")
self.write("size")
else:
self.write("0")
Expand Down Expand Up @@ -1210,17 +1212,16 @@ def gen_msg(
)
self.write("::pb_jelly::varint::write(l as u64, w)?;")
self.write("::pb_jelly::Message::serialize(val, w)?;")
if msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]:
if preserve_unrecognized:
self.write("w.write_all(&self._unrecognized)?;")
if has_extensions:
self.write("self._extensions.serialize(w)?;")
self.write("Ok(())")

with block(
self,
"fn deserialize<B: ::pb_jelly::PbBufferReader>(&mut self, mut buf: &mut B) -> ::std::io::Result<()>",
):
preserve_unrecognized = msg_type.options.Extensions[
extensions_pb2.preserve_unrecognized
]
if preserve_unrecognized:
self.write(
"let mut unrecognized = ::pb_jelly::Unrecognized::default();"
Expand Down Expand Up @@ -1371,6 +1372,15 @@ def gen_msg(
"self.%s = %s;"
% (field.name, field_val)
)
if has_extensions:
pattern = " | ".join(
"{}..={}".format(r.start, r.end - 1)
for r in msg_type.extension_range
)
with block(self, pattern + " =>"):
self.write(
"self._extensions.gather(field_number, typ, &mut buf)?;"
)
with block(self, "_ =>"):
if preserve_unrecognized:
self.write(
Expand All @@ -1396,7 +1406,12 @@ def gen_msg(
)

if preserve_unrecognized:
self.write("unrecognized.serialize(&mut self._unrecognized)?;")
self.write(
"self._unrecognized.reserve(unrecognized.compute_size());"
)
self.write(
"unrecognized.serialize(&mut std::io::Cursor::new(&mut self._unrecognized))?;"
)
self.write("Ok(())")

with block(self, "impl ::pb_jelly::Reflection for " + name):
Expand Down Expand Up @@ -1499,9 +1514,58 @@ def gen_msg(
with block(self, "_ =>"):
self.write('panic!("unknown field name given")')

if has_extensions:
with block(self, "impl ::pb_jelly::extensions::Extensible for " + name):
with block(
self,
"fn _extensions(&self) -> &::pb_jelly::Unrecognized",
):
self.write("&self._extensions")

def gen_extension(
self,
path: List[Text],
extension_field: FieldDescriptorProto,
scl: SourceCodeLocation,
) -> None:
crate, mod_parts = self.ctx.crate_from_proto_filename(self.proto_file.name)

self.write_comments(self.source_code_info_by_scl.get(tuple(scl)))
name = ("_".join(path + [extension_field.name])).upper()
rust_type = self.rust_type(None, extension_field)
extendee = self.ctx.find(extension_field.extendee)
kind = (
"RepeatedExtension"
if extension_field.label == FieldDescriptorProto.LABEL_REPEATED
else "SingularExtension"
)

self.write(
"""pub const {name}: ::pb_jelly::extensions::{kind}<{extendee}, {field_type}> =
::pb_jelly::extensions::{kind}::new(
{field_number},
::pb_jelly::wire_format::Type::{wire_format},
"{raw_name}",
);""".format(
name=name,
extendee=extendee.rust_name(crate, mod_parts),
field_type=rust_type.rust_type(),
kind=kind,
field_number=extension_field.number,
wire_format=rust_type.wire_format(),
raw_name=extension_field.name,
)
)


def walk(proto: FileDescriptorProto) -> WalkRet:
enums, messages = [], []
def walk(
proto: FileDescriptorProto,
) -> Tuple[
List[Tuple[List[Text], EnumDescriptorProto, SourceCodeLocation]],
List[Tuple[List[Text], DescriptorProto, SourceCodeLocation]],
List[Tuple[List[Text], FieldDescriptorProto, SourceCodeLocation]],
]:
enums, messages, extensions = [], [], []

def _walk(
proto: ProtoTypes, parents: List[Text], scl_prefix: SourceCodeLocation
Expand All @@ -1518,6 +1582,15 @@ def _walk(
for i, nested_message in enumerate(proto.nested_type):
ntfn = DescriptorProto.NESTED_TYPE_FIELD_NUMBER
_walk(nested_message, parents + [proto.name], scl_prefix + [ntfn, i])

for i, nested_extension in enumerate(proto.extension):
extensions.append(
(
parents + [proto.name],
nested_extension,
scl_prefix + [DescriptorProto.EXTENSION_FIELD_NUMBER, i],
)
)
elif isinstance(proto, FileDescriptorProto):
for i, enum_type in enumerate(proto.enum_type):
etfn = FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER
Expand All @@ -1527,8 +1600,17 @@ def _walk(
mtfn = FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER
_walk(message_type, parents, scl_prefix + [mtfn, i])

for i, nested_extension in enumerate(proto.extension):
extensions.append(
(
parents,
nested_extension,
scl_prefix + [DescriptorProto.EXTENSION_FIELD_NUMBER, i],
)
)

_walk(proto, [], [])
return enums, messages
return enums, messages, extensions


M = TypeVar("M", DescriptorProto, EnumDescriptorProto)
Expand Down Expand Up @@ -1621,8 +1703,14 @@ def calc_impls(
(msg_impls_eq, msg_impls_copy) = (True, True)

if msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]:
# TODO: copy pasta
msg_impls_copy = False # Preserve unparsed has a Vec which is not Copy

if len(msg_type.extension_range) > 0:
# `Unrecognized` is neither Copy nor Eq
msg_impls_eq = False
msg_impls_copy = False

for field in msg_type.field:
typ = field.type
rust_type = RustType(self, proto_file, msg_type, field)
Expand Down Expand Up @@ -1672,6 +1760,7 @@ def calc_impls(
)

if msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]:
# TODO: this check isn't really necessary, but it is useful
assert field_type.typ.options.Extensions[
extensions_pb2.preserve_unrecognized
], "%s preserves unrecognized but child message %s does not" % (
Expand All @@ -1698,7 +1787,7 @@ def calc_impls(
self.impls_by_msg[fq_msg] = Impls(Eq=msg_impls_eq, Copy=msg_impls_copy)

def feed(self, proto_file: FileDescriptorProto, to_generate: List[Text]) -> None:
enums, messages = walk(proto_file)
enums, messages, extensions = walk(proto_file)

for name in to_generate:
crate, _ = self.crate_from_proto_filename(name)
Expand All @@ -1712,11 +1801,23 @@ def feed(self, proto_file: FileDescriptorProto, to_generate: List[Text]) -> None
msg_pt = ProtoType(self, proto_file, path, typ)
self.proto_types[msg_pt.proto_name()] = msg_pt

crate, _ = self.crate_from_proto_filename(proto_file.name)

for path, typ, _ in messages:
fq_msg = (proto_file.name, "_".join(path + [typ.name]))
crate, _ = self.crate_from_proto_filename(proto_file.name)
self.calc_impls(proto_file, crate, typ, fq_msg)

if crate in self.deps_map:
for path, field, _ in extensions:
for type_name in [field.type_name, field.extendee]:
if type_name:
field_type = self.find(type_name)
dep_crate, _ = self.crate_from_proto_filename(
field_type.proto_file.name
)
if dep_crate != crate:
self.deps_map[crate].add(dep_crate)

def find_enum(self, typename: Text) -> ProtoType[EnumDescriptorProto]:
pt = self.find(typename)
assert isinstance(pt.typ, EnumDescriptorProto)
Expand Down Expand Up @@ -2024,7 +2125,7 @@ def add_mod(writer: CodeWriter) -> None:
if writer.derive_serde:
derive_serde = True

enums, messages = walk(proto_file)
enums, messages, extensions = walk(proto_file)

for path, enum_typ, scl in enums:
writer.gen_enum(path, enum_typ, scl)
Expand All @@ -2035,6 +2136,10 @@ def add_mod(writer: CodeWriter) -> None:
writer.gen_msg(path, msg_typ, scl)
writer.write("")

for path, extension_field, scl in extensions:
writer.gen_extension(path, extension_field, scl)
writer.write("")

add_mod(writer=writer)

# check if the writer ever used a small string optimization
Expand Down
Loading

0 comments on commit 6c77402

Please sign in to comment.