diff --git a/Cargo.lock b/Cargo.lock index 11fa66aef18f4..ec551ead138da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -327,7 +327,26 @@ dependencies = [ "arrow-row 48.0.1", "arrow-schema 48.0.1", "arrow-select 48.0.1", - "arrow-string", + "arrow-string 48.0.1", +] + +[[package]] +name = "arrow" +version = "50.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa285343fba4d829d49985bdc541e3789cf6000ed0e84be7c039438df4a4e78c" +dependencies = [ + "arrow-arith 50.0.0", + "arrow-array 50.0.0", + "arrow-buffer 50.0.0", + "arrow-cast 50.0.0", + "arrow-data 50.0.0", + "arrow-ipc 50.0.0", + "arrow-ord 50.0.0", + "arrow-row 50.0.0", + "arrow-schema 50.0.0", + "arrow-select 50.0.0", + "arrow-string 50.0.0", ] [[package]] @@ -681,6 +700,22 @@ dependencies = [ "regex-syntax 0.8.2", ] +[[package]] +name = "arrow-string" +version = "50.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00f3b37f2aeece31a2636d1b037dabb69ef590e03bdc7eb68519b51ec86932a7" +dependencies = [ + "arrow-array 50.0.0", + "arrow-buffer 50.0.0", + "arrow-data 50.0.0", + "arrow-schema 50.0.0", + "arrow-select 50.0.0", + "num", + "regex", + "regex-syntax 0.8.2", +] + [[package]] name = "arrow-udf-js" version = "0.1.2" @@ -1737,6 +1772,7 @@ dependencies = [ "num-bigint", "num-integer", "num-traits", + "serde", ] [[package]] @@ -3286,7 +3322,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "676796427e638d85e9eadf13765705212be60b34f8fc5d3934d95184c63ca1b4" dependencies = [ "ahash 0.8.6", - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "arrow-schema 48.0.1", "async-compression", @@ -3333,7 +3369,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31e23b3d21a6531259d291bd20ce59282ea794bda1018b0a1e278c13cd52e50c" dependencies = [ "ahash 0.8.6", - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "arrow-buffer 48.0.1", "arrow-schema 48.0.1", @@ -3351,7 +3387,7 @@ version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4de1fd0d8db0f2b8e4f4121bfa1c7c09d3a5c08a0a65c2229cd849eb65cff855" dependencies = [ - "arrow", + "arrow 48.0.1", "chrono", "dashmap", "datafusion-common", @@ -3373,7 +3409,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18e227fe88bf6730cab378d0cd8fc4c6b2ea42bc7e414a8ea9feba7225932735" dependencies = [ "ahash 0.8.6", - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "datafusion-common", "sqlparser", @@ -3387,7 +3423,7 @@ version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6648e62ea7605b9bfcd87fdc9d67e579c3b9ac563a87734ae5fe6d79ee4547" dependencies = [ - "arrow", + "arrow 48.0.1", "async-trait", "chrono", "datafusion-common", @@ -3406,7 +3442,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f32b8574add16a32411a9b3fb3844ac1fc09ab4e7be289f86fd56d620e4f2508" dependencies = [ "ahash 0.8.6", - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "arrow-buffer 48.0.1", "arrow-ord 48.0.1", @@ -3441,7 +3477,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "796abd77d5bfecd9e5275a99daf0ec45f5b3a793ec431349ce8211a67826fd22" dependencies = [ "ahash 0.8.6", - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "arrow-buffer 48.0.1", "arrow-schema 48.0.1", @@ -3471,7 +3507,7 @@ version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26de2592417beb20f73f29b131a04d7de14e2a6336c631554d611584b4306236" dependencies = [ - "arrow", + "arrow 48.0.1", "chrono", "datafusion", "datafusion-common", @@ -3486,7 +3522,7 @@ version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced70b8a5648ba7b95c61fc512183c33287ffe2c9f22ffe22700619d7d48c84f" dependencies = [ - "arrow", + "arrow 48.0.1", "arrow-schema 48.0.1", "datafusion-common", "datafusion-expr", @@ -3543,7 +3579,7 @@ name = "deltalake-core" version = "0.17.0" source = "git+https://github.com/risingwavelabs/delta-rs?rev=5c2dccd4640490202ffe98adbd13b09cef8e007b#5c2dccd4640490202ffe98adbd13b09cef8e007b" dependencies = [ - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "arrow-buffer 48.0.1", "arrow-cast 48.0.1", @@ -5492,6 +5528,28 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "google-cloud-auth" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bf7cb7864f08a92e77c26bb230d021ea57691788fb5dd51793f96965d19e7f9" +dependencies = [ + "async-trait", + "base64 0.21.7", + "google-cloud-metadata", + "google-cloud-token", + "home", + "jsonwebtoken", + "reqwest 0.11.20", + "serde", + "serde_json", + "thiserror", + "time", + "tokio", + "tracing", + "urlencoding", +] + [[package]] name = "google-cloud-auth" version = "0.14.0" @@ -5514,6 +5572,33 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "google-cloud-bigquery" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c48abc8687f4c4cc143dd5bd3da5f1d7ef38334e4af5cef6de4c39295c6a3fd0" +dependencies = [ + "anyhow", + "arrow 50.0.0", + "async-trait", + "backon", + "base64 0.21.7", + "bigdecimal 0.4.2", + "google-cloud-auth 0.13.2", + "google-cloud-gax", + "google-cloud-googleapis", + "google-cloud-token", + "num-bigint", + "reqwest 0.11.20", + "reqwest-middleware", + "serde", + "serde_json", + "thiserror", + "time", + "tokio", + "tracing", +] + [[package]] name = "google-cloud-gax" version = "0.17.0" @@ -5560,7 +5645,7 @@ checksum = "0b2184a5c70b994e6d77eb1c140e193e7f5fe6015e9115322fac24f7e33f003c" dependencies = [ "async-channel", "async-stream", - "google-cloud-auth", + "google-cloud-auth 0.14.0", "google-cloud-gax", "google-cloud-googleapis", "google-cloud-token", @@ -9691,6 +9776,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", @@ -9767,6 +9853,21 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-middleware" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a735987236a8e238bf0296c7e351b999c188ccc11477f311b82b55c93984216" +dependencies = [ + "anyhow", + "async-trait", + "http 0.2.9", + "reqwest 0.11.20", + "serde", + "task-local-extensions", + "thiserror", +] + [[package]] name = "resolv-conf" version = "0.7.0" @@ -10407,6 +10508,9 @@ dependencies = [ "futures-async-stream", "gcp-bigquery-client", "glob", + "google-cloud-bigquery", + "google-cloud-gax", + "google-cloud-googleapis", "google-cloud-pubsub", "http 0.2.9", "icelake", @@ -13969,6 +14073,15 @@ version = "0.12.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" +[[package]] +name = "task-local-extensions" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba323866e5d033818e3240feeb9f7db2c4296674e4d9e16b97b7bf8f490434e8" +dependencies = [ + "pin-utils", +] + [[package]] name = "tempfile" version = "3.10.0" diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 66c33335501d5..be5242e819163 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -58,6 +58,9 @@ futures = { version = "0.3", default-features = false, features = ["alloc"] } futures-async-stream = { workspace = true } gcp-bigquery-client = "0.18.0" glob = "0.3" +google-cloud-bigquery = { version = "0.7.0", features = ["auth"] } +google-cloud-gax = "0.17.0" +google-cloud-googleapis = "0.12.0" google-cloud-pubsub = "0.24" http = "0.2" icelake = { workspace = true } diff --git a/src/connector/src/lib.rs b/src/connector/src/lib.rs index f28f3cdce5e77..70de9f8561a76 100644 --- a/src/connector/src/lib.rs +++ b/src/connector/src/lib.rs @@ -34,6 +34,7 @@ #![feature(error_generic_member_access)] #![feature(negative_impls)] #![feature(register_tool)] +#![feature(assert_matches)] #![register_tool(rw)] #![recursion_limit = "256"] diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index e7f74614c9cf0..ee385ad8c010e 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -13,27 +13,43 @@ // limitations under the License. use core::mem; +use core::time::Duration; use std::collections::HashMap; use std::sync::Arc; use anyhow::anyhow; use async_trait::async_trait; use gcp_bigquery_client::model::query_request::QueryRequest; -use gcp_bigquery_client::model::table_data_insert_all_request::TableDataInsertAllRequest; -use gcp_bigquery_client::model::table_data_insert_all_request_rows::TableDataInsertAllRequestRows; use gcp_bigquery_client::Client; +use google_cloud_bigquery::grpc::apiv1::bigquery_client::StreamingWriteClient; +use google_cloud_bigquery::grpc::apiv1::conn_pool::{WriteConnectionManager, DOMAIN}; +use google_cloud_gax::conn::{ConnectionOptions, Environment}; +use google_cloud_gax::grpc::Request; +use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::{ + ProtoData, Rows as AppendRowsRequestRows, +}; +use google_cloud_googleapis::cloud::bigquery::storage::v1::{ + AppendRowsRequest, ProtoRows, ProtoSchema, +}; +use google_cloud_pubsub::client::google_cloud_auth; +use google_cloud_pubsub::client::google_cloud_auth::credentials::CredentialsFile; +use prost_reflect::{FieldDescriptor, MessageDescriptor}; +use prost_types::{ + field_descriptor_proto, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, + FileDescriptorSet, +}; use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; use risingwave_common::types::DataType; use serde_derive::Deserialize; -use serde_json::Value; use serde_with::{serde_as, DisplayFromStr}; use url::Url; +use uuid::Uuid; use with_options::WithOptions; use yup_oauth2::ServiceAccountKey; -use super::encoder::{JsonEncoder, RowEncoder}; +use super::encoder::{ProtoEncoder, ProtoHeader, RowEncoder, SerTo}; use super::writer::LogSinkerOf; use super::{SinkError, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT}; use crate::aws_utils::load_file_descriptor_from_s3; @@ -44,6 +60,10 @@ use crate::sink::{ }; pub const BIGQUERY_SINK: &str = "bigquery"; +pub const CHANGE_TYPE: &str = "_CHANGE_TYPE"; +const DEFAULT_GRPC_CHANNEL_NUMS: usize = 4; +const CONNECT_TIMEOUT: Option = Some(Duration::from_secs(30)); +const CONNECTION_TIMEOUT: Option = None; #[serde_as] #[derive(Deserialize, Debug, Clone, WithOptions)] @@ -68,27 +88,44 @@ fn default_max_batch_rows() -> usize { } impl BigQueryCommon { - pub(crate) async fn build_client(&self, aws_auth_props: &AwsAuthProps) -> Result { - let service_account = if let Some(local_path) = &self.local_path { - let auth_json = std::fs::read_to_string(local_path) - .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; - serde_json::from_str::(&auth_json) - .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))? + async fn build_client(&self, aws_auth_props: &AwsAuthProps) -> Result { + let auth_json = self.get_auth_json_from_path(aws_auth_props).await?; + + let service_account = serde_json::from_str::(&auth_json) + .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; + let client: Client = Client::from_service_account_key(service_account, false) + .await + .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; + Ok(client) + } + + async fn build_writer_client( + &self, + aws_auth_props: &AwsAuthProps, + ) -> Result { + let auth_json = self.get_auth_json_from_path(aws_auth_props).await?; + + let credentials_file = CredentialsFile::new_from_str(&auth_json) + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; + let client = StorageWriterClient::new(credentials_file).await?; + Ok(client) + } + + async fn get_auth_json_from_path(&self, aws_auth_props: &AwsAuthProps) -> Result { + if let Some(local_path) = &self.local_path { + std::fs::read_to_string(local_path) + .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err))) } else if let Some(s3_path) = &self.s3_path { let url = Url::parse(s3_path).map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; - let auth_json = load_file_descriptor_from_s3(&url, aws_auth_props) + let auth_vec = load_file_descriptor_from_s3(&url, aws_auth_props) .await .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; - serde_json::from_slice::(&auth_json) - .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))? + Ok(String::from_utf8(auth_vec).map_err(|e| SinkError::BigQuery(e.into()))?) } else { - return Err(SinkError::BigQuery(anyhow::anyhow!("`bigquery.local.path` and `bigquery.s3.path` set at least one, configure as needed."))); - }; - let client: Client = Client::from_service_account_key(service_account, false) - .await - .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; - Ok(client) + Err(SinkError::BigQuery(anyhow::anyhow!("`bigquery.local.path` and `bigquery.s3.path` set at least one, configure as needed."))) + } } } @@ -187,9 +224,7 @@ impl BigQuerySink { DataType::Decimal => Ok("NUMERIC".to_owned()), DataType::Date => Ok("DATE".to_owned()), DataType::Varchar => Ok("STRING".to_owned()), - DataType::Time => Err(SinkError::BigQuery(anyhow::anyhow!( - "Bigquery cannot support Time" - ))), + DataType::Time => Ok("TIME".to_owned()), DataType::Timestamp => Ok("DATETIME".to_owned()), DataType::Timestamptz => Ok("TIMESTAMP".to_owned()), DataType::Interval => Ok("INTERVAL".to_owned()), @@ -234,12 +269,10 @@ impl Sink for BigQuerySink { } async fn validate(&self) -> Result<()> { - if !self.is_append_only { + if !self.is_append_only && self.pk_indices.is_empty() { return Err(SinkError::Config(anyhow!( - "BigQuery sink don't support upsert" - ))); + "Primary key not defined for upsert bigquery sink (please define in `primary_key` field)"))); } - let client = self .config .common @@ -280,10 +313,15 @@ pub struct BigQuerySinkWriter { pub config: BigQueryConfig, schema: Schema, pk_indices: Vec, - client: Client, + client: StorageWriterClient, is_append_only: bool, - insert_request: TableDataInsertAllRequest, - row_encoder: JsonEncoder, + row_encoder: ProtoEncoder, + writer_pb_schema: ProtoSchema, + message_descriptor: MessageDescriptor, + write_stream: String, + proto_field: Option, + write_rows: Vec, + write_rows_count: usize, } impl TryFrom for BigQuerySink { @@ -308,66 +346,131 @@ impl BigQuerySinkWriter { pk_indices: Vec, is_append_only: bool, ) -> Result { - let client = config.common.build_client(&config.aws_auth_props).await?; + let client = config + .common + .build_writer_client(&config.aws_auth_props) + .await?; + let mut descriptor_proto = build_protobuf_schema( + schema + .fields() + .iter() + .map(|f| (f.name.as_str(), &f.data_type)), + config.common.table.clone(), + )?; + + if !is_append_only { + let field = FieldDescriptorProto { + name: Some(CHANGE_TYPE.to_string()), + number: Some((schema.len() + 1) as i32), + r#type: Some(field_descriptor_proto::Type::String.into()), + ..Default::default() + }; + descriptor_proto.field.push(field); + } + + let descriptor_pool = build_protobuf_descriptor_pool(&descriptor_proto); + let message_descriptor = descriptor_pool + .get_message_by_name(&config.common.table) + .ok_or_else(|| { + SinkError::BigQuery(anyhow::anyhow!( + "Can't find message proto {}", + &config.common.table + )) + })?; + let proto_field = if !is_append_only { + let proto_field = message_descriptor + .get_field_by_name(CHANGE_TYPE) + .ok_or_else(|| { + SinkError::BigQuery(anyhow::anyhow!("Can't find {}", CHANGE_TYPE)) + })?; + Some(proto_field) + } else { + None + }; + let row_encoder = ProtoEncoder::new( + schema.clone(), + None, + message_descriptor.clone(), + ProtoHeader::None, + )?; Ok(Self { + write_stream: format!( + "projects/{}/datasets/{}/tables/{}/streams/_default", + config.common.project, config.common.dataset, config.common.table + ), config, - schema: schema.clone(), + schema, pk_indices, client, is_append_only, - insert_request: TableDataInsertAllRequest::new(), - row_encoder: JsonEncoder::new_with_bigquery(schema, None), + row_encoder, + message_descriptor, + proto_field, + writer_pb_schema: ProtoSchema { + proto_descriptor: Some(descriptor_proto), + }, + write_rows: vec![], + write_rows_count: 0, }) } - async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> { - let mut insert_vec = Vec::with_capacity(chunk.capacity()); + fn append_only(&mut self, chunk: StreamChunk) -> Result>> { + let mut serialized_rows: Vec> = Vec::with_capacity(chunk.capacity()); for (op, row) in chunk.rows() { if op != Op::Insert { - return Err(SinkError::BigQuery(anyhow::anyhow!( - "BigQuery sink don't support upsert" - ))); + continue; } - insert_vec.push(TableDataInsertAllRequestRows { - insert_id: None, - json: Value::Object(self.row_encoder.encode(row)?), - }) + serialized_rows.push(self.row_encoder.encode(row)?.ser_to()?) } - self.insert_request - .add_rows(insert_vec) - .map_err(|e| SinkError::BigQuery(e.into()))?; - if self - .insert_request - .len() - .ge(&self.config.common.max_batch_rows) - { - self.insert_data().await?; - } - Ok(()) + Ok(serialized_rows) } - async fn insert_data(&mut self) -> Result<()> { - if !self.insert_request.is_empty() { - let insert_request = - mem::replace(&mut self.insert_request, TableDataInsertAllRequest::new()); - let request = self - .client - .tabledata() - .insert_all( - &self.config.common.project, - &self.config.common.dataset, - &self.config.common.table, - insert_request, - ) - .await - .map_err(|e| SinkError::BigQuery(e.into()))?; - if let Some(error) = request.insert_errors { - return Err(SinkError::BigQuery(anyhow::anyhow!( - "Insert error: {:?}", - error - ))); + fn upsert(&mut self, chunk: StreamChunk) -> Result>> { + let mut serialized_rows: Vec> = Vec::with_capacity(chunk.capacity()); + for (op, row) in chunk.rows() { + if op == Op::UpdateDelete { + continue; } + let mut pb_row = self.row_encoder.encode(row)?; + match op { + Op::Insert => pb_row + .message + .try_set_field( + self.proto_field.as_ref().unwrap(), + prost_reflect::Value::String("UPSERT".to_string()), + ) + .map_err(|e| SinkError::BigQuery(e.into()))?, + Op::Delete => pb_row + .message + .try_set_field( + self.proto_field.as_ref().unwrap(), + prost_reflect::Value::String("DELETE".to_string()), + ) + .map_err(|e| SinkError::BigQuery(e.into()))?, + Op::UpdateDelete => continue, + Op::UpdateInsert => pb_row + .message + .try_set_field( + self.proto_field.as_ref().unwrap(), + prost_reflect::Value::String("UPSERT".to_string()), + ) + .map_err(|e| SinkError::BigQuery(e.into()))?, + }; + + serialized_rows.push(pb_row.ser_to()?) } + Ok(serialized_rows) + } + + async fn write_rows(&mut self) -> Result<()> { + if self.write_rows.is_empty() { + return Ok(()); + } + let rows = mem::take(&mut self.write_rows); + self.write_rows_count = 0; + self.client + .append_rows(rows, self.write_stream.clone()) + .await?; Ok(()) } } @@ -375,13 +478,24 @@ impl BigQuerySinkWriter { #[async_trait] impl SinkWriter for BigQuerySinkWriter { async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> { - if self.is_append_only { - self.append_only(chunk).await + let serialized_rows = if self.is_append_only { + self.append_only(chunk)? } else { - Err(SinkError::BigQuery(anyhow::anyhow!( - "BigQuery sink don't support upsert" - ))) + self.upsert(chunk)? + }; + if !serialized_rows.is_empty() { + self.write_rows_count += serialized_rows.len(); + let rows = AppendRowsRequestRows::ProtoRows(ProtoData { + writer_schema: Some(self.writer_pb_schema.clone()), + rows: Some(ProtoRows { serialized_rows }), + }); + self.write_rows.push(rows); + + if self.write_rows_count >= self.config.common.max_batch_rows { + self.write_rows().await?; + } } + Ok(()) } async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> { @@ -392,8 +506,11 @@ impl SinkWriter for BigQuerySinkWriter { Ok(()) } - async fn barrier(&mut self, _is_checkpoint: bool) -> Result<()> { - self.insert_data().await + async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> { + if is_checkpoint { + self.write_rows().await?; + } + Ok(()) } async fn update_vnode_bitmap(&mut self, _vnode_bitmap: Arc) -> Result<()> { @@ -401,11 +518,183 @@ impl SinkWriter for BigQuerySinkWriter { } } +struct StorageWriterClient { + client: StreamingWriteClient, + environment: Environment, +} +impl StorageWriterClient { + pub async fn new(credentials: CredentialsFile) -> Result { + let ts_grpc = google_cloud_auth::token::DefaultTokenSourceProvider::new_with_credentials( + Self::bigquery_grpc_auth_config(), + Box::new(credentials), + ) + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; + let conn_options = ConnectionOptions { + connect_timeout: CONNECT_TIMEOUT, + timeout: CONNECTION_TIMEOUT, + }; + let environment = Environment::GoogleCloud(Box::new(ts_grpc)); + let conn = WriteConnectionManager::new( + DEFAULT_GRPC_CHANNEL_NUMS, + &environment, + DOMAIN, + &conn_options, + ) + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; + let client = conn.conn(); + Ok(StorageWriterClient { + client, + environment, + }) + } + + pub async fn append_rows( + &mut self, + rows: Vec, + write_stream: String, + ) -> Result<()> { + let mut resp_count = rows.len(); + let append_req: Vec = rows + .into_iter() + .map(|row| AppendRowsRequest { + write_stream: write_stream.clone(), + offset: None, + trace_id: Uuid::new_v4().hyphenated().to_string(), + missing_value_interpretations: HashMap::default(), + rows: Some(row), + }) + .collect(); + let mut resp = self + .client + .append_rows(Request::new(tokio_stream::iter(append_req))) + .await + .map_err(|e| SinkError::BigQuery(e.into()))? + .into_inner(); + while let Some(append_rows_response) = resp + .message() + .await + .map_err(|e| SinkError::BigQuery(e.into()))? + { + resp_count -= 1; + if !append_rows_response.row_errors.is_empty() { + return Err(SinkError::BigQuery(anyhow::anyhow!( + "Insert error {:?}", + append_rows_response.row_errors + ))); + } + } + assert_eq!(resp_count,0,"bigquery sink insert error: the number of response inserted is not equal to the number of request"); + Ok(()) + } + + fn bigquery_grpc_auth_config() -> google_cloud_auth::project::Config<'static> { + google_cloud_auth::project::Config { + audience: Some(google_cloud_bigquery::grpc::apiv1::conn_pool::AUDIENCE), + scopes: Some(&google_cloud_bigquery::grpc::apiv1::conn_pool::SCOPES), + sub: None, + } + } +} + +fn build_protobuf_descriptor_pool(desc: &DescriptorProto) -> prost_reflect::DescriptorPool { + let file_descriptor = FileDescriptorProto { + message_type: vec![desc.clone()], + name: Some("bigquery".to_string()), + ..Default::default() + }; + + prost_reflect::DescriptorPool::from_file_descriptor_set(FileDescriptorSet { + file: vec![file_descriptor], + }) + .unwrap() +} + +fn build_protobuf_schema<'a>( + fields: impl Iterator, + name: String, +) -> Result { + let mut proto = DescriptorProto { + name: Some(name), + ..Default::default() + }; + let mut struct_vec = vec![]; + let field_vec = fields + .enumerate() + .map(|(index, (name, data_type))| { + let (field, des_proto) = + build_protobuf_field(data_type, (index + 1) as i32, name.to_string())?; + if let Some(sv) = des_proto { + struct_vec.push(sv); + } + Ok(field) + }) + .collect::>>()?; + proto.field = field_vec; + proto.nested_type = struct_vec; + Ok(proto) +} + +fn build_protobuf_field( + data_type: &DataType, + index: i32, + name: String, +) -> Result<(FieldDescriptorProto, Option)> { + let mut field = FieldDescriptorProto { + name: Some(name.clone()), + number: Some(index), + ..Default::default() + }; + match data_type { + DataType::Boolean => field.r#type = Some(field_descriptor_proto::Type::Bool.into()), + DataType::Int32 => field.r#type = Some(field_descriptor_proto::Type::Int32.into()), + DataType::Int16 | DataType::Int64 => { + field.r#type = Some(field_descriptor_proto::Type::Int64.into()) + } + DataType::Float64 => field.r#type = Some(field_descriptor_proto::Type::Double.into()), + DataType::Decimal => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Date => field.r#type = Some(field_descriptor_proto::Type::Int32.into()), + DataType::Varchar => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Time => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Timestamp => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Timestamptz => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Interval => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Struct(s) => { + field.r#type = Some(field_descriptor_proto::Type::Message.into()); + let name = format!("Struct{}", name); + let sub_proto = build_protobuf_schema(s.iter(), name.clone())?; + field.type_name = Some(name); + return Ok((field, Some(sub_proto))); + } + DataType::List(l) => { + let (mut field, proto) = build_protobuf_field(l.as_ref(), index, name.clone())?; + field.label = Some(field_descriptor_proto::Label::Repeated.into()); + return Ok((field, proto)); + } + DataType::Bytea => field.r#type = Some(field_descriptor_proto::Type::Bytes.into()), + DataType::Jsonb => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Serial => field.r#type = Some(field_descriptor_proto::Type::Int64.into()), + DataType::Float32 | DataType::Int256 => { + return Err(SinkError::BigQuery(anyhow::anyhow!( + "Don't support Float32 and Int256" + ))) + } + } + Ok((field, None)) +} + #[cfg(test)] mod test { + + use std::assert_matches::assert_matches; + + use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::{DataType, StructType}; - use crate::sink::big_query::BigQuerySink; + use crate::sink::big_query::{ + build_protobuf_descriptor_pool, build_protobuf_schema, BigQuerySink, + }; #[tokio::test] async fn test_type_check() { @@ -425,4 +714,63 @@ mod test { big_query_type_string ); } + + #[tokio::test] + async fn test_schema_check() { + let schema = Schema { + fields: vec![ + Field::with_name(DataType::Int64, "v1"), + Field::with_name(DataType::Float64, "v2"), + Field::with_name( + DataType::List(Box::new(DataType::Struct(StructType::new(vec![ + ("v1".to_owned(), DataType::List(Box::new(DataType::Int64))), + ( + "v3".to_owned(), + DataType::Struct(StructType::new(vec![ + ("v1".to_owned(), DataType::Int64), + ("v2".to_owned(), DataType::Int64), + ])), + ), + ])))), + "v3", + ), + ], + }; + let fields = schema + .fields() + .iter() + .map(|f| (f.name.as_str(), &f.data_type)); + let desc = build_protobuf_schema(fields, "t1".to_string()).unwrap(); + let pool = build_protobuf_descriptor_pool(&desc); + let t1_message = pool.get_message_by_name("t1").unwrap(); + assert_matches!( + t1_message.get_field_by_name("v1").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + assert_matches!( + t1_message.get_field_by_name("v2").unwrap().kind(), + prost_reflect::Kind::Double + ); + assert_matches!( + t1_message.get_field_by_name("v3").unwrap().kind(), + prost_reflect::Kind::Message(_) + ); + + let v3_message = pool.get_message_by_name("t1.Structv3").unwrap(); + assert_matches!( + v3_message.get_field_by_name("v1").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + assert!(v3_message.get_field_by_name("v1").unwrap().is_list()); + + let v3_v3_message = pool.get_message_by_name("t1.Structv3.Structv3").unwrap(); + assert_matches!( + v3_v3_message.get_field_by_name("v1").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + assert_matches!( + v3_v3_message.get_field_by_name("v2").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + } } diff --git a/src/connector/src/sink/encoder/json.rs b/src/connector/src/sink/encoder/json.rs index bbd424d5db036..6eb7ca5cd1b50 100644 --- a/src/connector/src/sink/encoder/json.rs +++ b/src/connector/src/sink/encoder/json.rs @@ -110,19 +110,6 @@ impl JsonEncoder { } } - pub fn new_with_bigquery(schema: Schema, col_indices: Option>) -> Self { - Self { - schema, - col_indices, - time_handling_mode: TimeHandlingMode::Milli, - date_handling_mode: DateHandlingMode::String, - timestamp_handling_mode: TimestampHandlingMode::String, - timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, - custom_json_type: CustomJsonType::BigQuery, - kafka_connect: None, - } - } - pub fn with_kafka_connect(self, kafka_connect: KafkaConnectParams) -> Self { Self { kafka_connect: Some(Arc::new(kafka_connect)), @@ -200,14 +187,7 @@ fn datum_to_json_object( ) -> ArrayResult { let scalar_ref = match datum { None => { - if let CustomJsonType::BigQuery = custom_json_type - && matches!(field.data_type(), DataType::List(_)) - { - // Bigquery need to convert null of array to empty array https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types - return Ok(Value::Array(vec![])); - } else { - return Ok(Value::Null); - } + return Ok(Value::Null); } Some(datum) => datum, }; @@ -245,10 +225,7 @@ fn datum_to_json_object( v.rescale(*s as u32); json!(v.to_text()) } - CustomJsonType::Es - | CustomJsonType::None - | CustomJsonType::BigQuery - | CustomJsonType::StarRocks => { + CustomJsonType::Es | CustomJsonType::None | CustomJsonType::StarRocks => { json!(v.to_text()) } }, @@ -300,7 +277,7 @@ fn datum_to_json_object( } (DataType::Jsonb, ScalarRefImpl::Jsonb(jsonb_ref)) => match custom_json_type { CustomJsonType::Es | CustomJsonType::StarRocks => JsonbVal::from(jsonb_ref).take(), - CustomJsonType::Doris(_) | CustomJsonType::None | CustomJsonType::BigQuery => { + CustomJsonType::Doris(_) | CustomJsonType::None => { json!(jsonb_ref.to_string()) } }, @@ -351,7 +328,7 @@ fn datum_to_json_object( "starrocks can't support struct".to_string(), )); } - CustomJsonType::Es | CustomJsonType::None | CustomJsonType::BigQuery => { + CustomJsonType::Es | CustomJsonType::None => { let mut map = Map::with_capacity(st.len()); for (sub_datum_ref, sub_field) in struct_ref.iter_fields_ref().zip_eq_debug( st.iter() diff --git a/src/connector/src/sink/encoder/mod.rs b/src/connector/src/sink/encoder/mod.rs index 0b6899dbad955..dd590d0302ecf 100644 --- a/src/connector/src/sink/encoder/mod.rs +++ b/src/connector/src/sink/encoder/mod.rs @@ -144,8 +144,6 @@ pub enum CustomJsonType { Es, // starrocks' need jsonb is struct StarRocks, - // bigquery need null array -> [] - BigQuery, None, } diff --git a/src/connector/src/sink/encoder/proto.rs b/src/connector/src/sink/encoder/proto.rs index a5f1090dbafaf..3f50b3d97ff26 100644 --- a/src/connector/src/sink/encoder/proto.rs +++ b/src/connector/src/sink/encoder/proto.rs @@ -77,7 +77,7 @@ impl ProtoEncoder { } pub struct ProtoEncoded { - message: DynamicMessage, + pub message: DynamicMessage, header: ProtoHeader, } @@ -307,7 +307,6 @@ fn encode_field( proto_field.kind() ))) }; - let value = match &data_type { // Group A: perfect match between RisingWave types and ProtoBuf types DataType::Boolean => match (expect_list, proto_field.kind()) { @@ -364,18 +363,59 @@ fn encode_field( Ok(Value::Message(message.transcode_to_dynamic())) })? } + (false, Kind::String) => { + maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))? + } + _ => return no_match_err(), + }, + DataType::Jsonb => match (expect_list, proto_field.kind()) { + (false, Kind::String) => { + maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))? + } + _ => return no_match_err(), /* Value, NullValue, Struct (map), ListValue + * Group C: experimental */ + }, + DataType::Int16 => match (expect_list, proto_field.kind()) { + (false, Kind::Int64) => maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))?, _ => return no_match_err(), }, - DataType::Jsonb => return no_match_err(), // Value, NullValue, Struct (map), ListValue - // Group C: experimental - DataType::Int16 => return no_match_err(), - DataType::Date => return no_match_err(), // google.type.Date - DataType::Time => return no_match_err(), // google.type.TimeOfDay - DataType::Timestamp => return no_match_err(), // google.type.DateTime - DataType::Decimal => return no_match_err(), // google.type.Decimal - DataType::Interval => return no_match_err(), - // Group D: unsupported - DataType::Serial | DataType::Int256 => { + DataType::Date => match (expect_list, proto_field.kind()) { + (false, Kind::Int32) => { + maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))? + } + _ => return no_match_err(), // google.type.Date + }, + DataType::Time => match (expect_list, proto_field.kind()) { + (false, Kind::String) => { + maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))? + } + _ => return no_match_err(), // google.type.TimeOfDay + }, + DataType::Timestamp => match (expect_list, proto_field.kind()) { + (false, Kind::String) => { + maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))? + } + _ => return no_match_err(), // google.type.DateTime + }, + DataType::Decimal => match (expect_list, proto_field.kind()) { + (false, Kind::String) => { + maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))? + } + _ => return no_match_err(), // google.type.Decimal + }, + DataType::Interval => match (expect_list, proto_field.kind()) { + (false, Kind::String) => { + maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))? + } + _ => return no_match_err(), // Group D: unsupported + }, + DataType::Serial => match (expect_list, proto_field.kind()) { + (false, Kind::Int64) => { + maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))? + } + _ => return no_match_err(), // Group D: unsupported + }, + DataType::Int256 => { return no_match_err(); } }; @@ -398,7 +438,6 @@ mod tests { let pool_bytes = std::fs::read(pool_path).unwrap(); let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap(); let descriptor = pool.get_message_by_name("recursive.AllTypes").unwrap(); - let schema = Schema::new(vec![ Field::with_name(DataType::Boolean, "bool_field"), Field::with_name(DataType::Varchar, "string_field"),