Skip to content

Commit

Permalink
refactor(connector): share schema registry loader between avro and pr…
Browse files Browse the repository at this point in the history
…otobuf (#14642)
  • Loading branch information
xiangjinwu authored Mar 8, 2024
1 parent 6f9244d commit ef7eb1d
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 188 deletions.
1 change: 0 additions & 1 deletion src/connector/src/parser/protobuf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

mod parser;
pub use parser::*;
mod schema_resolver;

#[rustfmt::skip]
#[cfg(test)]
Expand Down
40 changes: 20 additions & 20 deletions src/connector/src/parser/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use std::sync::Arc;
use anyhow::Context;
use itertools::Itertools;
use prost_reflect::{
Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor,
ReflectMessage, Value,
Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, FileDescriptor, Kind,
MessageDescriptor, ReflectMessage, Value,
};
use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::types::{DataType, Datum, Decimal, JsonbVal, ScalarImpl, F32, F64};
Expand All @@ -27,17 +27,15 @@ use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion
use thiserror::Error;
use thiserror_ext::{AsReport, Macro};

use super::schema_resolver::*;
use crate::error::ConnectorResult;
use crate::parser::unified::protobuf::ProtobufAccess;
use crate::parser::unified::{
bail_uncategorized, uncategorized, AccessError, AccessImpl, AccessResult,
};
use crate::parser::util::bytes_from_url;
use crate::parser::{AccessBuilder, EncodingProperties};
use crate::schema::schema_registry::{
extract_schema_id, get_subject_by_strategy, handle_sr_list, Client, WireFormatError,
};
use crate::schema::schema_registry::{extract_schema_id, handle_sr_list, Client, WireFormatError};
use crate::schema::SchemaLoader;

#[derive(Debug)]
pub struct ProtobufAccessBuilder {
Expand Down Expand Up @@ -100,25 +98,27 @@ impl ProtobufParserConfig {
// https://docs.confluent.io/platform/7.5/control-center/topics/schema.html#c3-schemas-best-practices-key-value-pairs
bail!("protobuf key is not supported");
}
let schema_bytes = if protobuf_config.use_schema_registry {
let schema_value = get_subject_by_strategy(
&protobuf_config.name_strategy,
protobuf_config.topic.as_str(),
Some(message_name.as_ref()),
false,
)?;
tracing::debug!("infer value subject {schema_value}");

let pool = if protobuf_config.use_schema_registry {
let client = Client::new(url, &protobuf_config.client_config)?;
compile_file_descriptor_from_schema_registry(schema_value.as_str(), &client).await?
let loader = SchemaLoader {
client,
name_strategy: protobuf_config.name_strategy,
topic: protobuf_config.topic,
key_record_name: None,
val_record_name: Some(message_name.clone()),
};
let (_schema_id, root_file_descriptor) = loader
.load_val_schema::<FileDescriptor>()
.await
.context("load schema failed")?;
root_file_descriptor.parent_pool().clone()
} else {
let url = url.first().unwrap();
bytes_from_url(url, protobuf_config.aws_auth_props.as_ref()).await?
let schema_bytes = bytes_from_url(url, protobuf_config.aws_auth_props.as_ref()).await?;
DescriptorPool::decode(schema_bytes.as_slice())
.with_context(|| format!("cannot build descriptor pool from schema `{location}`"))?
};

let pool = DescriptorPool::decode(schema_bytes.as_slice())
.with_context(|| format!("cannot build descriptor pool from schema `{}`", location))?;

let message_descriptor = pool.get_message_by_name(message_name).with_context(|| {
format!(
"cannot find message `{}` in schema `{}`",
Expand Down
95 changes: 0 additions & 95 deletions src/connector/src/parser/protobuf/schema_resolver.rs

This file was deleted.

94 changes: 23 additions & 71 deletions src/connector/src/schema/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,89 +16,41 @@ use std::collections::BTreeMap;
use std::sync::Arc;

use apache_avro::Schema as AvroSchema;
use risingwave_pb::catalog::PbSchemaRegistryNameStrategy;

use super::schema_registry::{
get_subject_by_strategy, handle_sr_list, name_strategy_from_str, Client, ConfluentSchema,
SchemaRegistryAuth,
};
use super::{
invalid_option_error, InvalidOptionError, SchemaFetchError, KEY_MESSAGE_NAME_KEY,
MESSAGE_NAME_KEY, NAME_STRATEGY_KEY, SCHEMA_REGISTRY_KEY,
};
use super::loader::{LoadedSchema, SchemaLoader};
use super::schema_registry::Subject;
use super::SchemaFetchError;

pub struct SchemaWithId {
pub schema: Arc<AvroSchema>,
pub id: i32,
}

impl TryFrom<ConfluentSchema> for SchemaWithId {
type Error = SchemaFetchError;

fn try_from(fetched: ConfluentSchema) -> Result<Self, Self::Error> {
let parsed = AvroSchema::parse_str(&fetched.content)
.map_err(|e| SchemaFetchError::SchemaCompile(e.into()))?;
Ok(Self {
schema: Arc::new(parsed),
id: fetched.id,
})
}
}

/// Schema registry only
pub async fn fetch_schema(
format_options: &BTreeMap<String, String>,
topic: &str,
) -> Result<(SchemaWithId, SchemaWithId), SchemaFetchError> {
let schema_location = format_options
.get(SCHEMA_REGISTRY_KEY)
.ok_or_else(|| invalid_option_error!("{SCHEMA_REGISTRY_KEY} required"))?
.clone();
let client_config = format_options.into();
let name_strategy = format_options
.get(NAME_STRATEGY_KEY)
.map(|s| {
name_strategy_from_str(s)
.ok_or_else(|| invalid_option_error!("unrecognized strategy {s}"))
})
.transpose()?
.unwrap_or_default();
let key_record_name = format_options
.get(KEY_MESSAGE_NAME_KEY)
.map(std::ops::Deref::deref);
let val_record_name = format_options
.get(MESSAGE_NAME_KEY)
.map(std::ops::Deref::deref);

let (key_schema, val_schema) = fetch_schema_inner(
&schema_location,
&client_config,
&name_strategy,
topic,
key_record_name,
val_record_name,
)
.await?;

Ok((key_schema.try_into()?, val_schema.try_into()?))
let loader = SchemaLoader::from_format_options(topic, format_options)?;

let (key_id, key_avro) = loader.load_key_schema().await?;
let (val_id, val_avro) = loader.load_val_schema().await?;

Ok((
SchemaWithId {
id: key_id,
schema: Arc::new(key_avro),
},
SchemaWithId {
id: val_id,
schema: Arc::new(val_avro),
},
))
}

async fn fetch_schema_inner(
schema_location: &str,
client_config: &SchemaRegistryAuth,
name_strategy: &PbSchemaRegistryNameStrategy,
topic: &str,
key_record_name: Option<&str>,
val_record_name: Option<&str>,
) -> Result<(ConfluentSchema, ConfluentSchema), SchemaFetchError> {
let urls = handle_sr_list(schema_location)?;
let client = Client::new(urls, client_config)?;

let key_subject = get_subject_by_strategy(name_strategy, topic, key_record_name, true)?;
let key_schema = client.get_schema_by_subject(&key_subject).await?;

let val_subject = get_subject_by_strategy(name_strategy, topic, val_record_name, false)?;
let val_schema = client.get_schema_by_subject(&val_subject).await?;

Ok((key_schema, val_schema))
impl LoadedSchema for AvroSchema {
fn compile(primary: Subject, _: Vec<Subject>) -> Result<Self, SchemaFetchError> {
AvroSchema::parse_str(&primary.schema.content)
.map_err(|e| SchemaFetchError::SchemaCompile(e.into()))
}
}
95 changes: 95 additions & 0 deletions src/connector/src/schema/loader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeMap;

use risingwave_pb::catalog::PbSchemaRegistryNameStrategy;

use super::schema_registry::{
get_subject_by_strategy, handle_sr_list, name_strategy_from_str, Client, Subject,
};
use super::{invalid_option_error, InvalidOptionError, SchemaFetchError};

const MESSAGE_NAME_KEY: &str = "message";
const KEY_MESSAGE_NAME_KEY: &str = "key.message";
const SCHEMA_LOCATION_KEY: &str = "schema.location";
const SCHEMA_REGISTRY_KEY: &str = "schema.registry";
const NAME_STRATEGY_KEY: &str = "schema.registry.name.strategy";

pub struct SchemaLoader {
pub client: Client,
pub name_strategy: PbSchemaRegistryNameStrategy,
pub topic: String,
pub key_record_name: Option<String>,
pub val_record_name: Option<String>,
}

impl SchemaLoader {
pub fn from_format_options(
topic: &str,
format_options: &BTreeMap<String, String>,
) -> Result<Self, SchemaFetchError> {
let schema_location = format_options
.get(SCHEMA_REGISTRY_KEY)
.ok_or_else(|| invalid_option_error!("{SCHEMA_REGISTRY_KEY} required"))?;
let client_config = format_options.into();
let urls = handle_sr_list(schema_location)?;
let client = Client::new(urls, &client_config)?;

let name_strategy = format_options
.get(NAME_STRATEGY_KEY)
.map(|s| {
name_strategy_from_str(s)
.ok_or_else(|| invalid_option_error!("unrecognized strategy {s}"))
})
.transpose()?
.unwrap_or_default();
let key_record_name = format_options.get(KEY_MESSAGE_NAME_KEY).cloned();
let val_record_name = format_options.get(MESSAGE_NAME_KEY).cloned();

Ok(Self {
client,
name_strategy,
topic: topic.into(),
key_record_name,
val_record_name,
})
}

async fn load_schema<Out: LoadedSchema, const IS_KEY: bool>(
&self,
record: Option<&str>,
) -> Result<(i32, Out), SchemaFetchError> {
let subject = get_subject_by_strategy(&self.name_strategy, &self.topic, record, IS_KEY)?;
let (primary_subject, dependency_subjects) =
self.client.get_subject_and_references(&subject).await?;
let schema_id = primary_subject.schema.id;
let out = Out::compile(primary_subject, dependency_subjects)?;
Ok((schema_id, out))
}

pub async fn load_key_schema<Out: LoadedSchema>(&self) -> Result<(i32, Out), SchemaFetchError> {
self.load_schema::<Out, true>(self.key_record_name.as_deref())
.await
}

pub async fn load_val_schema<Out: LoadedSchema>(&self) -> Result<(i32, Out), SchemaFetchError> {
self.load_schema::<Out, false>(self.val_record_name.as_deref())
.await
}
}

pub trait LoadedSchema: Sized {
fn compile(primary: Subject, references: Vec<Subject>) -> Result<Self, SchemaFetchError>;
}
3 changes: 3 additions & 0 deletions src/connector/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
use crate::error::ConnectorError;

pub mod avro;
mod loader;
pub mod protobuf;
pub mod schema_registry;

pub use loader::SchemaLoader;

const MESSAGE_NAME_KEY: &str = "message";
const KEY_MESSAGE_NAME_KEY: &str = "key.message";
const SCHEMA_LOCATION_KEY: &str = "schema.location";
Expand Down
Loading

0 comments on commit ef7eb1d

Please sign in to comment.