Skip to content

Commit

Permalink
feat: Handle exception thrown from native side
Browse files Browse the repository at this point in the history
This PR catches exceptions thrown from native side via calling Java methods, and convert them into a `CometError::JavaException` which can then be properly propagated to the JVM.
  • Loading branch information
sunchao committed Feb 20, 2024
1 parent 7018225 commit 588cdf7
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 33 deletions.
3 changes: 3 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ pub enum CometError {
#[from]
source: DataFusionError,
},

#[error("{class}: {msg}")]
JavaException { class: String, msg: String },
}

pub fn init() {
Expand Down
26 changes: 13 additions & 13 deletions core/src/execution/datafusion/expressions/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl PhysicalExpr for Subquery {
let mut env = JVMClasses::get_env();

unsafe {
let is_null = jni_static_call!(env,
let is_null = jni_static_call!(&mut env,
comet_exec.is_null(self.exec_context_id, self.id) -> jboolean
)?;

Expand All @@ -105,50 +105,50 @@ impl PhysicalExpr for Subquery {

match &self.data_type {
DataType::Boolean => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_bool(self.exec_context_id, self.id) -> jboolean
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 0))))
}
DataType::Int8 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_byte(self.exec_context_id, self.id) -> jbyte
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r))))
}
DataType::Int16 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_short(self.exec_context_id, self.id) -> jshort
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r))))
}
DataType::Int32 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_int(self.exec_context_id, self.id) -> jint
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r))))
}
DataType::Int64 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_long(self.exec_context_id, self.id) -> jlong
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r))))
}
DataType::Float32 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_float(self.exec_context_id, self.id) -> f32
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r))))
}
DataType::Float64 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_double(self.exec_context_id, self.id) -> f64
)?;

Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r))))
}
DataType::Decimal128(p, s) => {
let bytes = jni_static_call!(env,
let bytes = jni_static_call!(&mut env,
comet_exec.get_decimal(self.exec_context_id, self.id) -> BinaryWrapper
)?;
let bytes: &JByteArray = bytes.get().into();
Expand All @@ -161,14 +161,14 @@ impl PhysicalExpr for Subquery {
)))
}
DataType::Date32 => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_int(self.exec_context_id, self.id) -> jint
)?;

Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r))))
}
DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
let r = jni_static_call!(env,
let r = jni_static_call!(&mut env,
comet_exec.get_long(self.exec_context_id, self.id) -> jlong
)?;

Expand All @@ -178,15 +178,15 @@ impl PhysicalExpr for Subquery {
)))
}
DataType::Utf8 => {
let string = jni_static_call!(env,
let string = jni_static_call!(&mut env,
comet_exec.get_string(self.exec_context_id, self.id) -> StringWrapper
)?;

let string = env.get_string(string.get()).unwrap().into();
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))))
}
DataType::Binary => {
let bytes = jni_static_call!(env,
let bytes = jni_static_call!(&mut env,
comet_exec.get_binary(self.exec_context_id, self.id) -> BinaryWrapper
)?;
let bytes: &JByteArray = bytes.get().into();
Expand Down
157 changes: 137 additions & 20 deletions core/src/jvm_bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
use jni::{
errors::{Error, Result as JniResult},
objects::{JClass, JObject, JString, JValueGen, JValueOwned},
objects::{JClass, JMethodID, JObject, JString, JThrowable, JValueGen, JValueOwned},
signature::ReturnType,
AttachGuard, JNIEnv,
};
use once_cell::sync::OnceCell;
Expand Down Expand Up @@ -58,29 +59,52 @@ macro_rules! jni_new_string {
/// jname and value are the arguments.
macro_rules! jni_call {
($env:expr, $clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{
$crate::jvm_bridge::jni_map_error!(
$env,
$env.call_method_unchecked(
$obj,
paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]},
paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method _ret>]}.clone(),
$crate::jvm_bridge::jvalues!($($args,)*)
)
).and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result)))
let method_id = paste::paste! {
$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]
};
let ret_type = paste::paste! {
$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method _ret>]
}.clone();
let args = $crate::jvm_bridge::jvalues!($($args,)*);

// Call the JVM method and obtain the returned value
let ret = $env.call_method_unchecked($obj, method_id, ret_type, args);

// Check if JVM has thrown any exception, and handle it if so.
let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env).unwrap() {
Err(exception.into())
} else {
$crate::jvm_bridge::jni_map_error!($env, ret)
};

result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result)))
}}
}

