From c2192182dd06cba7935d5f6a3028f0fd461c7bfc Mon Sep 17 00:00:00 2001 From: xxchan Date: Fri, 13 Sep 2024 16:20:09 +0800 Subject: [PATCH] refactor(source): move protobuf to codec crate, and refactor tests (#18507) Signed-off-by: xxchan --- Cargo.lock | 12 +- ci/scripts/e2e-source-test.sh | 1 - e2e_test/sink/kafka/protobuf.slt | 29 +- e2e_test/source/basic/kafka.slt | 36 - .../basic/old_row_format_syntax/kafka.slt | 27 - src/connector/Cargo.toml | 6 +- src/connector/codec/Cargo.toml | 10 + src/connector/{ => codec}/build.rs | 6 +- src/connector/codec/src/common/mod.rs | 15 + .../codec/src/common/protobuf/compiler.rs | 86 ++ .../codec/src/common/protobuf/mod.rs | 16 + src/connector/codec/src/decoder/mod.rs | 1 + .../src/decoder/protobuf/mod.rs} | 6 +- .../codec/src/decoder/protobuf/parser.rs | 236 ++++++ src/connector/codec/src/lib.rs | 1 + .../codec/tests/integration_tests/avro.rs | 32 +- .../codec/tests/integration_tests/main.rs | 25 + .../codec/tests/integration_tests/protobuf.rs | 647 ++++++++++++++ .../integration_tests}/protobuf/.gitignore | 1 + .../codec/tests/integration_tests/utils.rs | 29 +- .../tests/test_data/all-types.pb} | Bin 3231 -> 2748 bytes .../codec/tests/test_data/all-types.proto | 76 ++ .../tests}/test_data/any-schema.proto | 2 +- .../tests}/test_data/complex-schema.proto | 0 .../codec/tests/test_data/recursive.proto | 24 + .../tests}/test_data/simple-schema.proto | 0 src/connector/src/parser/protobuf/mod.rs | 4 - src/connector/src/parser/protobuf/parser.rs | 789 +----------------- src/connector/src/parser/unified/mod.rs | 5 +- src/connector/src/schema/protobuf.rs | 105 +-- src/connector/src/sink/encoder/proto.rs | 16 +- src/connector/src/test_data/any-schema.pb | 30 - src/connector/src/test_data/complex-schema | Bin 408 -> 0 bytes .../test_data/proto_recursive/recursive.proto | 95 --- src/connector/src/test_data/simple-schema | 11 - src/tests/simulation/src/slt.rs | 6 - 36 files changed, 1235 insertions(+), 1150 deletions(-) rename src/connector/{ => codec}/build.rs (87%) create mode 100644 src/connector/codec/src/common/mod.rs create mode 100644 src/connector/codec/src/common/protobuf/compiler.rs create mode 100644 src/connector/codec/src/common/protobuf/mod.rs rename src/connector/{src/parser/unified/protobuf.rs => codec/src/decoder/protobuf/mod.rs} (95%) create mode 100644 src/connector/codec/src/decoder/protobuf/parser.rs create mode 100644 src/connector/codec/tests/integration_tests/protobuf.rs rename src/connector/{src/parser => codec/tests/integration_tests}/protobuf/.gitignore (50%) rename src/connector/{src/test_data/proto_recursive/recursive.pb => codec/tests/test_data/all-types.pb} (76%) create mode 100644 src/connector/codec/tests/test_data/all-types.proto rename src/connector/{src => codec/tests}/test_data/any-schema.proto (99%) rename src/connector/{src => codec/tests}/test_data/complex-schema.proto (100%) create mode 100644 src/connector/codec/tests/test_data/recursive.proto rename src/connector/{src => codec/tests}/test_data/simple-schema.proto (100%) delete mode 100644 src/connector/src/test_data/any-schema.pb delete mode 100644 src/connector/src/test_data/complex-schema delete mode 100644 src/connector/src/test_data/proto_recursive/recursive.proto delete mode 100644 src/connector/src/test_data/simple-schema diff --git a/Cargo.lock b/Cargo.lock index 366fb3b36a672..fce73f9891743 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10613,6 +10613,7 @@ dependencies = [ "easy-ext", "enum-as-inner 0.6.0", "expect-test", + "fs-err", "futures", "futures-async-stream", "gcp-bigquery-client", @@ -10651,11 +10652,8 @@ dependencies = [ "pretty_assertions", "prometheus", "prost 0.13.1", - "prost-build 0.12.1", "prost-reflect", "prost-types 0.13.1", - "protobuf-native", - "protobuf-src", "pulsar", "quote", "rand", @@ -10717,11 +10715,19 @@ dependencies = [ "chrono", "easy-ext", "expect-test", + "fs-err", "hex", "itertools 0.12.1", "jsonbb", "jsonschema-transpiler", + "madsim-tokio", "num-bigint", + "prost 0.13.1", + "prost-build 0.12.1", + "prost-reflect", + "prost-types 0.13.1", + "protobuf-native", + "protobuf-src", "risingwave_common", "risingwave_pb", "rust_decimal", diff --git a/ci/scripts/e2e-source-test.sh b/ci/scripts/e2e-source-test.sh index 29f2a0ac7b5ce..6bf2f8a491576 100755 --- a/ci/scripts/e2e-source-test.sh +++ b/ci/scripts/e2e-source-test.sh @@ -45,7 +45,6 @@ risedev ci-kill echo "--- Prepare data" cp src/connector/src/test_data/simple-schema.avsc ./avro-simple-schema.avsc cp src/connector/src/test_data/complex-schema.avsc ./avro-complex-schema.avsc -cp src/connector/src/test_data/complex-schema ./proto-complex-schema cp src/connector/src/test_data/complex-schema.json ./json-complex-schema diff --git a/e2e_test/sink/kafka/protobuf.slt b/e2e_test/sink/kafka/protobuf.slt index 70de91e25c8d7..25b95a49cf1f3 100644 --- a/e2e_test/sink/kafka/protobuf.slt +++ b/e2e_test/sink/kafka/protobuf.slt @@ -4,17 +4,14 @@ set sink_decouple = false; system ok rpk topic create test-rw-sink-append-only-protobuf -system ok -cp src/connector/src/test_data/proto_recursive/recursive.pb ./proto-recursive - statement ok create table from_kafka with ( connector = 'kafka', topic = 'test-rw-sink-append-only-protobuf', properties.bootstrap.server = 'message_queue:29092') format plain encode protobuf ( - schema.location = 'file:///risingwave/proto-recursive', - message = 'recursive.AllTypes'); + schema.location = 'file:///risingwave/src/connector/codec/tests/test_data/all-types.pb', + message = 'all_types.AllTypes'); system ok rpk topic create test-rw-sink-append-only-protobuf-csr-a @@ -91,8 +88,8 @@ create sink sink0 from into_kafka with ( properties.bootstrap.server = 'message_queue:29092') format plain encode protobuf ( force_append_only = true, - schema.location = 'file:///risingwave/proto-recursive', - message = 'recursive.AllTypes'); + schema.location = 'file:///risingwave/src/connector/codec/tests/test_data/all-types.pb', + message = 'all_types.AllTypes'); statement ok create sink sink_csr_trivial as select string_field as field_a from into_kafka with ( @@ -121,8 +118,8 @@ create sink sink_upsert from into_kafka with ( properties.bootstrap.server = 'message_queue:29092', primary_key = 'string_field') format upsert encode protobuf ( - schema.location = 'file:///risingwave/proto-recursive', - message = 'recursive.AllTypes'); + schema.location = 'file:///risingwave/src/connector/codec/tests/test_data/all-types.pb', + message = 'all_types.AllTypes'); ---- db error: ERROR: Failed to run the query @@ -140,8 +137,8 @@ create sink sink_upsert from into_kafka with ( properties.bootstrap.server = 'message_queue:29092', primary_key = 'string_field') format upsert encode protobuf ( - schema.location = 'file:///risingwave/proto-recursive', - message = 'recursive.AllTypes') + schema.location = 'file:///risingwave/src/connector/codec/tests/test_data/all-types.pb', + message = 'all_types.AllTypes') key encode text; # Shall be ignored by force_append_only sinks but processed by upsert sinks. @@ -196,7 +193,7 @@ create sink sink_err from into_kafka with ( format plain encode protobuf ( force_append_only = true, schema.location = 'file:///risingwave/proto-recursiv', - message = 'recursive.AllTypes'); + message = 'all_types.AllTypes'); statement error field not in proto create sink sink_err as select 1 as extra_column with ( @@ -205,8 +202,8 @@ create sink sink_err as select 1 as extra_column with ( properties.bootstrap.server = 'message_queue:29092') format plain encode protobuf ( force_append_only = true, - schema.location = 'file:///risingwave/proto-recursive', - message = 'recursive.AllTypes'); + schema.location = 'file:///risingwave/src/connector/codec/tests/test_data/all-types.pb', + message = 'all_types.AllTypes'); statement error s3 URL not supported yet create sink sink_err from into_kafka with ( @@ -215,8 +212,8 @@ create sink sink_err from into_kafka with ( properties.bootstrap.server = 'message_queue:29092') format plain encode protobuf ( force_append_only = true, - schema.location = 's3:///risingwave/proto-recursive', - message = 'recursive.AllTypes'); + schema.location = 's3:///risingwave/src/connector/codec/tests/test_data/all-types.pb', + message = 'all_types.AllTypes'); statement ok drop table from_kafka cascade; diff --git a/e2e_test/source/basic/kafka.slt b/e2e_test/source/basic/kafka.slt index 0e413c3389d58..227c0aa46bac1 100644 --- a/e2e_test/source/basic/kafka.slt +++ b/e2e_test/source/basic/kafka.slt @@ -187,17 +187,6 @@ create table s10 with ( scan.startup.mode = 'earliest' ) FORMAT PLAIN ENCODE AVRO (schema.location = 'file:///risingwave/avro-complex-schema.avsc', with_deprecated_file_header = true); -statement ok -create table s11 with ( - connector = 'kafka', - topic = 'proto_c_bin', - properties.bootstrap.server = 'message_queue:29092', - scan.startup.mode = 'earliest') -FORMAT PLAIN ENCODE PROTOBUF ( - message = 'test.User', - schema.location = 'file:///risingwave/proto-complex-schema' -); - statement ok CREATE TABLE s12( id int, @@ -273,17 +262,6 @@ create table s16 (v1 int, v2 varchar) with ( scan.startup.mode = 'latest' ) FORMAT PLAIN ENCODE JSON -statement ok -create source s17 with ( - connector = 'kafka', - topic = 'proto_c_bin', - properties.bootstrap.server = 'message_queue:29092', - scan.startup.mode = 'earliest') -FORMAT PLAIN ENCODE PROTOBUF ( - message = 'test.User', - schema.location = 'file:///risingwave/proto-complex-schema' -); - statement ok create source s18 with ( connector = 'kafka', @@ -696,11 +674,6 @@ select id, code, timestamp, xfas, contacts, sex from s10; ---- 100 abc 1473305798 {"(0,200,10.0.0.1)","(1,400,10.0.0.2)"} ("{1xxx,2xxx}","{1xxx,2xxx}") MALE -query ITITT -select id, code, timestamp, xfas, contacts, sex from s11; ----- -0 abc 1473305798 {"(0,200,127.0.0.1)","(1,400,127.0.0.2)"} ("{1xxx,2xxx}","{1xxx,2xxx}") MALE - query ITITT select id, code, timestamp, xfas, contacts, jsonb from s12; ---- @@ -730,9 +703,6 @@ select count(*) from s16 statement error Not supported: alter source with schema registry alter source s18 add column v10 int; -statement error Not supported: alter source with schema registry -alter source s17 add column v10 int; - query III rowsort select * from s21; ---- @@ -875,9 +845,6 @@ drop table s9 statement ok drop table s10 -statement ok -drop table s11 - statement ok drop table s12 @@ -893,9 +860,6 @@ drop table s15 statement ok drop table s16 -statement ok -drop source s17 - statement ok drop source s18 diff --git a/e2e_test/source/basic/old_row_format_syntax/kafka.slt b/e2e_test/source/basic/old_row_format_syntax/kafka.slt index 1f4c118f30dc5..d67665a049daa 100644 --- a/e2e_test/source/basic/old_row_format_syntax/kafka.slt +++ b/e2e_test/source/basic/old_row_format_syntax/kafka.slt @@ -171,14 +171,6 @@ create table s10 with ( scan.startup.mode = 'earliest' ) row format avro row schema location 'file:///risingwave/avro-complex-schema.avsc' -statement ok -create table s11 with ( - connector = 'kafka', - topic = 'proto_c_bin', - properties.bootstrap.server = 'message_queue:29092', - scan.startup.mode = 'earliest' -) row format protobuf message 'test.User' row schema location 'file:///risingwave/proto-complex-schema' - statement ok CREATE TABLE s12( id int, @@ -254,14 +246,6 @@ create table s16 (v1 int, v2 varchar) with ( scan.startup.mode = 'latest' ) ROW FORMAT JSON -statement ok -create source s17 with ( - connector = 'kafka', - topic = 'proto_c_bin', - properties.bootstrap.server = 'message_queue:29092', - scan.startup.mode = 'earliest' -) row format protobuf message 'test.User' row schema location 'file:///risingwave/proto-complex-schema' - statement error without schema registry create source s18 with ( connector = 'kafka', @@ -570,11 +554,6 @@ select id, first_name, last_name, email from s8_no_schema_field; # ---- # 100 abc 1473305798 {"(0,200,10.0.0.1)","(1,400,10.0.0.2)"} ("{1xxx,2xxx}","{1xxx,2xxx}") MALE -query ITITT -select id, code, timestamp, xfas, contacts, sex from s11; ----- -0 abc 1473305798 {"(0,200,127.0.0.1)","(1,400,127.0.0.2)"} ("{1xxx,2xxx}","{1xxx,2xxx}") MALE - query ITITT select id, code, timestamp, xfas, contacts, jsonb from s12; ---- @@ -712,9 +691,6 @@ drop table s8_no_schema_field # statement ok # drop table s10 -statement ok -drop table s11 - statement ok drop table s12 @@ -730,9 +706,6 @@ drop table s15 statement ok drop table s16 -statement ok -drop source s17 - # statement ok # drop source s18 diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 3801508a7aa19..2535847c98fe4 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -105,7 +105,6 @@ prometheus = { version = "0.13", features = ["process"] } prost = { workspace = true, features = ["no-recursion-limit"] } prost-reflect = { version = "0.14", features = ["serde"] } prost-types = "0.13" -protobuf-native = "0.2.2" pulsar = { version = "6.3", default-features = false, features = [ "tokio-runtime", "telemetry", @@ -194,6 +193,7 @@ assert_matches = "1" criterion = { workspace = true, features = ["async_tokio", "async"] } deltalake = { workspace = true, features = ["datafusion"] } expect-test = "1" +fs-err = "2" paste = "1" pretty_assertions = "1" quote = "1" @@ -206,10 +206,6 @@ tracing-subscriber = "0.3" tracing-test = "0.2" walkdir = "2" -[build-dependencies] -prost-build = "0.12" -protobuf-src = "1" - [[bench]] name = "debezium_json_parser" harness = false diff --git a/src/connector/codec/Cargo.toml b/src/connector/codec/Cargo.toml index 5086549f4bf4c..5848c236dbd4d 100644 --- a/src/connector/codec/Cargo.toml +++ b/src/connector/codec/Cargo.toml @@ -26,6 +26,10 @@ itertools = { workspace = true } jsonbb = { workspace = true } jst = { package = 'jsonschema-transpiler', git = "https://github.com/mozilla/jsonschema-transpiler", rev = "c1a89d720d118843d8bcca51084deb0ed223e4b4" } num-bigint = "0.4" +prost = { workspace = true, features = ["no-recursion-limit"] } +prost-reflect = { version = "0.14", features = ["serde"] } +prost-types = "0.13" +protobuf-native = "0.2.2" risingwave_common = { workspace = true } risingwave_pb = { workspace = true } rust_decimal = "1" @@ -37,7 +41,13 @@ tracing = "0.1" [dev-dependencies] expect-test = "1" +fs-err = "2" hex = "0.4" +tokio = { version = "0.2", package = "madsim-tokio" } + +[build-dependencies] +prost-build = "0.12" +protobuf-src = "1" [target.'cfg(not(madsim))'.dependencies] workspace-hack = { path = "../../workspace-hack" } diff --git a/src/connector/build.rs b/src/connector/codec/build.rs similarity index 87% rename from src/connector/build.rs rename to src/connector/codec/build.rs index 6ef6e1629438c..8a9438d59b9e8 100644 --- a/src/connector/build.rs +++ b/src/connector/codec/build.rs @@ -13,17 +13,17 @@ // limitations under the License. fn main() { - let proto_dir = "./src/test_data/proto_recursive"; + let proto_dir = "./tests/test_data/"; println!("cargo:rerun-if-changed={}", proto_dir); - let proto_files = ["recursive"]; + let proto_files = ["recursive", "all-types"]; let protos: Vec = proto_files .iter() .map(|f| format!("{}/{}.proto", proto_dir, f)) .collect(); prost_build::Config::new() - .out_dir("./src/parser/protobuf") + .out_dir("./tests/integration_tests/protobuf") .compile_protos(&protos, &Vec::::new()) .unwrap(); diff --git a/src/connector/codec/src/common/mod.rs b/src/connector/codec/src/common/mod.rs new file mode 100644 index 0000000000000..c8a7ca35c4209 --- /dev/null +++ b/src/connector/codec/src/common/mod.rs @@ -0,0 +1,15 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod protobuf; diff --git a/src/connector/codec/src/common/protobuf/compiler.rs b/src/connector/codec/src/common/protobuf/compiler.rs new file mode 100644 index 0000000000000..80e86d002d4aa --- /dev/null +++ b/src/connector/codec/src/common/protobuf/compiler.rs @@ -0,0 +1,86 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::path::{Path, PathBuf}; + +use itertools::Itertools; + +macro_rules! embed_wkts { + [$( $path:literal ),+ $(,)?] => { + &[$( + ( + concat!("google/protobuf/", $path), + include_bytes!(concat!(env!("PROTO_INCLUDE"), "/google/protobuf/", $path)).as_slice(), + ) + ),+] + }; +} +const WELL_KNOWN_TYPES: &[(&str, &[u8])] = embed_wkts![ + "any.proto", + "api.proto", + "compiler/plugin.proto", + "descriptor.proto", + "duration.proto", + "empty.proto", + "field_mask.proto", + "source_context.proto", + "struct.proto", + "timestamp.proto", + "type.proto", + "wrappers.proto", +]; + +#[derive(Debug, thiserror::Error)] +pub enum PbCompileError { + #[error("build_file_descriptor_set failed\n{}", errs.iter().map(|e| format!("\t{e}")).join("\n"))] + Build { + errs: Vec, + }, + #[error("serialize descriptor set failed")] + Serialize, +} + +pub fn compile_pb( + main_file: (PathBuf, Vec), + dependencies: impl IntoIterator)>, +) -> Result, PbCompileError> { + use protobuf_native::compiler::{ + SimpleErrorCollector, SourceTreeDescriptorDatabase, VirtualSourceTree, + }; + use protobuf_native::MessageLite; + + let root = main_file.0.clone(); + + let mut source_tree = VirtualSourceTree::new(); + for (path, bytes) in std::iter::once(main_file).chain(dependencies.into_iter()) { + source_tree.as_mut().add_file(&path, bytes); + } + for (path, bytes) in WELL_KNOWN_TYPES { + source_tree + .as_mut() + .add_file(Path::new(path), bytes.to_vec()); + } + + let mut error_collector = SimpleErrorCollector::new(); + // `db` needs to be dropped before we can iterate on `error_collector`. + let fds = { + let mut db = SourceTreeDescriptorDatabase::new(source_tree.as_mut()); + db.as_mut().record_errors_to(error_collector.as_mut()); + db.as_mut().build_file_descriptor_set(&[root]) + } + .map_err(|_| PbCompileError::Build { + errs: error_collector.as_mut().collect(), + })?; + fds.serialize().map_err(|_| PbCompileError::Serialize) +} diff --git a/src/connector/codec/src/common/protobuf/mod.rs b/src/connector/codec/src/common/protobuf/mod.rs new file mode 100644 index 0000000000000..f630dedf0d240 --- /dev/null +++ b/src/connector/codec/src/common/protobuf/mod.rs @@ -0,0 +1,16 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod compiler; +pub use compiler::compile_pb; diff --git a/src/connector/codec/src/decoder/mod.rs b/src/connector/codec/src/decoder/mod.rs index bbfdbf0a90d79..e3e579ed36ec1 100644 --- a/src/connector/codec/src/decoder/mod.rs +++ b/src/connector/codec/src/decoder/mod.rs @@ -14,6 +14,7 @@ pub mod avro; pub mod json; +pub mod protobuf; pub mod utils; use risingwave_common::error::NotImplemented; diff --git a/src/connector/src/parser/unified/protobuf.rs b/src/connector/codec/src/decoder/protobuf/mod.rs similarity index 95% rename from src/connector/src/parser/unified/protobuf.rs rename to src/connector/codec/src/decoder/protobuf/mod.rs index 3ebeebca44373..7ad357fef50fb 100644 --- a/src/connector/src/parser/unified/protobuf.rs +++ b/src/connector/codec/src/decoder/protobuf/mod.rs @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod parser; use std::borrow::Cow; use std::sync::LazyLock; +use parser::from_protobuf_value; use prost_reflect::{DynamicMessage, ReflectMessage}; use risingwave_common::log::LogSuppresser; use risingwave_common::types::{DataType, DatumCow, ToOwnedDatum}; use thiserror_ext::AsReport; -use super::{Access, AccessResult}; -use crate::parser::from_protobuf_value; -use crate::parser::unified::uncategorized; +use super::{uncategorized, Access, AccessResult}; pub struct ProtobufAccess { message: DynamicMessage, diff --git a/src/connector/codec/src/decoder/protobuf/parser.rs b/src/connector/codec/src/decoder/protobuf/parser.rs new file mode 100644 index 0000000000000..15778727fc466 --- /dev/null +++ b/src/connector/codec/src/decoder/protobuf/parser.rs @@ -0,0 +1,236 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::Context; +use itertools::Itertools; +use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value}; +use risingwave_common::array::{ListValue, StructValue}; +use risingwave_common::types::{ + DataType, DatumCow, Decimal, JsonbVal, ScalarImpl, ToOwnedDatum, F32, F64, +}; +use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion}; +use thiserror::Error; +use thiserror_ext::Macro; + +use crate::decoder::{uncategorized, AccessError, AccessResult}; + +pub fn pb_schema_to_column_descs( + message_descriptor: &MessageDescriptor, +) -> anyhow::Result> { + let mut columns = Vec::with_capacity(message_descriptor.fields().len()); + let mut index = 0; + let mut parse_trace: Vec = vec![]; + for field in message_descriptor.fields() { + columns.push(pb_field_to_col_desc(&field, &mut index, &mut parse_trace)?); + } + + Ok(columns) +} + +/// Maps a protobuf field to a RW column. +fn pb_field_to_col_desc( + field_descriptor: &FieldDescriptor, + index: &mut i32, + parse_trace: &mut Vec, +) -> anyhow::Result { + let field_type = protobuf_type_mapping(field_descriptor, parse_trace) + .context("failed to map protobuf type")?; + if let Kind::Message(m) = field_descriptor.kind() { + let field_descs = if let DataType::List { .. } = field_type { + vec![] + } else { + m.fields() + .map(|f| pb_field_to_col_desc(&f, index, parse_trace)) + .try_collect()? + }; + *index += 1; + Ok(ColumnDesc { + column_id: *index, + name: field_descriptor.name().to_string(), + column_type: Some(field_type.to_protobuf()), + field_descs, + type_name: m.full_name().to_string(), + generated_or_default_column: None, + description: None, + additional_column_type: 0, // deprecated + additional_column: Some(AdditionalColumn { column_type: None }), + version: ColumnDescVersion::Pr13707 as i32, + }) + } else { + *index += 1; + Ok(ColumnDesc { + column_id: *index, + name: field_descriptor.name().to_string(), + column_type: Some(field_type.to_protobuf()), + additional_column: Some(AdditionalColumn { column_type: None }), + version: ColumnDescVersion::Pr13707 as i32, + ..Default::default() + }) + } +} + +#[derive(Error, Debug, Macro)] +#[error("{0}")] +struct ProtobufTypeError(#[message] String); + +fn detect_loop_and_push( + trace: &mut Vec, + fd: &FieldDescriptor, +) -> std::result::Result<(), ProtobufTypeError> { + let identifier = format!("{}({})", fd.name(), fd.full_name()); + if trace.iter().any(|s| s == identifier.as_str()) { + bail_protobuf_type_error!( + "circular reference detected: {}, conflict with {}, kind {:?}", + trace.iter().format("->"), + identifier, + fd.kind(), + ); + } + trace.push(identifier); + Ok(()) +} + +pub fn from_protobuf_value<'a>( + field_desc: &FieldDescriptor, + value: &'a Value, + type_expected: &DataType, +) -> AccessResult> { + let kind = field_desc.kind(); + + macro_rules! borrowed { + ($v:expr) => { + return Ok(DatumCow::Borrowed(Some($v.into()))) + }; + } + + let v: ScalarImpl = match value { + Value::Bool(v) => ScalarImpl::Bool(*v), + Value::I32(i) => ScalarImpl::Int32(*i), + Value::U32(i) => ScalarImpl::Int64(*i as i64), + Value::I64(i) => ScalarImpl::Int64(*i), + Value::U64(i) => ScalarImpl::Decimal(Decimal::from(*i)), + Value::F32(f) => ScalarImpl::Float32(F32::from(*f)), + Value::F64(f) => ScalarImpl::Float64(F64::from(*f)), + Value::String(s) => borrowed!(s.as_str()), + Value::EnumNumber(idx) => { + let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError { + expected: "enum".to_owned(), + got: format!("{kind:?}"), + value: value.to_string(), + })?; + let enum_symbol = enum_desc.get_value(*idx).ok_or_else(|| { + uncategorized!("unknown enum index {} of enum {:?}", idx, enum_desc) + })?; + ScalarImpl::Utf8(enum_symbol.name().into()) + } + Value::Message(dyn_msg) => { + if dyn_msg.descriptor().full_name() == "google.protobuf.Any" { + ScalarImpl::Jsonb(JsonbVal::from( + serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?, + )) + } else { + let desc = dyn_msg.descriptor(); + let DataType::Struct(st) = type_expected else { + return Err(AccessError::TypeError { + expected: type_expected.to_string(), + got: desc.full_name().to_string(), + value: value.to_string(), // Protobuf TEXT + }); + }; + + let mut rw_values = Vec::with_capacity(st.len()); + for (name, expected_field_type) in st.iter() { + let Some(field_desc) = desc.get_field_by_name(name) else { + // Field deleted in protobuf. Fallback to SQL NULL (of proper RW type). + rw_values.push(None); + continue; + }; + let value = dyn_msg.get_field(&field_desc); + rw_values.push( + from_protobuf_value(&field_desc, &value, expected_field_type)? + .to_owned_datum(), + ); + } + ScalarImpl::Struct(StructValue::new(rw_values)) + } + } + Value::List(values) => { + let DataType::List(element_type) = type_expected else { + return Err(AccessError::TypeError { + expected: type_expected.to_string(), + got: format!("repeated {:?}", kind), + value: value.to_string(), // Protobuf TEXT + }); + }; + let mut builder = element_type.create_array_builder(values.len()); + for value in values { + builder.append(from_protobuf_value(field_desc, value, element_type)?); + } + ScalarImpl::List(ListValue::new(builder.finish())) + } + Value::Bytes(value) => borrowed!(&**value), + _ => { + return Err(AccessError::UnsupportedType { + ty: format!("{kind:?}"), + }); + } + }; + Ok(Some(v).into()) +} + +/// Maps protobuf type to RW type. +fn protobuf_type_mapping( + field_descriptor: &FieldDescriptor, + parse_trace: &mut Vec, +) -> std::result::Result { + detect_loop_and_push(parse_trace, field_descriptor)?; + let field_type = field_descriptor.kind(); + let mut t = match field_type { + Kind::Bool => DataType::Boolean, + Kind::Double => DataType::Float64, + Kind::Float => DataType::Float32, + Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32, + // Fixed32 represents [0, 2^32 - 1]. It's equal to u32. + Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => { + DataType::Int64 + } + Kind::Uint64 | Kind::Fixed64 => DataType::Decimal, + Kind::String => DataType::Varchar, + Kind::Message(m) => match m.full_name() { + // Well-Known Types are identified by their full name + "google.protobuf.Any" => DataType::Jsonb, + _ => { + let fields = m + .fields() + .map(|f| protobuf_type_mapping(&f, parse_trace)) + .try_collect()?; + let field_names = m.fields().map(|f| f.name().to_string()).collect_vec(); + DataType::new_struct(fields, field_names) + } + }, + Kind::Enum(_) => DataType::Varchar, + Kind::Bytes => DataType::Bytea, + }; + if field_descriptor.is_map() { + bail_protobuf_type_error!( + "protobuf map type (on field `{}`) is not supported", + field_descriptor.full_name() + ); + } + if field_descriptor.cardinality() == Cardinality::Repeated { + t = DataType::List(Box::new(t)) + } + _ = parse_trace.pop(); + Ok(t) +} diff --git a/src/connector/codec/src/lib.rs b/src/connector/codec/src/lib.rs index 2119c1ece4e57..d3f0a8c6ec2cf 100644 --- a/src/connector/codec/src/lib.rs +++ b/src/connector/codec/src/lib.rs @@ -37,6 +37,7 @@ #![register_tool(rw)] #![recursion_limit = "256"] +pub mod common; /// Converts JSON/AVRO/Protobuf data to RisingWave datum. /// The core API is [`decoder::Access`]. pub mod decoder; diff --git a/src/connector/codec/tests/integration_tests/avro.rs b/src/connector/codec/tests/integration_tests/avro.rs index 11275f45e9783..ab1df6e7e82b8 100644 --- a/src/connector/codec/tests/integration_tests/avro.rs +++ b/src/connector/codec/tests/integration_tests/avro.rs @@ -64,33 +64,11 @@ fn avro_schema_str_to_risingwave_schema( Ok((resolved_schema, rw_schema)) } -/// Data driven testing for converting Avro Schema to RisingWave Schema, and then converting Avro data into RisingWave data. -/// -/// The expected results can be automatically updated. To run and update the tests: -/// ```bash -/// UPDATE_EXPECT=1 cargo test -p risingwave_connector_codec -/// ``` -/// Or use Rust Analyzer. Refer to . +/// Refer to [crate level documentation](crate) for the ideas. /// /// ## Arguments /// - `avro_schema`: Avro schema in JSON format. /// - `avro_data`: list of Avro data. Refer to [`TestDataEncoding`] for the format. -/// -/// ## Why not directly test the uppermost layer `AvroParserConfig` and `AvroAccessBuilder`? -/// -/// Because their interface are not clean enough, and have complex logic like schema registry. -/// We might need to separate logic to make them clenaer and then we can use it directly for testing. -/// -/// ## If we reimplement a similar logic here, what are we testing? -/// -/// Basically unit tests of `avro_schema_to_column_descs`, `convert_to_datum`, i.e., the type mapping. -/// -/// It makes some sense, as the data parsing logic is generally quite simple (one-liner), and the most -/// complex and error-prone part is the type mapping. -/// -/// ## Why test schema mapping and data mapping together? -/// -/// Because the expected data type for data mapping comes from the schema mapping. #[track_caller] fn check( avro_schema: &str, @@ -992,10 +970,10 @@ fn test_map() { map_map_int(#2): Jsonb, ]"#]], expect![[r#" - Owned(Jsonb(JsonbRef({"a": "x", "b": "y"}))) - Owned(Jsonb(JsonbRef({"m1": {"a": Number(1), "b": Number(2)}, "m2": {"c": Number(3), "d": Number(4)}}))) + Owned(Jsonb({"a": "x", "b": "y"})) + Owned(Jsonb({"m1": {"a": 1, "b": 2}, "m2": {"c": 3, "d": 4}})) ---- - Owned(Jsonb(JsonbRef({}))) - Owned(Jsonb(JsonbRef({})))"#]], + Owned(Jsonb({})) + Owned(Jsonb({}))"#]], ); } diff --git a/src/connector/codec/tests/integration_tests/main.rs b/src/connector/codec/tests/integration_tests/main.rs index 8c718f918d0a6..010fe05936517 100644 --- a/src/connector/codec/tests/integration_tests/main.rs +++ b/src/connector/codec/tests/integration_tests/main.rs @@ -12,6 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Data driven testing for converting Avro/Protobuf Schema to RisingWave Schema, and then converting Avro/Protobuf data into RisingWave data. +//! +//! The expected results can be automatically updated. To run and update the tests: +//! ```bash +//! UPDATE_EXPECT=1 cargo test -p risingwave_connector_codec +//! ``` +//! Or use Rust Analyzer. Refer to . +//! +//! ## Why not directly test the uppermost layer `AvroParserConfig` and `AvroAccessBuilder`? +//! +//! Because their interface are not clean enough, and have complex logic like schema registry. +//! We might need to separate logic to make them cleaner and then we can use it directly for testing. +//! +//! ## If we reimplement a similar logic here, what are we testing? +//! +//! Basically unit tests of `avro_schema_to_column_descs`, `convert_to_datum`, i.e., the type mapping. +//! +//! It makes some sense, as the data parsing logic is generally quite simple (one-liner), and the most +//! complex and error-prone part is the type mapping. +//! +//! ## Why test schema mapping and data mapping together? +//! +//! Because the expected data type for data mapping comes from the schema mapping. + mod avro; +mod protobuf; pub mod utils; diff --git a/src/connector/codec/tests/integration_tests/protobuf.rs b/src/connector/codec/tests/integration_tests/protobuf.rs new file mode 100644 index 0000000000000..b07d5f739b81d --- /dev/null +++ b/src/connector/codec/tests/integration_tests/protobuf.rs @@ -0,0 +1,647 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[rustfmt::skip] +#[allow(clippy::all)] +mod recursive; +#[rustfmt::skip] +#[allow(clippy::all)] +mod all_types; +use std::path::PathBuf; + +use anyhow::Context; +use prost::Message; +use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor}; +use risingwave_connector_codec::common::protobuf::compile_pb; +use risingwave_connector_codec::decoder::protobuf::parser::*; +use risingwave_connector_codec::decoder::protobuf::ProtobufAccess; +use risingwave_connector_codec::decoder::Access; +use thiserror_ext::AsReport; + +use crate::utils::*; + +/// Refer to [crate level documentation](crate) for the ideas. +#[track_caller] +fn check( + pb_schema: MessageDescriptor, + pb_data: &[&[u8]], + expected_risingwave_schema: expect_test::Expect, + expected_risingwave_data: expect_test::Expect, +) { + let rw_schema = pb_schema_to_column_descs(&pb_schema); + + if let Err(e) = rw_schema { + expected_risingwave_schema.assert_eq(&e.to_report_string_pretty()); + expected_risingwave_data.assert_eq(""); + return; + } + + let rw_schema = rw_schema + .unwrap() + .iter() + .map(ColumnDesc::from) + .collect_vec(); + expected_risingwave_schema.assert_eq(&format!( + "{:#?}", + rw_schema.iter().map(ColumnDescTestDisplay).collect_vec() + )); + + let mut data_str = vec![]; + for data in pb_data { + let access = ProtobufAccess::new(DynamicMessage::decode(pb_schema.clone(), *data).unwrap()); + let mut row = vec![]; + for col in &rw_schema { + let rw_data = access.access(&[&col.name], &col.data_type); + match rw_data { + Ok(data) => row.push(format!("{:#?}", DatumCowTestDisplay(&data))), + Err(e) => row.push(format!( + "~~~~\nError at column `{}`: {}\n~~~~", + col.name, + e.to_report_string() + )), + } + } + data_str.push(format!("{}", row.iter().format("\n"))); + } + + expected_risingwave_data.assert_eq(&format!( + "{}", + data_str + .iter() + .format("\n================================================================\n") + )); +} + +fn load_message_descriptor( + file_name: &str, + message_name: &str, +) -> anyhow::Result { + let location = "tests/test_data/".to_string() + file_name; + let file_content = fs_err::read(&location).unwrap(); + let schema_bytes = if file_name.ends_with(".proto") { + compile_pb((PathBuf::from(&location), file_content), [])? + } else { + file_content + }; + let pool = DescriptorPool::decode(schema_bytes.as_slice()) + .with_context(|| format!("cannot build descriptor pool from schema `{location}`"))?; + + pool.get_message_by_name(message_name).with_context(|| { + format!( + "cannot find message `{}` in schema `{}`", + message_name, location, + ) + }) +} + +#[test] +fn test_simple_schema() -> anyhow::Result<()> { + // Id: 123, + // Address: "test address", + // City: "test city", + // Zipcode: 456, + // Rate: 1.2345, + // Date: "2021-01-01" + static PRE_GEN_PROTO_DATA: &[u8] = b"\x08\x7b\x12\x0c\x74\x65\x73\x74\x20\x61\x64\x64\x72\x65\x73\x73\x1a\x09\x74\x65\x73\x74\x20\x63\x69\x74\x79\x20\xc8\x03\x2d\x19\x04\x9e\x3f\x32\x0a\x32\x30\x32\x31\x2d\x30\x31\x2d\x30\x31"; + + let message_descriptor = + load_message_descriptor("simple-schema.proto", "test.TestRecord").unwrap(); + + // validate the binary data is correct + let value = DynamicMessage::decode(message_descriptor.clone(), PRE_GEN_PROTO_DATA).unwrap(); + expect![[r#" + [ + I32( + 123, + ), + String( + "test address", + ), + String( + "test city", + ), + I64( + 456, + ), + F32( + 1.2345, + ), + String( + "2021-01-01", + ), + ] + "#]] + .assert_debug_eq(&value.fields().map(|f| f.1).collect_vec()); + + check( + message_descriptor, + &[PRE_GEN_PROTO_DATA], + expect![[r#" + [ + id(#1): Int32, + address(#2): Varchar, + city(#3): Varchar, + zipcode(#4): Int64, + rate(#5): Float32, + date(#6): Varchar, + ]"#]], + expect![[r#" + Owned(Int32(123)) + Borrowed(Utf8("test address")) + Borrowed(Utf8("test city")) + Owned(Int64(456)) + Owned(Float32(OrderedFloat(1.2345))) + Borrowed(Utf8("2021-01-01"))"#]], + ); + + Ok(()) +} + +#[test] +fn test_complex_schema() -> anyhow::Result<()> { + let message_descriptor = load_message_descriptor("complex-schema.proto", "test.User").unwrap(); + + check( + message_descriptor, + &[], + expect![[r#" + [ + id(#1): Int32, + code(#2): Varchar, + timestamp(#3): Int64, + xfas(#4): List( + Struct { + device_model_id: Int32, + device_make_id: Int32, + ip: Varchar, + }, + ), type_name: test.Xfa, + contacts(#7): Struct { + emails: List(Varchar), + phones: List(Varchar), + }, type_name: test.Contacts, field_descs: [emails(#5): List(Varchar), phones(#6): List(Varchar)], + sex(#8): Varchar, + ]"#]], + expect![""], + ); + + Ok(()) +} + +#[test] +fn test_any_schema() -> anyhow::Result<()> { + let message_descriptor = load_message_descriptor("any-schema.proto", "test.TestAny").unwrap(); + + // id: 12345 + // name { + // type_url: "type.googleapis.com/test.Int32Value" + // value: "\010\322\376\006" + // } + // Unpacked Int32Value from Any: value: 114514 + static ANY_DATA_1: &[u8] = b"\x08\xb9\x60\x12\x2b\x0a\x23\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x49\x6e\x74\x33\x32\x56\x61\x6c\x75\x65\x12\x04\x08\xd2\xfe\x06"; + + // "id": 12345, + // "any_value": { + // "type_url": "type.googleapis.com/test.AnyValue", + // "value": { + // "any_value_1": { + // "type_url": "type.googleapis.com/test.StringValue", + // "value": "114514" + // }, + // "any_value_2": { + // "type_url": "type.googleapis.com/test.Int32Value", + // "value": 114514 + // } + // } + // } + static ANY_DATA_2: &[u8] = b"\x08\xb9\x60\x12\x84\x01\x0a\x21\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x41\x6e\x79\x56\x61\x6c\x75\x65\x12\x5f\x0a\x30\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x56\x61\x6c\x75\x65\x12\x08\x0a\x06\x31\x31\x34\x35\x31\x34\x12\x2b\x0a\x23\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x49\x6e\x74\x33\x32\x56\x61\x6c\x75\x65\x12\x04\x08\xd2\xfe\x06"; + + // id: 12345 + // name { + // type_url: "type.googleapis.com/test.StringValue" + // value: "\n\010John Doe" + // } + static ANY_DATA_3: &[u8] = b"\x08\xb9\x60\x12\x32\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x56\x61\x6c\x75\x65\x12\x0a\x0a\x08\x4a\x6f\x68\x6e\x20\x44\x6f\x65"; + + // // id: 12345 + // // any_value: { + // // type_url: "type.googleapis.com/test.StringXalue" + // // value: "\n\010John Doe" + // // } + static ANY_DATA_INVALID: &[u8] = b"\x08\xb9\x60\x12\x32\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x58\x61\x6c\x75\x65\x12\x0a\x0a\x08\x4a\x6f\x68\x6e\x20\x44\x6f\x65"; + + // validate the binary data is correct + { + let value1 = DynamicMessage::decode(message_descriptor.clone(), ANY_DATA_1).unwrap(); + expect![[r#" + [ + I32( + 12345, + ), + Message( + DynamicMessage { + desc: MessageDescriptor { + name: "Any", + full_name: "google.protobuf.Any", + is_map_entry: false, + fields: [ + FieldDescriptor { + name: "type_url", + full_name: "google.protobuf.Any.type_url", + json_name: "typeUrl", + number: 1, + kind: string, + cardinality: Optional, + containing_oneof: None, + default_value: None, + is_group: false, + is_list: false, + is_map: false, + is_packed: false, + supports_presence: false, + }, + FieldDescriptor { + name: "value", + full_name: "google.protobuf.Any.value", + json_name: "value", + number: 2, + kind: bytes, + cardinality: Optional, + containing_oneof: None, + default_value: None, + is_group: false, + is_list: false, + is_map: false, + is_packed: false, + supports_presence: false, + }, + ], + oneofs: [], + }, + fields: DynamicMessageFieldSet { + fields: { + 1: Value( + String( + "type.googleapis.com/test.Int32Value", + ), + ), + 2: Value( + Bytes( + b"\x08\xd2\xfe\x06", + ), + ), + }, + }, + }, + ), + ] + "#]] + .assert_debug_eq(&value1.fields().map(|f| f.1).collect_vec()); + + let value2 = DynamicMessage::decode(message_descriptor.clone(), ANY_DATA_2).unwrap(); + expect![[r#" + [ + I32( + 12345, + ), + Message( + DynamicMessage { + desc: MessageDescriptor { + name: "Any", + full_name: "google.protobuf.Any", + is_map_entry: false, + fields: [ + FieldDescriptor { + name: "type_url", + full_name: "google.protobuf.Any.type_url", + json_name: "typeUrl", + number: 1, + kind: string, + cardinality: Optional, + containing_oneof: None, + default_value: None, + is_group: false, + is_list: false, + is_map: false, + is_packed: false, + supports_presence: false, + }, + FieldDescriptor { + name: "value", + full_name: "google.protobuf.Any.value", + json_name: "value", + number: 2, + kind: bytes, + cardinality: Optional, + containing_oneof: None, + default_value: None, + is_group: false, + is_list: false, + is_map: false, + is_packed: false, + supports_presence: false, + }, + ], + oneofs: [], + }, + fields: DynamicMessageFieldSet { + fields: { + 1: Value( + String( + "type.googleapis.com/test.AnyValue", + ), + ), + 2: Value( + Bytes( + b"\n0\n$type.googleapis.com/test.StringValue\x12\x08\n\x06114514\x12+\n#type.googleapis.com/test.Int32Value\x12\x04\x08\xd2\xfe\x06", + ), + ), + }, + }, + }, + ), + ] + "#]] + .assert_debug_eq(&value2.fields().map(|f| f.1).collect_vec()); + + let value3 = DynamicMessage::decode(message_descriptor.clone(), ANY_DATA_INVALID).unwrap(); + expect![[r#" + [ + I32( + 12345, + ), + Message( + DynamicMessage { + desc: MessageDescriptor { + name: "Any", + full_name: "google.protobuf.Any", + is_map_entry: false, + fields: [ + FieldDescriptor { + name: "type_url", + full_name: "google.protobuf.Any.type_url", + json_name: "typeUrl", + number: 1, + kind: string, + cardinality: Optional, + containing_oneof: None, + default_value: None, + is_group: false, + is_list: false, + is_map: false, + is_packed: false, + supports_presence: false, + }, + FieldDescriptor { + name: "value", + full_name: "google.protobuf.Any.value", + json_name: "value", + number: 2, + kind: bytes, + cardinality: Optional, + containing_oneof: None, + default_value: None, + is_group: false, + is_list: false, + is_map: false, + is_packed: false, + supports_presence: false, + }, + ], + oneofs: [], + }, + fields: DynamicMessageFieldSet { + fields: { + 1: Value( + String( + "type.googleapis.com/test.StringXalue", + ), + ), + 2: Value( + Bytes( + b"\n\x08John Doe", + ), + ), + }, + }, + }, + ), + ] + "#]] + .assert_debug_eq(&value3.fields().map(|f| f.1).collect_vec()); + } + + check( + message_descriptor, + &[ANY_DATA_1, ANY_DATA_2, ANY_DATA_3, ANY_DATA_INVALID], + expect![[r#" + [ + id(#1): Int32, + any_value(#4): Jsonb, type_name: google.protobuf.Any, field_descs: [type_url(#2): Varchar, value(#3): Bytea], + ]"#]], + expect![[r#" + Owned(Int32(12345)) + Owned(Jsonb({ + "@type": "type.googleapis.com/test.Int32Value", + "value": Number(114514), + })) + ================================================================ + Owned(Int32(12345)) + Owned(Jsonb({ + "@type": "type.googleapis.com/test.AnyValue", + "anyValue1": { + "@type": "type.googleapis.com/test.StringValue", + "value": "114514", + }, + "anyValue2": { + "@type": "type.googleapis.com/test.Int32Value", + "value": Number(114514), + }, + })) + ================================================================ + Owned(Int32(12345)) + Owned(Jsonb({ + "@type": "type.googleapis.com/test.StringValue", + "value": "John Doe", + })) + ================================================================ + Owned(Int32(12345)) + ~~~~ + Error at column `any_value`: Fail to convert protobuf Any into jsonb: message 'test.StringXalue' not found + ~~~~"#]], + ); + + Ok(()) +} + +#[test] +fn test_all_types() -> anyhow::Result<()> { + use self::all_types::all_types::*; + use self::all_types::*; + + let message_descriptor = + load_message_descriptor("all-types.proto", "all_types.AllTypes").unwrap(); + + let data = { + AllTypes { + double_field: 1.2345, + float_field: 1.2345, + int32_field: 42, + int64_field: 1234567890, + uint32_field: 98765, + uint64_field: 9876543210, + sint32_field: -12345, + sint64_field: -987654321, + fixed32_field: 1234, + fixed64_field: 5678, + sfixed32_field: -56789, + sfixed64_field: -123456, + bool_field: true, + string_field: "Hello, Prost!".to_string(), + bytes_field: b"byte data".to_vec(), + enum_field: EnumType::Option1 as i32, + nested_message_field: Some(NestedMessage { + id: 100, + name: "Nested".to_string(), + }), + repeated_int_field: vec![1, 2, 3, 4, 5], + timestamp_field: Some(::prost_types::Timestamp { + seconds: 1630927032, + nanos: 500000000, + }), + duration_field: Some(::prost_types::Duration { + seconds: 60, + nanos: 500000000, + }), + any_field: Some(::prost_types::Any { + type_url: "type.googleapis.com/my_custom_type".to_string(), + value: b"My custom data".to_vec(), + }), + int32_value_field: Some(42), + string_value_field: Some("Hello, Wrapper!".to_string()), + example_oneof: Some(ExampleOneof::OneofInt32(123)), + } + }; + let mut data_bytes = Vec::new(); + data.encode(&mut data_bytes).unwrap(); + + check( + message_descriptor, + &[&data_bytes], + expect![[r#" + [ + double_field(#1): Float64, + float_field(#2): Float32, + int32_field(#3): Int32, + int64_field(#4): Int64, + uint32_field(#5): Int64, + uint64_field(#6): Decimal, + sint32_field(#7): Int32, + sint64_field(#8): Int64, + fixed32_field(#9): Int64, + fixed64_field(#10): Decimal, + sfixed32_field(#11): Int32, + sfixed64_field(#12): Int64, + bool_field(#13): Boolean, + string_field(#14): Varchar, + bytes_field(#15): Bytea, + enum_field(#16): Varchar, + nested_message_field(#19): Struct { + id: Int32, + name: Varchar, + }, type_name: all_types.AllTypes.NestedMessage, field_descs: [id(#17): Int32, name(#18): Varchar], + repeated_int_field(#20): List(Int32), + oneof_string(#21): Varchar, + oneof_int32(#22): Int32, + oneof_enum(#23): Varchar, + timestamp_field(#26): Struct { + seconds: Int64, + nanos: Int32, + }, type_name: google.protobuf.Timestamp, field_descs: [seconds(#24): Int64, nanos(#25): Int32], + duration_field(#29): Struct { + seconds: Int64, + nanos: Int32, + }, type_name: google.protobuf.Duration, field_descs: [seconds(#27): Int64, nanos(#28): Int32], + any_field(#32): Jsonb, type_name: google.protobuf.Any, field_descs: [type_url(#30): Varchar, value(#31): Bytea], + int32_value_field(#34): Struct { value: Int32 }, type_name: google.protobuf.Int32Value, field_descs: [value(#33): Int32], + string_value_field(#36): Struct { value: Varchar }, type_name: google.protobuf.StringValue, field_descs: [value(#35): Varchar], + ]"#]], + expect![[r#" + Owned(Float64(OrderedFloat(1.2345))) + Owned(Float32(OrderedFloat(1.2345))) + Owned(Int32(42)) + Owned(Int64(1234567890)) + Owned(Int64(98765)) + Owned(Decimal(Normalized(9876543210))) + Owned(Int32(-12345)) + Owned(Int64(-987654321)) + Owned(Int64(1234)) + Owned(Decimal(Normalized(5678))) + Owned(Int32(-56789)) + Owned(Int64(-123456)) + Owned(Bool(true)) + Borrowed(Utf8("Hello, Prost!")) + Borrowed(Bytea([98, 121, 116, 101, 32, 100, 97, 116, 97])) + Owned(Utf8("OPTION1")) + Owned(StructValue( + Int32(100), + Utf8("Nested"), + )) + Owned([ + Int32(1), + Int32(2), + Int32(3), + Int32(4), + Int32(5), + ]) + Owned(Utf8("")) + Owned(Int32(123)) + Owned(Utf8("DEFAULT")) + Owned(StructValue( + Int64(1630927032), + Int32(500000000), + )) + Owned(StructValue( + Int64(60), + Int32(500000000), + )) + ~~~~ + Error at column `any_field`: Fail to convert protobuf Any into jsonb: message 'my_custom_type' not found + ~~~~ + Owned(StructValue(Int32(42))) + Owned(StructValue(Utf8("Hello, Wrapper!")))"#]], + ); + + Ok(()) +} + +#[test] +fn test_recursive() -> anyhow::Result<()> { + let message_descriptor = + load_message_descriptor("recursive.proto", "recursive.ComplexRecursiveMessage").unwrap(); + + check( + message_descriptor, + &[], + expect![[r#" + failed to map protobuf type + + Caused by: + circular reference detected: parent(recursive.ComplexRecursiveMessage.parent)->siblings(recursive.ComplexRecursiveMessage.Parent.siblings), conflict with parent(recursive.ComplexRecursiveMessage.parent), kind recursive.ComplexRecursiveMessage.Parent + "#]], + expect![""], + ); + + Ok(()) +} diff --git a/src/connector/src/parser/protobuf/.gitignore b/src/connector/codec/tests/integration_tests/protobuf/.gitignore similarity index 50% rename from src/connector/src/parser/protobuf/.gitignore rename to src/connector/codec/tests/integration_tests/protobuf/.gitignore index 4109deeeb3337..6e5bea6ee81ce 100644 --- a/src/connector/src/parser/protobuf/.gitignore +++ b/src/connector/codec/tests/integration_tests/protobuf/.gitignore @@ -1 +1,2 @@ recursive.rs +all_types.rs diff --git a/src/connector/codec/tests/integration_tests/utils.rs b/src/connector/codec/tests/integration_tests/utils.rs index dd375656c51e3..889dbeffc306f 100644 --- a/src/connector/codec/tests/integration_tests/utils.rs +++ b/src/connector/codec/tests/integration_tests/utils.rs @@ -40,10 +40,15 @@ impl<'a> std::fmt::Debug for DataTypeTestDisplay<'a> { f.finish()?; Ok(()) } - DataType::List(t) => f - .debug_tuple("List") - .field(&DataTypeTestDisplay(t)) - .finish(), + DataType::List(t) => { + if t.is_struct() { + f.debug_tuple("List") + .field(&DataTypeTestDisplay(t)) + .finish() + } else { + write!(f, "List({:?})", &DataTypeTestDisplay(t)) + } + } DataType::Map(m) => { write!( f, @@ -88,6 +93,14 @@ impl<'a> std::fmt::Debug for ScalarRefImplTestDisplay<'a> { .debug_list() .entries(m.inner().iter().map(DatumRefTestDisplay)) .finish(), + ScalarRefImpl::Jsonb(j) => { + let compact_str = format!("{}", j); + if compact_str.len() > 50 { + write!(f, "Jsonb({:#?})", jsonbb::ValueRef::from(j)) + } else { + write!(f, "Jsonb({:#})", j) + } + } _ => { // do not use alternative display for simple types write!(f, "{:?}", self.0) @@ -174,7 +187,13 @@ impl<'a> std::fmt::Debug for ColumnDescTestDisplay<'a> { write!(f, ", type_name: {}", type_name)?; } if !field_descs.is_empty() { - write!(f, ", field_descs: {:?}", field_descs)?; + write!( + f, + ", field_descs: [{}]", + field_descs.iter().format_with(", ", |field_desc, f| { + f(&format_args!("{:?}", ColumnDescTestDisplay(field_desc))) + }) + )?; } if let Some(generated_or_default_column) = generated_or_default_column { write!( diff --git a/src/connector/src/test_data/proto_recursive/recursive.pb b/src/connector/codec/tests/test_data/all-types.pb similarity index 76% rename from src/connector/src/test_data/proto_recursive/recursive.pb rename to src/connector/codec/tests/test_data/all-types.pb index 5c611c18d0d30bb3aceaec86d5dc4979d34fd189..177976d5244add199f5144fbce200152114733b4 100644 GIT binary patch delta 90 zcmbO)xkq%vM;5u0++6&LIXSu|l?AEAdId%KCHX>}K!JF$z+`^b+Rgdw224moD>&_t Qc<(uPAb6WMaJ^>)0N}aDmDEthMzIn7PEA^K%PwQY(UB zDt%LnixbmRg`~MS^YT+tEhxC6<&FWhRxDq!vpsD==$t8|mu7ZPi1utsdwwsC9ada9u%2CI~rlu@xj1rRJ4L zumJ7hHqyjp7Lri`U=2Y~jY4)@9LX7(IVnI>Bv`=?Qi2)+vLDq&2vdUKnxss)xDc)r z65wLaPOXH72#6^p!o^yam{Xbx4hqg72uG@ok&6xD5+QjmZipw41Cy&DF%{%7aCia( z8>#{xq8tzbV3-0uP@I{RlbM$e3sIN{@VXzSXLC2JArn%>A7!^g;_-0oK=3x7<#^8s E0QwTVHvj+t diff --git a/src/connector/codec/tests/test_data/all-types.proto b/src/connector/codec/tests/test_data/all-types.proto new file mode 100644 index 0000000000000..7dcad51a645d6 --- /dev/null +++ b/src/connector/codec/tests/test_data/all-types.proto @@ -0,0 +1,76 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/duration.proto"; +import "google/protobuf/any.proto"; +import "google/protobuf/wrappers.proto"; + +package all_types; + +// all-types.pb is generated by `protoc all-types.proto -o all-types.pb --include_imports` in the current directory. + +message AllTypes { + // standard types + double double_field = 1; + float float_field = 2; + int32 int32_field = 3; + int64 int64_field = 4; + uint32 uint32_field = 5; + uint64 uint64_field = 6; + sint32 sint32_field = 7; + sint64 sint64_field = 8; + fixed32 fixed32_field = 9; + fixed64 fixed64_field = 10; + sfixed32 sfixed32_field = 11; + sfixed64 sfixed64_field = 12; + bool bool_field = 13; + string string_field = 14; + + bytes bytes_field = 15; + + // enum + enum EnumType { + DEFAULT = 0; + OPTION1 = 1; + OPTION2 = 2; + } + EnumType enum_field = 16; + + // nested message + message NestedMessage { + int32 id = 1; + string name = 2; + } + NestedMessage nested_message_field = 17; + + // repeated field + repeated int32 repeated_int_field = 18; + + // oneof field + oneof example_oneof { + string oneof_string = 19; + int32 oneof_int32 = 20; + EnumType oneof_enum = 21; + } + + // // map field + // map map_field = 22; + + // timestamp + google.protobuf.Timestamp timestamp_field = 23; + + // duration + google.protobuf.Duration duration_field = 24; + + // any + google.protobuf.Any any_field = 25; + + // -- Unsupported + // // struct + // import "google/protobuf/struct.proto"; + // google.protobuf.Struct struct_field = 26; + + // wrapper types + google.protobuf.Int32Value int32_value_field = 27; + google.protobuf.StringValue string_value_field = 28; + } diff --git a/src/connector/src/test_data/any-schema.proto b/src/connector/codec/tests/test_data/any-schema.proto similarity index 99% rename from src/connector/src/test_data/any-schema.proto rename to src/connector/codec/tests/test_data/any-schema.proto index 12a367100ce7d..6bd9dcdf32b8f 100644 --- a/src/connector/src/test_data/any-schema.proto +++ b/src/connector/codec/tests/test_data/any-schema.proto @@ -35,4 +35,4 @@ message StringStringInt32Value { message Float32StringValue { float first = 1; string second = 2; -} \ No newline at end of file +} diff --git a/src/connector/src/test_data/complex-schema.proto b/src/connector/codec/tests/test_data/complex-schema.proto similarity index 100% rename from src/connector/src/test_data/complex-schema.proto rename to src/connector/codec/tests/test_data/complex-schema.proto diff --git a/src/connector/codec/tests/test_data/recursive.proto b/src/connector/codec/tests/test_data/recursive.proto new file mode 100644 index 0000000000000..a26a6a98e172f --- /dev/null +++ b/src/connector/codec/tests/test_data/recursive.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package recursive; + +message ComplexRecursiveMessage { + string node_name = 1; + int32 node_id = 2; + + message Attributes { + string key = 1; + string value = 2; + } + + repeated Attributes attributes = 3; + + message Parent { + string parent_name = 1; + int32 parent_id = 2; + repeated ComplexRecursiveMessage siblings = 3; + } + + Parent parent = 4; + repeated ComplexRecursiveMessage children = 5; +} diff --git a/src/connector/src/test_data/simple-schema.proto b/src/connector/codec/tests/test_data/simple-schema.proto similarity index 100% rename from src/connector/src/test_data/simple-schema.proto rename to src/connector/codec/tests/test_data/simple-schema.proto diff --git a/src/connector/src/parser/protobuf/mod.rs b/src/connector/src/parser/protobuf/mod.rs index bfcb0adfe1a18..462e478932ee7 100644 --- a/src/connector/src/parser/protobuf/mod.rs +++ b/src/connector/src/parser/protobuf/mod.rs @@ -14,7 +14,3 @@ mod parser; pub use parser::*; - -#[rustfmt::skip] -#[cfg(test)] -mod recursive; diff --git a/src/connector/src/parser/protobuf/parser.rs b/src/connector/src/parser/protobuf/parser.rs index bbd1d3f0da1e3..93eeb19cc1565 100644 --- a/src/connector/src/parser/protobuf/parser.rs +++ b/src/connector/src/parser/protobuf/parser.rs @@ -13,23 +13,14 @@ // limitations under the License. use anyhow::Context; -use itertools::Itertools; -use prost_reflect::{ - Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, FileDescriptor, Kind, - MessageDescriptor, ReflectMessage, Value, -}; -use risingwave_common::array::{ListValue, StructValue}; -use risingwave_common::types::{ - DataType, DatumCow, Decimal, JsonbVal, ScalarImpl, ToOwnedDatum, F32, F64, -}; +use prost_reflect::{DescriptorPool, DynamicMessage, FileDescriptor, MessageDescriptor}; use risingwave_common::{bail, try_match_expand}; -use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion}; -use thiserror::Error; -use thiserror_ext::Macro; +pub use risingwave_connector_codec::decoder::protobuf::parser::*; +use risingwave_connector_codec::decoder::protobuf::ProtobufAccess; +use risingwave_pb::plan_common::ColumnDesc; use crate::error::ConnectorResult; -use crate::parser::unified::protobuf::ProtobufAccess; -use crate::parser::unified::{uncategorized, AccessError, AccessImpl, AccessResult}; +use crate::parser::unified::AccessImpl; use crate::parser::util::bytes_from_url; use crate::parser::{AccessBuilder, EncodingProperties}; use crate::schema::schema_registry::{extract_schema_id, handle_sr_list, Client, WireFormatError}; @@ -124,216 +115,8 @@ impl ProtobufParserConfig { /// Maps the protobuf schema to relational schema. pub fn map_to_columns(&self) -> ConnectorResult> { - let mut columns = Vec::with_capacity(self.message_descriptor.fields().len()); - let mut index = 0; - let mut parse_trace: Vec = vec![]; - for field in self.message_descriptor.fields() { - columns.push(Self::pb_field_to_col_desc( - &field, - &mut index, - &mut parse_trace, - )?); - } - - Ok(columns) - } - - /// Maps a protobuf field to a RW column. - fn pb_field_to_col_desc( - field_descriptor: &FieldDescriptor, - index: &mut i32, - parse_trace: &mut Vec, - ) -> ConnectorResult { - let field_type = protobuf_type_mapping(field_descriptor, parse_trace) - .context("failed to map protobuf type")?; - if let Kind::Message(m) = field_descriptor.kind() { - let field_descs = if let DataType::List { .. } = field_type { - vec![] - } else { - m.fields() - .map(|f| Self::pb_field_to_col_desc(&f, index, parse_trace)) - .try_collect()? - }; - *index += 1; - Ok(ColumnDesc { - column_id: *index, - name: field_descriptor.name().to_string(), - column_type: Some(field_type.to_protobuf()), - field_descs, - type_name: m.full_name().to_string(), - generated_or_default_column: None, - description: None, - additional_column_type: 0, // deprecated - additional_column: Some(AdditionalColumn { column_type: None }), - version: ColumnDescVersion::Pr13707 as i32, - }) - } else { - *index += 1; - Ok(ColumnDesc { - column_id: *index, - name: field_descriptor.name().to_string(), - column_type: Some(field_type.to_protobuf()), - additional_column: Some(AdditionalColumn { column_type: None }), - version: ColumnDescVersion::Pr13707 as i32, - ..Default::default() - }) - } - } -} - -#[derive(Error, Debug, Macro)] -#[error("{0}")] -struct ProtobufTypeError(#[message] String); - -fn detect_loop_and_push( - trace: &mut Vec, - fd: &FieldDescriptor, -) -> std::result::Result<(), ProtobufTypeError> { - let identifier = format!("{}({})", fd.name(), fd.full_name()); - if trace.iter().any(|s| s == identifier.as_str()) { - bail_protobuf_type_error!( - "circular reference detected: {}, conflict with {}, kind {:?}", - trace.iter().format("->"), - identifier, - fd.kind(), - ); - } - trace.push(identifier); - Ok(()) -} - -pub fn from_protobuf_value<'a>( - field_desc: &FieldDescriptor, - value: &'a Value, - type_expected: &DataType, -) -> AccessResult> { - let kind = field_desc.kind(); - - macro_rules! borrowed { - ($v:expr) => { - return Ok(DatumCow::Borrowed(Some($v.into()))) - }; + pb_schema_to_column_descs(&self.message_descriptor).map_err(|e| e.into()) } - - let v: ScalarImpl = match value { - Value::Bool(v) => ScalarImpl::Bool(*v), - Value::I32(i) => ScalarImpl::Int32(*i), - Value::U32(i) => ScalarImpl::Int64(*i as i64), - Value::I64(i) => ScalarImpl::Int64(*i), - Value::U64(i) => ScalarImpl::Decimal(Decimal::from(*i)), - Value::F32(f) => ScalarImpl::Float32(F32::from(*f)), - Value::F64(f) => ScalarImpl::Float64(F64::from(*f)), - Value::String(s) => borrowed!(s.as_str()), - Value::EnumNumber(idx) => { - let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError { - expected: "enum".to_owned(), - got: format!("{kind:?}"), - value: value.to_string(), - })?; - let enum_symbol = enum_desc.get_value(*idx).ok_or_else(|| { - uncategorized!("unknown enum index {} of enum {:?}", idx, enum_desc) - })?; - ScalarImpl::Utf8(enum_symbol.name().into()) - } - Value::Message(dyn_msg) => { - if dyn_msg.descriptor().full_name() == "google.protobuf.Any" { - ScalarImpl::Jsonb(JsonbVal::from( - serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?, - )) - } else { - let desc = dyn_msg.descriptor(); - let DataType::Struct(st) = type_expected else { - return Err(AccessError::TypeError { - expected: type_expected.to_string(), - got: desc.full_name().to_string(), - value: value.to_string(), // Protobuf TEXT - }); - }; - - let mut rw_values = Vec::with_capacity(st.len()); - for (name, expected_field_type) in st.iter() { - let Some(field_desc) = desc.get_field_by_name(name) else { - // Field deleted in protobuf. Fallback to SQL NULL (of proper RW type). - rw_values.push(None); - continue; - }; - let value = dyn_msg.get_field(&field_desc); - rw_values.push( - from_protobuf_value(&field_desc, &value, expected_field_type)? - .to_owned_datum(), - ); - } - ScalarImpl::Struct(StructValue::new(rw_values)) - } - } - Value::List(values) => { - let DataType::List(element_type) = type_expected else { - return Err(AccessError::TypeError { - expected: type_expected.to_string(), - got: format!("repeated {:?}", kind), - value: value.to_string(), // Protobuf TEXT - }); - }; - let mut builder = element_type.create_array_builder(values.len()); - for value in values { - builder.append(from_protobuf_value(field_desc, value, element_type)?); - } - ScalarImpl::List(ListValue::new(builder.finish())) - } - Value::Bytes(value) => borrowed!(&**value), - _ => { - return Err(AccessError::UnsupportedType { - ty: format!("{kind:?}"), - }); - } - }; - Ok(Some(v).into()) -} - -/// Maps protobuf type to RW type. -fn protobuf_type_mapping( - field_descriptor: &FieldDescriptor, - parse_trace: &mut Vec, -) -> std::result::Result { - detect_loop_and_push(parse_trace, field_descriptor)?; - let field_type = field_descriptor.kind(); - let mut t = match field_type { - Kind::Bool => DataType::Boolean, - Kind::Double => DataType::Float64, - Kind::Float => DataType::Float32, - Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32, - // Fixed32 represents [0, 2^32 - 1]. It's equal to u32. - Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => { - DataType::Int64 - } - Kind::Uint64 | Kind::Fixed64 => DataType::Decimal, - Kind::String => DataType::Varchar, - Kind::Message(m) => match m.full_name() { - // Well-Known Types are identified by their full name - "google.protobuf.Any" => DataType::Jsonb, - _ => { - let fields = m - .fields() - .map(|f| protobuf_type_mapping(&f, parse_trace)) - .try_collect()?; - let field_names = m.fields().map(|f| f.name().to_string()).collect_vec(); - DataType::new_struct(fields, field_names) - } - }, - Kind::Enum(_) => DataType::Varchar, - Kind::Bytes => DataType::Bytea, - }; - if field_descriptor.is_map() { - bail_protobuf_type_error!( - "protobuf map type (on field `{}`) is not supported", - field_descriptor.full_name() - ); - } - if field_descriptor.cardinality() == Cardinality::Repeated { - t = DataType::List(Box::new(t)) - } - _ = parse_trace.pop(); - Ok(t) } /// A port from the implementation of confluent's Varint Zig-zag deserialization. @@ -391,567 +174,7 @@ pub(crate) fn resolve_pb_header(payload: &[u8]) -> ConnectorResult<&[u8]> { #[cfg(test)] mod test { - use std::path::PathBuf; - - use prost::Message; - use risingwave_common::types::StructType; - use risingwave_connector_codec::decoder::AccessExt; - use risingwave_pb::catalog::StreamSourceInfo; - use risingwave_pb::data::data_type::PbTypeName; - use risingwave_pb::plan_common::{PbEncodeType, PbFormatType}; - use serde_json::json; - use thiserror_ext::AsReport as _; - use super::*; - use crate::parser::protobuf::recursive::all_types::{EnumType, ExampleOneof, NestedMessage}; - use crate::parser::protobuf::recursive::AllTypes; - use crate::parser::SpecificParserConfig; - - fn schema_dir() -> String { - let dir = PathBuf::from("src/test_data"); - format!( - "file://{}", - std::fs::canonicalize(dir).unwrap().to_str().unwrap() - ) - } - - // Id: 123, - // Address: "test address", - // City: "test city", - // Zipcode: 456, - // Rate: 1.2345, - // Date: "2021-01-01" - static PRE_GEN_PROTO_DATA: &[u8] = b"\x08\x7b\x12\x0c\x74\x65\x73\x74\x20\x61\x64\x64\x72\x65\x73\x73\x1a\x09\x74\x65\x73\x74\x20\x63\x69\x74\x79\x20\xc8\x03\x2d\x19\x04\x9e\x3f\x32\x0a\x32\x30\x32\x31\x2d\x30\x31\x2d\x30\x31"; - - #[tokio::test] - async fn test_simple_schema() -> crate::error::ConnectorResult<()> { - let location = schema_dir() + "/simple-schema"; - println!("location: {}", location); - let message_name = "test.TestRecord"; - let info = StreamSourceInfo { - proto_message_name: message_name.to_string(), - row_schema_location: location.to_string(), - use_schema_registry: false, - format: PbFormatType::Plain.into(), - row_encode: PbEncodeType::Protobuf.into(), - ..Default::default() - }; - let parser_config = SpecificParserConfig::new(&info, &Default::default())?; - let conf = ProtobufParserConfig::new(parser_config.encoding_config).await?; - let value = DynamicMessage::decode(conf.message_descriptor, PRE_GEN_PROTO_DATA).unwrap(); - - assert_eq!( - value.get_field_by_name("id").unwrap().into_owned(), - Value::I32(123) - ); - assert_eq!( - value.get_field_by_name("address").unwrap().into_owned(), - Value::String("test address".to_string()) - ); - assert_eq!( - value.get_field_by_name("city").unwrap().into_owned(), - Value::String("test city".to_string()) - ); - assert_eq!( - value.get_field_by_name("zipcode").unwrap().into_owned(), - Value::I64(456) - ); - assert_eq!( - value.get_field_by_name("rate").unwrap().into_owned(), - Value::F32(1.2345) - ); - assert_eq!( - value.get_field_by_name("date").unwrap().into_owned(), - Value::String("2021-01-01".to_string()) - ); - - Ok(()) - } - - #[tokio::test] - async fn test_complex_schema() -> crate::error::ConnectorResult<()> { - let location = schema_dir() + "/complex-schema"; - let message_name = "test.User"; - - let info = StreamSourceInfo { - proto_message_name: message_name.to_string(), - row_schema_location: location.to_string(), - use_schema_registry: false, - format: PbFormatType::Plain.into(), - row_encode: PbEncodeType::Protobuf.into(), - ..Default::default() - }; - let parser_config = SpecificParserConfig::new(&info, &Default::default())?; - let conf = ProtobufParserConfig::new(parser_config.encoding_config).await?; - let columns = conf.map_to_columns().unwrap(); - - assert_eq!(columns[0].name, "id".to_string()); - assert_eq!(columns[1].name, "code".to_string()); - assert_eq!(columns[2].name, "timestamp".to_string()); - - let data_type = columns[3].column_type.as_ref().unwrap(); - assert_eq!(data_type.get_type_name().unwrap(), PbTypeName::List); - let inner_field_type = data_type.field_type.clone(); - assert_eq!( - inner_field_type[0].get_type_name().unwrap(), - PbTypeName::Struct - ); - let struct_inner = inner_field_type[0].field_type.clone(); - assert_eq!(struct_inner[0].get_type_name().unwrap(), PbTypeName::Int32); - assert_eq!(struct_inner[1].get_type_name().unwrap(), PbTypeName::Int32); - assert_eq!( - struct_inner[2].get_type_name().unwrap(), - PbTypeName::Varchar - ); - - assert_eq!(columns[4].name, "contacts".to_string()); - let inner_field_type = columns[4].column_type.as_ref().unwrap().field_type.clone(); - assert_eq!( - inner_field_type[0].get_type_name().unwrap(), - PbTypeName::List - ); - assert_eq!( - inner_field_type[1].get_type_name().unwrap(), - PbTypeName::List - ); - Ok(()) - } - - #[tokio::test] - async fn test_refuse_recursive_proto_message() { - let location = schema_dir() + "/proto_recursive/recursive.pb"; - let message_name = "recursive.ComplexRecursiveMessage"; - - let info = StreamSourceInfo { - proto_message_name: message_name.to_string(), - row_schema_location: location.to_string(), - use_schema_registry: false, - format: PbFormatType::Plain.into(), - row_encode: PbEncodeType::Protobuf.into(), - ..Default::default() - }; - let parser_config = SpecificParserConfig::new(&info, &Default::default()).unwrap(); - let conf = ProtobufParserConfig::new(parser_config.encoding_config) - .await - .unwrap(); - let columns = conf.map_to_columns(); - // expect error message: - // "Err(Protocol error: circular reference detected: - // parent(recursive.ComplexRecursiveMessage.parent)->siblings(recursive. - // ComplexRecursiveMessage.Parent.siblings), conflict with - // parent(recursive.ComplexRecursiveMessage.parent), kind - // recursive.ComplexRecursiveMessage.Parent" - assert!(columns.is_err()); - } - - async fn create_recursive_pb_parser_config( - location: &str, - message_name: &str, - ) -> ProtobufParserConfig { - let location = schema_dir() + location; - - let info = StreamSourceInfo { - proto_message_name: message_name.to_string(), - row_schema_location: location.to_string(), - use_schema_registry: false, - format: PbFormatType::Plain.into(), - row_encode: PbEncodeType::Protobuf.into(), - ..Default::default() - }; - let parser_config = SpecificParserConfig::new(&info, &Default::default()).unwrap(); - - ProtobufParserConfig::new(parser_config.encoding_config) - .await - .unwrap() - } - - #[tokio::test] - async fn test_all_types_create_source() { - let conf = create_recursive_pb_parser_config( - "/proto_recursive/recursive.pb", - "recursive.AllTypes", - ) - .await; - - // Ensure that the parser can recognize the schema. - let columns = conf - .map_to_columns() - .unwrap() - .into_iter() - .map(|c| DataType::from(&c.column_type.unwrap())) - .collect_vec(); - assert_eq!( - columns, - vec![ - DataType::Float64, // double_field - DataType::Float32, // float_field - DataType::Int32, // int32_field - DataType::Int64, // int64_field - DataType::Int64, // uint32_field - DataType::Decimal, // uint64_field - DataType::Int32, // sint32_field - DataType::Int64, // sint64_field - DataType::Int64, // fixed32_field - DataType::Decimal, // fixed64_field - DataType::Int32, // sfixed32_field - DataType::Int64, // sfixed64_field - DataType::Boolean, // bool_field - DataType::Varchar, // string_field - DataType::Bytea, // bytes_field - DataType::Varchar, // enum_field - DataType::Struct(StructType::new(vec![ - ("id", DataType::Int32), - ("name", DataType::Varchar) - ])), // nested_message_field - DataType::List(DataType::Int32.into()), // repeated_int_field - DataType::Varchar, // oneof_string - DataType::Int32, // oneof_int32 - DataType::Varchar, // oneof_enum - DataType::Struct(StructType::new(vec![ - ("seconds", DataType::Int64), - ("nanos", DataType::Int32) - ])), // timestamp_field - DataType::Struct(StructType::new(vec![ - ("seconds", DataType::Int64), - ("nanos", DataType::Int32) - ])), // duration_field - DataType::Jsonb, // any_field - DataType::Struct(StructType::new(vec![("value", DataType::Int32)])), /* int32_value_field */ - DataType::Struct(StructType::new(vec![("value", DataType::Varchar)])), /* string_value_field */ - ] - ) - } - - #[tokio::test] - async fn test_all_types_data_parsing() { - let m = create_all_types_message(); - let mut payload = Vec::new(); - m.encode(&mut payload).unwrap(); - - let conf = create_recursive_pb_parser_config( - "/proto_recursive/recursive.pb", - "recursive.AllTypes", - ) - .await; - let mut access_builder = ProtobufAccessBuilder::new(conf).unwrap(); - let access = access_builder.generate_accessor(payload).await.unwrap(); - if let AccessImpl::Protobuf(a) = access { - assert_all_types_eq(&a, &m); - } else { - panic!("unexpected") - } - } - - fn assert_all_types_eq(a: &ProtobufAccess, m: &AllTypes) { - type S = ScalarImpl; - - pb_eq(a, "double_field", S::Float64(m.double_field.into())); - pb_eq(a, "float_field", S::Float32(m.float_field.into())); - pb_eq(a, "int32_field", S::Int32(m.int32_field)); - pb_eq(a, "int64_field", S::Int64(m.int64_field)); - pb_eq(a, "uint32_field", S::Int64(m.uint32_field.into())); - pb_eq(a, "uint64_field", S::Decimal(m.uint64_field.into())); - pb_eq(a, "sint32_field", S::Int32(m.sint32_field)); - pb_eq(a, "sint64_field", S::Int64(m.sint64_field)); - pb_eq(a, "fixed32_field", S::Int64(m.fixed32_field.into())); - pb_eq(a, "fixed64_field", S::Decimal(m.fixed64_field.into())); - pb_eq(a, "sfixed32_field", S::Int32(m.sfixed32_field)); - pb_eq(a, "sfixed64_field", S::Int64(m.sfixed64_field)); - pb_eq(a, "bool_field", S::Bool(m.bool_field)); - pb_eq(a, "string_field", S::Utf8(m.string_field.as_str().into())); - pb_eq(a, "bytes_field", S::Bytea(m.bytes_field.clone().into())); - pb_eq(a, "enum_field", S::Utf8("OPTION1".into())); - pb_eq( - a, - "nested_message_field", - S::Struct(StructValue::new(vec![ - Some(ScalarImpl::Int32(100)), - Some(ScalarImpl::Utf8("Nested".into())), - ])), - ); - pb_eq( - a, - "repeated_int_field", - S::List(ListValue::from_iter(m.repeated_int_field.clone())), - ); - pb_eq( - a, - "timestamp_field", - S::Struct(StructValue::new(vec![ - Some(ScalarImpl::Int64(1630927032)), - Some(ScalarImpl::Int32(500000000)), - ])), - ); - pb_eq( - a, - "duration_field", - S::Struct(StructValue::new(vec![ - Some(ScalarImpl::Int64(60)), - Some(ScalarImpl::Int32(500000000)), - ])), - ); - pb_eq( - a, - "int32_value_field", - S::Struct(StructValue::new(vec![Some(ScalarImpl::Int32(42))])), - ); - pb_eq( - a, - "string_value_field", - S::Struct(StructValue::new(vec![Some(ScalarImpl::Utf8( - m.string_value_field.as_ref().unwrap().as_str().into(), - ))])), - ); - pb_eq(a, "oneof_string", S::Utf8("".into())); - pb_eq(a, "oneof_int32", S::Int32(123)); - pb_eq(a, "oneof_enum", S::Utf8("DEFAULT".into())); - } - - fn pb_eq(a: &ProtobufAccess, field_name: &str, value: ScalarImpl) { - let field = a.descriptor().get_field_by_name(field_name).unwrap(); - let dummy_type = protobuf_type_mapping(&field, &mut vec![]).unwrap(); - let d = a.access_owned(&[field_name], &dummy_type).unwrap().unwrap(); - assert_eq!(d, value, "field: {} value: {:?}", field_name, d); - } - - fn create_all_types_message() -> AllTypes { - AllTypes { - double_field: 1.2345, - float_field: 1.2345, - int32_field: 42, - int64_field: 1234567890, - uint32_field: 98765, - uint64_field: 9876543210, - sint32_field: -12345, - sint64_field: -987654321, - fixed32_field: 1234, - fixed64_field: 5678, - sfixed32_field: -56789, - sfixed64_field: -123456, - bool_field: true, - string_field: "Hello, Prost!".to_string(), - bytes_field: b"byte data".to_vec(), - enum_field: EnumType::Option1 as i32, - nested_message_field: Some(NestedMessage { - id: 100, - name: "Nested".to_string(), - }), - repeated_int_field: vec![1, 2, 3, 4, 5], - timestamp_field: Some(::prost_types::Timestamp { - seconds: 1630927032, - nanos: 500000000, - }), - duration_field: Some(::prost_types::Duration { - seconds: 60, - nanos: 500000000, - }), - any_field: Some(::prost_types::Any { - type_url: "type.googleapis.com/my_custom_type".to_string(), - value: b"My custom data".to_vec(), - }), - int32_value_field: Some(42), - string_value_field: Some("Hello, Wrapper!".to_string()), - example_oneof: Some(ExampleOneof::OneofInt32(123)), - } - } - - // id: 12345 - // name { - // type_url: "type.googleapis.com/test.StringValue" - // value: "\n\010John Doe" - // } - static ANY_GEN_PROTO_DATA: &[u8] = b"\x08\xb9\x60\x12\x32\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x56\x61\x6c\x75\x65\x12\x0a\x0a\x08\x4a\x6f\x68\x6e\x20\x44\x6f\x65"; - - #[tokio::test] - async fn test_any_schema() -> crate::error::ConnectorResult<()> { - let conf = create_recursive_pb_parser_config("/any-schema.pb", "test.TestAny").await; - - println!("Current conf: {:#?}", conf); - println!("---------------------------"); - - let message = - DynamicMessage::decode(conf.message_descriptor.clone(), ANY_GEN_PROTO_DATA).unwrap(); - - println!("Test ANY_GEN_PROTO_DATA, current value: {:#?}", message); - println!("---------------------------"); - - let field = conf - .message_descriptor - .get_field_by_name("any_value") - .unwrap(); - let value = message.get_field(&field); - - let ret = from_protobuf_value(&field, &value, &DataType::Jsonb) - .unwrap() - .to_owned_datum(); - println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret); - println!("---------------------------"); - - match ret { - Some(ScalarImpl::Jsonb(jv)) => { - assert_eq!( - jv, - JsonbVal::from(json!({ - "@type": "type.googleapis.com/test.StringValue", - "value": "John Doe" - })) - ); - } - _ => panic!("Expected ScalarImpl::Jsonb"), - } - - Ok(()) - } - - // id: 12345 - // name { - // type_url: "type.googleapis.com/test.Int32Value" - // value: "\010\322\376\006" - // } - // Unpacked Int32Value from Any: value: 114514 - static ANY_GEN_PROTO_DATA_1: &[u8] = b"\x08\xb9\x60\x12\x2b\x0a\x23\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x49\x6e\x74\x33\x32\x56\x61\x6c\x75\x65\x12\x04\x08\xd2\xfe\x06"; - - #[tokio::test] - async fn test_any_schema_1() -> crate::error::ConnectorResult<()> { - let conf = create_recursive_pb_parser_config("/any-schema.pb", "test.TestAny").await; - - println!("Current conf: {:#?}", conf); - println!("---------------------------"); - - let message = - DynamicMessage::decode(conf.message_descriptor.clone(), ANY_GEN_PROTO_DATA_1).unwrap(); - - println!("Current Value: {:#?}", message); - println!("---------------------------"); - - let field = conf - .message_descriptor - .get_field_by_name("any_value") - .unwrap(); - let value = message.get_field(&field); - - let ret = from_protobuf_value(&field, &value, &DataType::Jsonb) - .unwrap() - .to_owned_datum(); - println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret); - println!("---------------------------"); - - match ret { - Some(ScalarImpl::Jsonb(jv)) => { - assert_eq!( - jv, - JsonbVal::from(json!({ - "@type": "type.googleapis.com/test.Int32Value", - "value": 114514 - })) - ); - } - _ => panic!("Expected ScalarImpl::Jsonb"), - } - - Ok(()) - } - - // "id": 12345, - // "any_value": { - // "type_url": "type.googleapis.com/test.AnyValue", - // "value": { - // "any_value_1": { - // "type_url": "type.googleapis.com/test.StringValue", - // "value": "114514" - // }, - // "any_value_2": { - // "type_url": "type.googleapis.com/test.Int32Value", - // "value": 114514 - // } - // } - // } - static ANY_RECURSIVE_GEN_PROTO_DATA: &[u8] = b"\x08\xb9\x60\x12\x84\x01\x0a\x21\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x41\x6e\x79\x56\x61\x6c\x75\x65\x12\x5f\x0a\x30\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x56\x61\x6c\x75\x65\x12\x08\x0a\x06\x31\x31\x34\x35\x31\x34\x12\x2b\x0a\x23\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x49\x6e\x74\x33\x32\x56\x61\x6c\x75\x65\x12\x04\x08\xd2\xfe\x06"; - - #[tokio::test] - async fn test_any_recursive() -> crate::error::ConnectorResult<()> { - let conf = create_recursive_pb_parser_config("/any-schema.pb", "test.TestAny").await; - - println!("Current conf: {:#?}", conf); - println!("---------------------------"); - - let message = DynamicMessage::decode( - conf.message_descriptor.clone(), - ANY_RECURSIVE_GEN_PROTO_DATA, - ) - .unwrap(); - - println!("Current Value: {:#?}", message); - println!("---------------------------"); - - let field = conf - .message_descriptor - .get_field_by_name("any_value") - .unwrap(); - let value = message.get_field(&field); - - let ret = from_protobuf_value(&field, &value, &DataType::Jsonb) - .unwrap() - .to_owned_datum(); - println!("Decoded Value for ANY_RECURSIVE_GEN_PROTO_DATA: {:#?}", ret); - println!("---------------------------"); - - match ret { - Some(ScalarImpl::Jsonb(jv)) => { - assert_eq!( - jv, - JsonbVal::from(json!({ - "@type": "type.googleapis.com/test.AnyValue", - "anyValue1": { - "@type": "type.googleapis.com/test.StringValue", - "value": "114514", - }, - "anyValue2": { - "@type": "type.googleapis.com/test.Int32Value", - "value": 114514, - } - })) - ); - } - _ => panic!("Expected ScalarImpl::Jsonb"), - } - - Ok(()) - } - - // id: 12345 - // any_value: { - // type_url: "type.googleapis.com/test.StringXalue" - // value: "\n\010John Doe" - // } - static ANY_GEN_PROTO_DATA_INVALID: &[u8] = b"\x08\xb9\x60\x12\x32\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x58\x61\x6c\x75\x65\x12\x0a\x0a\x08\x4a\x6f\x68\x6e\x20\x44\x6f\x65"; - - #[tokio::test] - async fn test_any_invalid() -> crate::error::ConnectorResult<()> { - let conf = create_recursive_pb_parser_config("/any-schema.pb", "test.TestAny").await; - - let message = - DynamicMessage::decode(conf.message_descriptor.clone(), ANY_GEN_PROTO_DATA_INVALID) - .unwrap(); - - let field = conf - .message_descriptor - .get_field_by_name("any_value") - .unwrap(); - let value = message.get_field(&field); - - let err = from_protobuf_value(&field, &value, &DataType::Jsonb).unwrap_err(); - - let expected = expect_test::expect![[r#" - Fail to convert protobuf Any into jsonb - - Caused by: - message 'test.StringXalue' not found - "#]]; - expected.assert_eq(err.to_report_string_pretty().as_str()); - - Ok(()) - } #[test] fn test_decode_varint_zigzag() { diff --git a/src/connector/src/parser/unified/mod.rs b/src/connector/src/parser/unified/mod.rs index fdfe3aae6aaee..adf32df572307 100644 --- a/src/connector/src/parser/unified/mod.rs +++ b/src/connector/src/parser/unified/mod.rs @@ -17,11 +17,11 @@ use auto_impl::auto_impl; use risingwave_common::types::{DataType, DatumCow}; use risingwave_connector_codec::decoder::avro::AvroAccess; -pub use risingwave_connector_codec::decoder::{uncategorized, Access, AccessError, AccessResult}; +use risingwave_connector_codec::decoder::protobuf::ProtobufAccess; +pub use risingwave_connector_codec::decoder::{Access, AccessError, AccessResult}; use self::bytes::BytesAccess; use self::json::JsonAccess; -use self::protobuf::ProtobufAccess; use crate::parser::unified::debezium::MongoJsonAccess; use crate::source::SourceColumnDesc; @@ -30,7 +30,6 @@ pub mod debezium; pub mod json; pub mod kv_event; pub mod maxwell; -pub mod protobuf; pub mod util; pub enum AccessImpl<'a> { diff --git a/src/connector/src/schema/protobuf.rs b/src/connector/src/schema/protobuf.rs index d140af83c853f..634d692066ac1 100644 --- a/src/connector/src/schema/protobuf.rs +++ b/src/connector/src/schema/protobuf.rs @@ -13,9 +13,10 @@ // limitations under the License. use std::collections::BTreeMap; +use std::path::PathBuf; -use itertools::Itertools as _; use prost_reflect::{DescriptorPool, FileDescriptor, MessageDescriptor}; +use risingwave_connector_codec::common::protobuf::compile_pb; use super::loader::{LoadedSchema, SchemaLoader}; use super::schema_registry::Subject; @@ -98,91 +99,29 @@ pub async fn fetch_from_registry( impl LoadedSchema for FileDescriptor { fn compile(primary: Subject, references: Vec) -> Result { let primary_name = primary.name.clone(); - match compile_pb(primary, references) { - Err(e) => Err(SchemaFetchError::SchemaCompile(e.into())), - Ok(b) => { - let pool = DescriptorPool::decode(b.as_slice()) - .map_err(|e| SchemaFetchError::SchemaCompile(e.into()))?; - pool.get_file_by_name(&primary_name).ok_or_else(|| { - SchemaFetchError::SchemaCompile( - anyhow::anyhow!("{primary_name} lost after compilation").into(), - ) - }) - } - } - } -} - -macro_rules! embed_wkts { - [$( $path:literal ),+ $(,)?] => { - &[$( - ( - concat!("google/protobuf/", $path), - include_bytes!(concat!(env!("PROTO_INCLUDE"), "/google/protobuf/", $path)).as_slice(), + let compiled_pb = compile_pb_subject(primary, references)?; + let pool = DescriptorPool::decode(compiled_pb.as_slice()) + .map_err(|e| SchemaFetchError::SchemaCompile(e.into()))?; + pool.get_file_by_name(&primary_name).ok_or_else(|| { + SchemaFetchError::SchemaCompile( + anyhow::anyhow!("{primary_name} lost after compilation").into(), ) - ),+] - }; -} -const WELL_KNOWN_TYPES: &[(&str, &[u8])] = embed_wkts![ - "any.proto", - "api.proto", - "compiler/plugin.proto", - "descriptor.proto", - "duration.proto", - "empty.proto", - "field_mask.proto", - "source_context.proto", - "struct.proto", - "timestamp.proto", - "type.proto", - "wrappers.proto", -]; - -#[derive(Debug, thiserror::Error)] -pub enum PbCompileError { - #[error("build_file_descriptor_set failed\n{}", errs.iter().map(|e| format!("\t{e}")).join("\n"))] - Build { - errs: Vec, - }, - #[error("serialize descriptor set failed")] - Serialize, + }) + } } -pub fn compile_pb( +fn compile_pb_subject( primary_subject: Subject, dependency_subjects: Vec, -) -> Result, PbCompileError> { - use std::iter; - use std::path::Path; - - use protobuf_native::compiler::{ - SimpleErrorCollector, SourceTreeDescriptorDatabase, VirtualSourceTree, - }; - use protobuf_native::MessageLite; - - let mut source_tree = VirtualSourceTree::new(); - for subject in iter::once(&primary_subject).chain(dependency_subjects.iter()) { - source_tree.as_mut().add_file( - Path::new(&subject.name), - subject.schema.content.as_bytes().to_vec(), - ); - } - for (path, bytes) in WELL_KNOWN_TYPES { - source_tree - .as_mut() - .add_file(Path::new(path), bytes.to_vec()); - } - - let mut error_collector = SimpleErrorCollector::new(); - // `db` needs to be dropped before we can iterate on `error_collector`. - let fds = { - let mut db = SourceTreeDescriptorDatabase::new(source_tree.as_mut()); - db.as_mut().record_errors_to(error_collector.as_mut()); - db.as_mut() - .build_file_descriptor_set(&[Path::new(&primary_subject.name)]) - } - .map_err(|_| PbCompileError::Build { - errs: error_collector.as_mut().collect(), - })?; - fds.serialize().map_err(|_| PbCompileError::Serialize) +) -> Result, SchemaFetchError> { + compile_pb( + ( + PathBuf::from(&primary_subject.name), + primary_subject.schema.content.as_bytes().to_vec(), + ), + dependency_subjects + .into_iter() + .map(|s| (PathBuf::from(&s.name), s.schema.content.as_bytes().to_vec())), + ) + .map_err(|e| SchemaFetchError::SchemaCompile(e.into())) } diff --git a/src/connector/src/sink/encoder/proto.rs b/src/connector/src/sink/encoder/proto.rs index 8046606b5690c..ce6e8503b624e 100644 --- a/src/connector/src/sink/encoder/proto.rs +++ b/src/connector/src/sink/encoder/proto.rs @@ -440,10 +440,10 @@ mod tests { #[test] fn test_encode_proto_ok() { let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) - .join("src/test_data/proto_recursive/recursive.pb"); + .join("codec/tests/test_data/all-types.pb"); let pool_bytes = std::fs::read(pool_path).unwrap(); let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap(); - let descriptor = pool.get_message_by_name("recursive.AllTypes").unwrap(); + let descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap(); let schema = Schema::new(vec![ Field::with_name(DataType::Boolean, "bool_field"), Field::with_name(DataType::Varchar, "string_field"), @@ -495,7 +495,7 @@ mod tests { // Hint: write the binary output to a file `test.binpb`, and view it with `protoc`: // ``` // protoc --decode_raw < test.binpb - // protoc --decode=recursive.AllTypes recursive.proto < test.binpb + // protoc --decode=all_types.AllTypes all-types.proto < test.binpb // ``` [ 9, 0, 0, 0, 0, 0, 0, 17, 64, 21, 0, 0, 96, 64, 24, 22, 32, 23, 56, 48, 93, 26, 0, @@ -509,10 +509,10 @@ mod tests { #[test] fn test_encode_proto_repeated() { let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) - .join("src/test_data/proto_recursive/recursive.pb"); - let pool_bytes = std::fs::read(pool_path).unwrap(); + .join("codec/tests/test_data/all-types.pb"); + let pool_bytes = fs_err::read(pool_path).unwrap(); let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap(); - let message_descriptor = pool.get_message_by_name("recursive.AllTypes").unwrap(); + let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap(); let schema = Schema::new(vec![Field::with_name( DataType::List(DataType::List(DataType::Int32.into()).into()), @@ -561,10 +561,10 @@ mod tests { #[test] fn test_encode_proto_err() { let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) - .join("src/test_data/proto_recursive/recursive.pb"); + .join("codec/tests/test_data/all-types.pb"); let pool_bytes = std::fs::read(pool_path).unwrap(); let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap(); - let message_descriptor = pool.get_message_by_name("recursive.AllTypes").unwrap(); + let message_descriptor = pool.get_message_by_name("all_types.AllTypes").unwrap(); let err = validate_fields( std::iter::once(("not_exists", &DataType::Int16)), diff --git a/src/connector/src/test_data/any-schema.pb b/src/connector/src/test_data/any-schema.pb deleted file mode 100644 index 977f64cec3775..0000000000000 --- a/src/connector/src/test_data/any-schema.pb +++ /dev/null @@ -1,30 +0,0 @@ - -ä -google/protobuf/any.protogoogle.protobuf"6 -Any -type_url ( RtypeUrl -value ( RvalueBv -com.google.protobufBAnyProtoPZ,google.golang.org/protobuf/types/known/anypb¢GPBªGoogle.Protobuf.WellKnownTypesbproto3 -á -any-schema.prototestgoogle/protobuf/any.proto"L -TestAny -id (Rid1 - any_value ( 2.google.protobuf.AnyRanyValue"# - StringValue -value ( Rvalue"" - -Int32Value -value (Rvalue"v -AnyValue4 - any_value_1 ( 2.google.protobuf.AnyR anyValue14 - any_value_2 ( 2.google.protobuf.AnyR anyValue2"@ -StringInt32Value -first ( Rfirst -second (Rsecond"Ž -StringStringInt32Value -first ( Rfirst. -second ( 2.test.StringInt32ValueRsecond. -third ( 2.test.Float32StringValueRthird"B -Float32StringValue -first (Rfirst -second ( Rsecondbproto3 \ No newline at end of file diff --git a/src/connector/src/test_data/complex-schema b/src/connector/src/test_data/complex-schema deleted file mode 100644 index ff7cd64120883945d3ad63fff87d21e3adf126ed..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 408 zcmYL_&q~8U5XMb*6OyS$8AJjqlpst;)XNfWw3@*=*6kLYav z)Z5H&z8Su6;D^Dvs>-_1?VV9uUCOlHS51Yy(WXg$9pLXq?=b>5&thll%sfFAB5+k@ zI(CVKqO0)=!X__iO_p`cXu!ljz-!>`Mw*yU*=({1Q+q%C*htY~oI{|hT_aUvVvlww zi>Q}84OATFcnQqtHugytjhEhZD=bYEFuIDjaGL4wTSC@%v06#jleG2u5Fc0Y1GgFATkh(KCk5@_JkCw4sX~@4GpAq&yEnle_nRA`{{SzzyAaOfNSyq diff --git a/src/connector/src/test_data/proto_recursive/recursive.proto b/src/connector/src/test_data/proto_recursive/recursive.proto deleted file mode 100644 index 93f177055788c..0000000000000 --- a/src/connector/src/test_data/proto_recursive/recursive.proto +++ /dev/null @@ -1,95 +0,0 @@ -syntax = "proto3"; - -import "google/protobuf/timestamp.proto"; -import "google/protobuf/duration.proto"; -import "google/protobuf/any.proto"; -import "google/protobuf/wrappers.proto"; - -package recursive; - -message ComplexRecursiveMessage { - string node_name = 1; - int32 node_id = 2; - - message Attributes { - string key = 1; - string value = 2; - } - - repeated Attributes attributes = 3; - - message Parent { - string parent_name = 1; - int32 parent_id = 2; - repeated ComplexRecursiveMessage siblings = 3; - } - - Parent parent = 4; - repeated ComplexRecursiveMessage children = 5; -} - -message AllTypes { - // standard types - double double_field = 1; - float float_field = 2; - int32 int32_field = 3; - int64 int64_field = 4; - uint32 uint32_field = 5; - uint64 uint64_field = 6; - sint32 sint32_field = 7; - sint64 sint64_field = 8; - fixed32 fixed32_field = 9; - fixed64 fixed64_field = 10; - sfixed32 sfixed32_field = 11; - sfixed64 sfixed64_field = 12; - bool bool_field = 13; - string string_field = 14; - - bytes bytes_field = 15; - - // enum - enum EnumType { - DEFAULT = 0; - OPTION1 = 1; - OPTION2 = 2; - } - EnumType enum_field = 16; - - // nested message - message NestedMessage { - int32 id = 1; - string name = 2; - } - NestedMessage nested_message_field = 17; - - // repeated field - repeated int32 repeated_int_field = 18; - - // oneof field - oneof example_oneof { - string oneof_string = 19; - int32 oneof_int32 = 20; - EnumType oneof_enum = 21; - } - - // // map field - // map map_field = 22; - - // timestamp - google.protobuf.Timestamp timestamp_field = 23; - - // duration - google.protobuf.Duration duration_field = 24; - - // any - google.protobuf.Any any_field = 25; - - // -- Unsupported - // // struct - // import "google/protobuf/struct.proto"; - // google.protobuf.Struct struct_field = 26; - - // wrapper types - google.protobuf.Int32Value int32_value_field = 27; - google.protobuf.StringValue string_value_field = 28; -} \ No newline at end of file diff --git a/src/connector/src/test_data/simple-schema b/src/connector/src/test_data/simple-schema deleted file mode 100644 index 97686ce9c478d..0000000000000 --- a/src/connector/src/test_data/simple-schema +++ /dev/null @@ -1,11 +0,0 @@ - -² -simple-schema.prototest"Œ - -TestRecord -id (Rid -address ( Raddress -city ( Rcity -zipcode (Rzipcode -rate (Rrate -date ( Rdatebproto3 \ No newline at end of file diff --git a/src/tests/simulation/src/slt.rs b/src/tests/simulation/src/slt.rs index 799602a00aa3f..7ac5a7b27d70b 100644 --- a/src/tests/simulation/src/slt.rs +++ b/src/tests/simulation/src/slt.rs @@ -497,8 +497,6 @@ fn hack_kafka_test(path: &Path) -> tempfile::NamedTempFile { let complex_avsc_full_path = std::fs::canonicalize("src/connector/src/test_data/complex-schema.avsc") .expect("failed to get schema path"); - let proto_full_path = std::fs::canonicalize("src/connector/src/test_data/complex-schema") - .expect("failed to get schema path"); let json_schema_full_path = std::fs::canonicalize("src/connector/src/test_data/complex-schema.json") .expect("failed to get schema path"); @@ -513,10 +511,6 @@ fn hack_kafka_test(path: &Path) -> tempfile::NamedTempFile { "/risingwave/avro-complex-schema.avsc", complex_avsc_full_path.to_str().unwrap(), ) - .replace( - "/risingwave/proto-complex-schema", - proto_full_path.to_str().unwrap(), - ) .replace( "/risingwave/json-complex-schema", json_schema_full_path.to_str().unwrap(),