Skip to content

Commit

Permalink
refactor: introduce thiserror-ext for boxed error wrapper definition (
Browse files Browse the repository at this point in the history
#13200)

Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Nov 2, 2023
1 parent 537bbe2 commit c02cdc4
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 23 deletions.
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ arrow-buffer = "48"
arrow-flight = "48"
arrow-select = "48"
arrow-ord = "48"
thiserror-ext = "0.0.1"
tikv-jemalloc-ctl = { git = "https://github.com/risingwavelabs/jemallocator.git", rev = "64a2d9" }
tikv-jemallocator = { git = "https://github.com/risingwavelabs/jemallocator.git", features = [
"profiling",
Expand Down
2 changes: 1 addition & 1 deletion src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl Build for UdfExpression {
"",
DataType::from(t)
.try_into()
.map_err(risingwave_udf::Error::Unsupported)?,
.map_err(risingwave_udf::Error::unsupported)?,
true,
))
})
Expand Down
2 changes: 1 addition & 1 deletion src/expr/core/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
"",
DataType::from(t)
.try_into()
.map_err(risingwave_udf::Error::Unsupported)?,
.map_err(risingwave_udf::Error::unsupported)?,
true,
))
})
Expand Down
1 change: 1 addition & 0 deletions src/udf/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ cfg-or-panic = "0.2"
futures-util = "0.3.28"
static_assertions = "1"
thiserror = "1"
thiserror-ext = { workspace = true }
tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "macros"] }
tonic = { workspace = true }

Expand Down
25 changes: 8 additions & 17 deletions src/udf/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,24 @@
// limitations under the License.

use arrow_flight::error::FlightError;
use thiserror::Error;
use thiserror_ext::{Box, Construct};

/// A specialized `Result` type for UDF operations.
pub type Result<T, E = Error> = std::result::Result<T, E>;

/// The error type for UDF operations.
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[derive(Error, Debug, Box, Construct)]
#[thiserror_ext(type = Error)]
pub enum ErrorInner {
#[error("failed to connect to UDF service: {0}")]
Connect(#[from] tonic::transport::Error),

#[error("failed to send requests to UDF service: {0}")]
Tonic(#[from] Box<tonic::Status>),
Tonic(#[from] tonic::Status),

#[error("failed to call UDF: {0}")]
Flight(#[from] Box<FlightError>),
Flight(#[from] FlightError),

#[error("type mismatch: {0}")]
TypeMismatch(String),
Expand All @@ -45,16 +48,4 @@ pub enum Error {
ServiceError(String),
}

static_assertions::const_assert_eq!(std::mem::size_of::<Error>(), 40);

impl From<tonic::Status> for Error {
fn from(status: tonic::Status) -> Self {
Error::from(Box::new(status))
}
}

impl From<FlightError> for Error {
fn from(error: FlightError) -> Self {
Error::from(Box::new(error))
}
}
static_assertions::const_assert_eq!(std::mem::size_of::<Error>(), 8);
11 changes: 7 additions & 4 deletions src/udf/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl ArrowFlightUdfClient {
let full_schema = Schema::try_from(info)
.map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
if input_num > full_schema.fields.len() {
return Err(Error::ServiceError(format!(
return Err(Error::service_error(format!(
"function {:?} schema info not consistency: input_num: {}, total_fields: {}",
id,
input_num,
Expand All @@ -73,13 +73,13 @@ impl ArrowFlightUdfClient {
let expect_input_types: Vec<_> = args.fields.iter().map(|f| f.data_type()).collect();
let expect_result_types: Vec<_> = returns.fields.iter().map(|f| f.data_type()).collect();
if !data_types_match(&expect_input_types, &actual_input_types) {
return Err(Error::TypeMismatch(format!(
return Err(Error::type_mismatch(format!(
"function: {:?}, expect arguments: {:?}, actual: {:?}",
id, expect_input_types, actual_input_types
)));
}
if !data_types_match(&expect_result_types, &actual_result_types) {
return Err(Error::TypeMismatch(format!(
return Err(Error::type_mismatch(format!(
"function: {:?}, expect return: {:?}, actual: {:?}",
id, expect_result_types, actual_result_types
)));
Expand All @@ -91,7 +91,10 @@ impl ArrowFlightUdfClient {
pub async fn call(&self, id: &str, input: RecordBatch) -> Result<RecordBatch> {
let mut output_stream = self.call_stream(id, stream::once(async { input })).await?;
// TODO: support no output
let head = output_stream.next().await.ok_or(Error::NoReturned)??;
let head = output_stream
.next()
.await
.ok_or_else(Error::no_returned)??;
let mut remaining = vec![];
while let Some(batch) = output_stream.next().await {
remaining.push(batch?);
Expand Down
2 changes: 2 additions & 0 deletions src/udf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#![feature(error_generic_member_access)]

mod error;
mod external;

Expand Down

0 comments on commit c02cdc4

Please sign in to comment.