macro_rules! jni_static_call {
($env:expr, $clsname:ident.$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{
$crate::jvm_bridge::jni_map_error!(
$env,
$env.call_static_method_unchecked(
&paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<class>]},
paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]},
paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method _ret>]}.clone(),
$crate::jvm_bridge::jvalues!($($args,)*)
)
).and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result)))
let clazz = &paste::paste! {
$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<class>]
};
let method_id = paste::paste! {
$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method>]
};
let ret_type = paste::paste! {
$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[<method_ $method _ret>]
}.clone();
let args = $crate::jvm_bridge::jvalues!($($args,)*);

// Call the JVM static method and obtain the returned value
let ret = $env.call_static_method_unchecked(clazz, method_id, ret_type, args);

// Check if JVM has thrown any exception, and handle it if so.
let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env).unwrap() {
Err(exception.into())
} else {
$crate::jvm_bridge::jni_map_error!($env, ret)
};

result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result)))
}}
}

Expand Down Expand Up @@ -167,11 +191,21 @@ pub fn get_global_jclass(env: &mut JNIEnv, cls: &str) -> JniResult<JClass<'stati
mod comet_exec;
pub use comet_exec::*;
mod comet_metric_node;
use crate::JAVA_VM;
use crate::{
errors::{CometError, CometResult},
JAVA_VM,
};
pub use comet_metric_node::*;

/// The JVM classes that are used in the JNI calls.
pub struct JVMClasses<'a> {
/// Cached method ID for "java.lang.Object#getClass"
pub object_get_class_method: JMethodID,
/// Cached method ID for "java.lang.Class#getName"
pub class_get_name_method: JMethodID,
/// Cached method ID for "java.lang.Throwable#getMessage"
pub throwable_get_message_method: JMethodID,

/// The CometMetricNode class. Used for updating the metrics.
pub comet_metric_node: CometMetricNode<'a>,
/// The static CometExec class. Used for getting the subquery result.
Expand All @@ -192,7 +226,25 @@ impl JVMClasses<'_> {
// `JNIEnv` except for creating the global references of the classes.
let env = unsafe { std::mem::transmute::<_, &'static mut JNIEnv>(env) };

let clazz = env.find_class("java/lang/Object").unwrap();
let object_get_class_method = env
.get_method_id(clazz, "getClass", "()Ljava/lang/Class;")
.unwrap();

let clazz = env.find_class("java/lang/Class").unwrap();
let class_get_name_method = env
.get_method_id(clazz, "getName", "()Ljava/lang/String;")
.unwrap();

let clazz = env.find_class("java/lang/Throwable").unwrap();
let throwable_get_message_method = env
.get_method_id(clazz, "getMessage", "()Ljava/lang/String;")
.unwrap();

JVMClasses {
object_get_class_method,
class_get_name_method,
throwable_get_message_method,
comet_metric_node: CometMetricNode::new(env).unwrap(),
comet_exec: CometExec::new(env).unwrap(),
}
Expand All @@ -211,3 +263,68 @@ impl JVMClasses<'_> {
}
}
}

pub(crate) fn check_exception(env: &mut JNIEnv) -> CometResult<Option<CometError>> {
let result = if env.exception_check()? {
let exception = env.exception_occurred()?;
env.exception_clear()?;
let exception_err = convert_exception(env, &exception)?;
Some(exception_err)
} else {
None
};

Ok(result)
}

/// Given a `JThrowable` which is thrown from calling a Java method on the native side,
/// this converts it into a `CometError::JavaException` with the exception class name
/// and exception message. This error can then be populated to the JVM side to let
/// users know the cause of the native side error.
pub(crate) fn convert_exception(
env: &mut JNIEnv,
throwable: &JThrowable,
) -> CometResult<CometError> {
unsafe {
let cache = JVMClasses::get();

// get the class name of the exception by:
// 1. get the `Class` object of the input `throwable` via `Object#getClass` method
// 2. get the exception class name via calling `Class#getName` on the above object
let class_obj = env
.call_method_unchecked(
throwable,
cache.object_get_class_method,
ReturnType::Object,
&[],
)?
.l()?;
let exception_class_name = env
.call_method_unchecked(
class_obj,
cache.class_get_name_method,
ReturnType::Object,
&[],
)?
.l()?
.into();
let exception_class_name_str = env.get_string(&exception_class_name)?.into();

// get the exception message via calling `Throwable#getMessage` on the throwable object
let message = env
.call_method_unchecked(
throwable,
cache.throwable_get_message_method,
ReturnType::Object,
&[],
)?
.l()?
.into();
let message_str = env.get_string(&message)?.into();

Ok(CometError::JavaException {
class: exception_class_name_str,
msg: message_str,
})
}
}

0 comments on commit 588cdf7

Please sign in to comment.