diff --git a/core/Cargo.lock b/core/Cargo.lock index 0585d7ec7..9c40b9153 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -1087,7 +1087,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1336,7 +1336,7 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" dependencies = [ - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1436,7 +1436,7 @@ checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" dependencies = [ "hermit-abi", "rustix", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1472,18 +1472,32 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "java-locator" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90003f2fd9c52f212c21d8520f1128da0080bad6fff16b68fe6e7f2f0c3780c2" +dependencies = [ + "glob", + "lazy_static", +] + [[package]] name = "jni" -version = "0.19.0" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" dependencies = [ "cesu8", + "cfg-if", "combine", + "java-locator", "jni-sys", + "libloading", "log", "thiserror", "walkdir", + "windows-sys 0.45.0", ] [[package]] @@ -1586,6 +1600,16 @@ version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" +[[package]] +name = "libloading" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" +dependencies = [ + "cfg-if", + "winapi", +] + [[package]] name = "libm" version = "0.2.8" @@ -2319,7 +2343,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -2602,7 +2626,7 @@ dependencies = [ "fastrand", "redox_syscall", "rustix", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -3009,6 +3033,15 @@ dependencies = [ "windows-targets 0.52.0", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -3018,6 +3051,21 @@ dependencies = [ "windows-targets 0.52.0", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -3048,6 +3096,12 @@ dependencies = [ "windows_x86_64_msvc 0.52.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -3060,6 +3114,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -3072,6 +3132,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -3084,6 +3150,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -3096,6 +3168,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -3108,6 +3186,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -3120,6 +3204,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" diff --git a/core/Cargo.toml b/core/Cargo.toml index d27b83366..b4df34d0c 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -48,7 +48,7 @@ serde = { version = "1", features = ["derive"] } lazy_static = "1.4.0" prost = "0.12.1" thrift = "0.17" -jni = "0.19" +jni = "0.21" byteorder = "1.4.3" snap = "1.1" brotli = "3.3" @@ -81,7 +81,7 @@ prost-build = "0.9.0" [dev-dependencies] pprof = { version = "0.13.0", features = ["flamegraph"] } criterion = "0.5.1" -jni = { version = "0.19", features = ["invocation"] } +jni = { version = "0.21", features = ["invocation"] } lazy_static = "1.4" assertables = "7" diff --git a/core/src/errors.rs b/core/src/errors.rs index a5f52d377..7188ebd1d 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -298,7 +298,7 @@ impl JNIDefault for () { // `RuntimeException` back to the calling Java. Since a return result is required, use `JNIDefault` // to create a reasonable result. This returned default value will be ignored due to the exception. pub fn unwrap_or_throw_default( - env: &JNIEnv, + env: &mut JNIEnv, result: std::result::Result, ) -> T { match result { @@ -314,7 +314,7 @@ pub fn unwrap_or_throw_default( } } -fn throw_exception(env: &JNIEnv, error: &E, backtrace: Option) { +fn throw_exception(env: &mut JNIEnv, error: &E, backtrace: Option) { // If there isn't already an exception? if env.exception_check().is_ok() { // ... then throw new exception @@ -380,37 +380,46 @@ fn flatten(result: Result, E>) -> Result { result.and_then(convert::identity) } -// It is currently undefined behavior to unwind from Rust code into foreign code, so we can wrap -// our JNI functions and turn these panics into a `RuntimeException`. -pub fn try_or_throw(env: JNIEnv, f: F) -> T +// Implements "currying" from `FnOnce(T) -> R` to `FnOnce() -> R`, given +// an instance of T. Curring is not supported in Rust so we have to use this +// custom function to achieve something similar here. +fn curry<'a, T: 'a, F, R>(f: F, t: T) -> impl FnOnce() -> R + 'a where - T: JNIDefault, - F: FnOnce() -> T + UnwindSafe, + F: FnOnce(T) -> R + 'a, { - unwrap_or_throw_default(&env, catch_unwind(f).map_err(CometError::from)) + || f(t) } // This is a duplicate of `try_unwrap_or_throw`, which is used to work around Arrow's lack of // `UnwindSafe` handling. -pub fn try_assert_unwind_safe_or_throw(env: JNIEnv, f: F) -> T +pub fn try_assert_unwind_safe_or_throw(env: &JNIEnv, f: F) -> T where T: JNIDefault, - F: FnOnce() -> Result, + F: FnOnce(JNIEnv) -> Result, { + let mut env1 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() }; + let env2 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() }; unwrap_or_throw_default( - &env, - flatten(catch_unwind(std::panic::AssertUnwindSafe(f)).map_err(CometError::from)), + &mut env1, + flatten( + catch_unwind(std::panic::AssertUnwindSafe(curry(f, env2))).map_err(CometError::from), + ), ) } // It is currently undefined behavior to unwind from Rust code into foreign code, so we can wrap // our JNI functions and turn these panics into a `RuntimeException`. -pub fn try_unwrap_or_throw(env: JNIEnv, f: F) -> T +pub fn try_unwrap_or_throw(env: &JNIEnv, f: F) -> T where T: JNIDefault, - F: FnOnce() -> Result + UnwindSafe, + F: FnOnce(JNIEnv) -> Result + UnwindSafe, { - unwrap_or_throw_default(&env, flatten(catch_unwind(f).map_err(CometError::from))) + let mut env1 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() }; + let env2 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() }; + unwrap_or_throw_default( + &mut env1, + flatten(catch_unwind(curry(f, env2)).map_err(CometError::from)), + ) } #[cfg(test)] @@ -425,7 +434,7 @@ mod tests { }; use jni::{ - objects::{JClass, JObject, JString, JThrowable}, + objects::{JClass, JIntArray, JString, JThrowable}, sys::{jintArray, jstring}, AttachGuard, InitArgsBuilder, JNIEnv, JNIVersion, JavaVM, }; @@ -482,14 +491,14 @@ mod tests { #[test] pub fn error_from_panic() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); - try_or_throw(env, || { + try_unwrap_or_throw(&env, |_| -> CometResult<()> { panic!("oops!"); }); assert_pending_java_exception_detailed( - &env, + &mut env, Some("java/lang/RuntimeException"), Some("oops!"), ); @@ -500,38 +509,16 @@ mod tests { #[test] pub fn object_result() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); let clazz = env.find_class("java/lang/Object").unwrap(); let input = env.new_string("World".to_string()).unwrap(); - let actual = Java_Errors_hello(env, clazz, input); - - let actual_string = String::from(env.get_string(actual.into()).unwrap().to_str().unwrap()); - assert_eq!("Hello, World!", actual_string); - } - - // Verify that functions that return an object can handle throwing exceptions. The test - // causes an exception by passing a `null` where a string value is expected. - #[test] - pub fn object_panic_exception() { - let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); - // Class java.lang.object is just a stand-in - let class = env.find_class("java/lang/Object").unwrap(); - let input = JString::from(JObject::null()); - let _actual = Java_Errors_hello(env, class, input); - - assert!(env.exception_check().unwrap()); - let exception = env.exception_occurred().expect("Unable to get exception"); - env.exception_clear().unwrap(); + let actual = Java_Errors_hello(&env, clazz, input); + let actual_s = unsafe { JString::from_raw(actual) }; - assert_exception_message_with_stacktrace( - &env, - exception, - "Couldn't get java string!: NullPtr(\"get_string obj argument\")", - "at Java_Errors_hello(", - ); + let actual_string = String::from(env.get_string(&actual_s).unwrap().to_str().unwrap()); + assert_eq!("Hello, World!", actual_string); } // Verify that functions that return an native time are handled correctly. This is basically @@ -539,13 +526,13 @@ mod tests { #[test] pub fn jlong_result() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let a: jlong = 6; let b: jlong = 3; - let actual = Java_Errors_div(env, class, a, b); + let actual = Java_Errors_div(&env, class, a, b); assert_eq!(2, actual); } @@ -555,16 +542,16 @@ mod tests { #[test] pub fn jlong_panic_exception() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let a: jlong = 6; let b: jlong = 0; - let _actual = Java_Errors_div(env, class, a, b); + let _actual = Java_Errors_div(&env, class, a, b); assert_pending_java_exception_detailed( - &env, + &mut env, Some("java/lang/RuntimeException"), Some("attempt to divide by zero"), ); @@ -575,13 +562,13 @@ mod tests { #[test] pub fn jlong_result_ok() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let a: JString = env.new_string("9".to_string()).unwrap(); let b: JString = env.new_string("3".to_string()).unwrap(); - let actual = Java_Errors_div_with_parse(env, class, a, b); + let actual = Java_Errors_div_with_parse(&env, class, a, b); assert_eq!(3, actual); } @@ -591,16 +578,16 @@ mod tests { #[test] pub fn jlong_result_err() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let a: JString = env.new_string("NaN".to_string()).unwrap(); let b: JString = env.new_string("3".to_string()).unwrap(); - let _actual = Java_Errors_div_with_parse(env, class, a, b); + let _actual = Java_Errors_div_with_parse(&env, class, a, b); assert_pending_java_exception_detailed( - &env, + &mut env, Some("java/lang/NumberFormatException"), Some("invalid digit found in string"), ); @@ -611,17 +598,18 @@ mod tests { #[test] pub fn jint_array_result() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let buf = [2, 4, 6]; let input = env.new_int_array(3).unwrap(); - env.set_int_array_region(input, 0, &buf).unwrap(); - let actual = Java_Errors_array_div(env, class, input, 2); + env.set_int_array_region(&input, 0, &buf).unwrap(); + let actual = Java_Errors_array_div(&env, class, &input, 2); + let actual_s = unsafe { JIntArray::from_raw(actual) }; let mut buf: [i32; 3] = [0; 3]; - env.get_int_array_region(actual, 0, &mut buf).unwrap(); + env.get_int_array_region(&actual_s, 0, &mut buf).unwrap(); assert_eq!([1, 2, 3], buf); } @@ -630,17 +618,17 @@ mod tests { #[test] pub fn jint_array_panic_exception() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let buf = [2, 4, 6]; let input = env.new_int_array(3).unwrap(); - env.set_int_array_region(input, 0, &buf).unwrap(); - let _actual = Java_Errors_array_div(env, class, input, 0); + env.set_int_array_region(&input, 0, &buf).unwrap(); + let _actual = Java_Errors_array_div(&env, class, &input, 0); assert_pending_java_exception_detailed( - &env, + &mut env, Some("java/lang/RuntimeException"), Some("attempt to divide by zero"), ); @@ -683,13 +671,13 @@ mod tests { // * throwing an exception from `.expect()` #[no_mangle] pub extern "system" fn Java_Errors_hello( - env: JNIEnv, + e: &JNIEnv, _class: JClass, input: JString, ) -> jstring { - try_or_throw(env, || { + try_unwrap_or_throw(&e, |mut env| { let input: String = env - .get_string(input) + .get_string(&input) .expect("Couldn't get java string!") .into(); @@ -697,7 +685,7 @@ mod tests { .new_string(format!("Hello, {}!", input)) .expect("Couldn't create java string!"); - output.into_inner() + Ok(output.into_raw()) }) } @@ -706,24 +694,24 @@ mod tests { // * throwing an exception when dividing by zero #[no_mangle] pub extern "system" fn Java_Errors_div( - env: JNIEnv, + env: &JNIEnv, _class: JClass, a: jlong, b: jlong, ) -> jlong { - try_or_throw(env, || a / b) + try_unwrap_or_throw(env, |_| Ok(a / b)) } #[no_mangle] pub extern "system" fn Java_Errors_div_with_parse( - env: JNIEnv, + e: &JNIEnv, _class: JClass, a: JString, b: JString, ) -> jlong { - try_unwrap_or_throw(env, || { - let a_value: i64 = env.get_string(a)?.to_str()?.parse()?; - let b_value: i64 = env.get_string(b)?.to_str()?.parse()?; + try_unwrap_or_throw(e, |mut env| { + let a_value: i64 = env.get_string(&a)?.to_str()?.parse()?; + let b_value: i64 = env.get_string(&b)?.to_str()?.parse()?; Ok(a_value / b_value) }) } @@ -733,27 +721,27 @@ mod tests { // * throwing an exception when dividing by zero #[no_mangle] pub extern "system" fn Java_Errors_array_div( - env: JNIEnv, + e: &JNIEnv, _class: JClass, - input: jintArray, + input: &JIntArray, divisor: jint, ) -> jintArray { - try_or_throw(env, || { + try_unwrap_or_throw(e, |env| { let mut input_buf: [jint; 3] = [0; 3]; - env.get_int_array_region(input, 0, &mut input_buf).unwrap(); + env.get_int_array_region(input, 0, &mut input_buf)?; let buf = input_buf.map(|v| -> jint { v / divisor }); - let result = env.new_int_array(3).unwrap(); - env.set_int_array_region(result, 0, &buf).unwrap(); - result + let result = env.new_int_array(3)?; + env.set_int_array_region(&result, 0, &buf)?; + Ok(result.into_raw()) }) } // Helper method that asserts there is a pending Java exception which is an `instance_of` // `expected_type` with a message matching `expected_message` and clears it if any. fn assert_pending_java_exception_detailed( - env: &JNIEnv, + env: &mut JNIEnv, expected_type: Option<&str>, expected_message: Option<&str>, ) { @@ -762,7 +750,7 @@ mod tests { env.exception_clear().unwrap(); if let Some(expected_type) = expected_type { - assert_exception_type(env, exception, expected_type); + assert_exception_type(env, &exception, expected_type); } if let Some(expected_message) = expected_message { @@ -771,7 +759,7 @@ mod tests { } // Asserts that exception is an `instance_of` `expected_type` type. - fn assert_exception_type(env: &JNIEnv, exception: JThrowable, expected_type: &str) { + fn assert_exception_type(env: &mut JNIEnv, exception: &JThrowable, expected_type: &str) { if !env.is_instance_of(exception, expected_type).unwrap() { let class: JClass = env.get_object_class(exception).unwrap(); let name = env @@ -779,19 +767,21 @@ mod tests { .unwrap() .l() .unwrap(); - let class_name: String = env.get_string(name.into()).unwrap().into(); + let name_string = name.into(); + let class_name: String = env.get_string(&name_string).unwrap().into(); assert_eq!(class_name.replace('.', "/"), expected_type); }; } // Asserts that exception's message matches `expected_message`. - fn assert_exception_message(env: &JNIEnv, exception: JThrowable, expected_message: &str) { + fn assert_exception_message(env: &mut JNIEnv, exception: JThrowable, expected_message: &str) { let message = env .call_method(exception, "getMessage", "()Ljava/lang/String;", &[]) .unwrap() .l() .unwrap(); - let msg_rust: String = env.get_string(message.into()).unwrap().into(); + let message_string = message.into(); + let msg_rust: String = env.get_string(&message_string).unwrap().into(); println!("{}", msg_rust); // Since panics result in multi-line messages which include the backtrace, just use the // first line. @@ -800,7 +790,7 @@ mod tests { // Asserts that exception's message matches `expected_message`. fn assert_exception_message_with_stacktrace( - env: &JNIEnv, + env: &mut JNIEnv, exception: JThrowable, expected_message: &str, stacktrace_contains: &str, @@ -810,7 +800,8 @@ mod tests { .unwrap() .l() .unwrap(); - let msg_rust: String = env.get_string(message.into()).unwrap().into(); + let message_string = message.into(); + let msg_rust: String = env.get_string(&message_string).unwrap().into(); // Since panics result in multi-line messages which include the backtrace, just use the // first line. assert_starts_with!(msg_rust, expected_message); diff --git a/core/src/execution/datafusion/expressions/subquery.rs b/core/src/execution/datafusion/expressions/subquery.rs index a82fb357c..a4b32ba16 100644 --- a/core/src/execution/datafusion/expressions/subquery.rs +++ b/core/src/execution/datafusion/expressions/subquery.rs @@ -20,7 +20,10 @@ use arrow_schema::{DataType, Schema, TimeUnit}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, DataFusionError, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; -use jni::sys::{jboolean, jbyte, jint, jlong, jshort}; +use jni::{ + objects::JByteArray, + sys::{jboolean, jbyte, jint, jlong, jshort}, +}; use std::{ any::Any, fmt::{Display, Formatter}, @@ -87,109 +90,112 @@ impl PhysicalExpr for Subquery { } fn evaluate(&self, _: &RecordBatch) -> datafusion_common::Result { - let env = JVMClasses::get_env(); - - let is_null = - jni_static_call!(env, comet_exec.is_null(self.exec_context_id, self.id) -> jboolean)?; + let mut env = JVMClasses::get_env(); - if is_null > 0 { - return Ok(ColumnarValue::Scalar(ScalarValue::try_from( - &self.data_type, - )?)); - } + unsafe { + let is_null = jni_static_call!(env, + comet_exec.is_null(self.exec_context_id, self.id) -> jboolean + )?; - match &self.data_type { - DataType::Boolean => { - let r = jni_static_call!(env, - comet_exec.get_bool(self.exec_context_id, self.id) -> jboolean - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 0)))) - } - DataType::Int8 => { - let r = jni_static_call!(env, - comet_exec.get_byte(self.exec_context_id, self.id) -> jbyte - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r)))) - } - DataType::Int16 => { - let r = jni_static_call!(env, - comet_exec.get_short(self.exec_context_id, self.id) -> jshort - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r)))) + if is_null > 0 { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from( + &self.data_type, + )?)); } - DataType::Int32 => { - let r = jni_static_call!(env, - comet_exec.get_int(self.exec_context_id, self.id) -> jint - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r)))) - } - DataType::Int64 => { - let r = jni_static_call!(env, - comet_exec.get_long(self.exec_context_id, self.id) -> jlong - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r)))) - } - DataType::Float32 => { - let r = jni_static_call!(env, - comet_exec.get_float(self.exec_context_id, self.id) -> f32 - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r)))) - } - DataType::Float64 => { - let r = jni_static_call!(env, - comet_exec.get_double(self.exec_context_id, self.id) -> f64 - )?; - - Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r)))) - } - DataType::Decimal128(p, s) => { - let bytes = jni_static_call!(env, - comet_exec.get_decimal(self.exec_context_id, self.id) -> BinaryWrapper - )?; - - let slice = env.convert_byte_array((*bytes.get()).into_inner()).unwrap(); - - Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(bytes_to_i128(&slice)), - *p, - *s, - ))) - } - DataType::Date32 => { - let r = jni_static_call!(env, - comet_exec.get_int(self.exec_context_id, self.id) -> jint - )?; - - Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r)))) - } - DataType::Timestamp(TimeUnit::Microsecond, timezone) => { - let r = jni_static_call!(env, - comet_exec.get_long(self.exec_context_id, self.id) -> jlong - )?; - - Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - Some(r), - timezone.clone(), - ))) - } - DataType::Utf8 => { - let string = jni_static_call!(env, - comet_exec.get_string(self.exec_context_id, self.id) -> StringWrapper - )?; - - let string = env.get_string(*string.get()).unwrap().into(); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))) - } - DataType::Binary => { - let bytes = jni_static_call!(env, - comet_exec.get_binary(self.exec_context_id, self.id) -> BinaryWrapper - )?; - - let slice = env.convert_byte_array((*bytes.get()).into_inner()).unwrap(); - Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(slice)))) + match &self.data_type { + DataType::Boolean => { + let r = jni_static_call!(env, + comet_exec.get_bool(self.exec_context_id, self.id) -> jboolean + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 0)))) + } + DataType::Int8 => { + let r = jni_static_call!(env, + comet_exec.get_byte(self.exec_context_id, self.id) -> jbyte + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r)))) + } + DataType::Int16 => { + let r = jni_static_call!(env, + comet_exec.get_short(self.exec_context_id, self.id) -> jshort + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r)))) + } + DataType::Int32 => { + let r = jni_static_call!(env, + comet_exec.get_int(self.exec_context_id, self.id) -> jint + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r)))) + } + DataType::Int64 => { + let r = jni_static_call!(env, + comet_exec.get_long(self.exec_context_id, self.id) -> jlong + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r)))) + } + DataType::Float32 => { + let r = jni_static_call!(env, + comet_exec.get_float(self.exec_context_id, self.id) -> f32 + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r)))) + } + DataType::Float64 => { + let r = jni_static_call!(env, + comet_exec.get_double(self.exec_context_id, self.id) -> f64 + )?; + + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r)))) + } + DataType::Decimal128(p, s) => { + let bytes = jni_static_call!(env, + comet_exec.get_decimal(self.exec_context_id, self.id) -> BinaryWrapper + )?; + let bytes: &JByteArray = bytes.get().into(); + let slice = env.convert_byte_array(bytes).unwrap(); + + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(bytes_to_i128(&slice)), + *p, + *s, + ))) + } + DataType::Date32 => { + let r = jni_static_call!(env, + comet_exec.get_int(self.exec_context_id, self.id) -> jint + )?; + + Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r)))) + } + DataType::Timestamp(TimeUnit::Microsecond, timezone) => { + let r = jni_static_call!(env, + comet_exec.get_long(self.exec_context_id, self.id) -> jlong + )?; + + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(r), + timezone.clone(), + ))) + } + DataType::Utf8 => { + let string = jni_static_call!(env, + comet_exec.get_string(self.exec_context_id, self.id) -> StringWrapper + )?; + + let string = env.get_string(string.get()).unwrap().into(); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))) + } + DataType::Binary => { + let bytes = jni_static_call!(env, + comet_exec.get_binary(self.exec_context_id, self.id) -> BinaryWrapper + )?; + let bytes: &JByteArray = bytes.get().into(); + let slice = env.convert_byte_array(bytes).unwrap(); + + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(slice)))) + } + _ => internal_err!("Unsupported scalar subquery data type {:?}", self.data_type), } - _ => internal_err!("Unsupported scalar subquery data type {:?}", self.data_type), } } diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index 9981cece3..831f78838 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -36,7 +36,10 @@ use datafusion_common::DataFusionError; use futures::poll; use jni::{ errors::Result as JNIResult, - objects::{JClass, JMap, JObject, JString, ReleaseMode}, + objects::{ + AutoElements, JBooleanArray, JByteArray, JClass, JIntArray, JLongArray, JMap, JObject, + JObjectArray, JPrimitiveArray, JString, ReleaseMode, + }, sys::{jbyteArray, jint, jlong, jlongArray}, JNIEnv, }; @@ -45,7 +48,7 @@ use std::{collections::HashMap, sync::Arc, task::Poll}; use super::{serde, utils::SparkArrowConvert}; use crate::{ - errors::{try_unwrap_or_throw, CometError}, + errors::{try_unwrap_or_throw, CometError, CometResult}, execution::{ datafusion::planner::PhysicalPlanner, metrics::utils::update_comet_metric, serde::to_arrow_datatype, shuffle::row::process_sorted_row_partition, sort::RdxSort, @@ -55,7 +58,7 @@ use crate::{ }; use futures::stream::StreamExt; use jni::{ - objects::{AutoArray, GlobalRef}, + objects::GlobalRef, sys::{jboolean, jbooleanArray, jdouble, jintArray, jobjectArray, jstring}, }; use tokio::runtime::Runtime; @@ -88,21 +91,24 @@ struct ExecutionContext { pub debug_native: bool, } -#[no_mangle] /// Accept serialized query plan and return the address of the native query plan. -pub extern "system" fn Java_org_apache_comet_Native_createPlan( - env: JNIEnv, +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +#[no_mangle] +pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( + e: JNIEnv, _class: JClass, id: jlong, config_object: JObject, serialized_query: jbyteArray, metrics_node: JObject, ) -> jlong { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |mut env| { // Init JVM classes - JVMClasses::init(&env); + JVMClasses::init(&mut env); - let bytes = env.convert_byte_array(serialized_query)?; + let array = unsafe { JPrimitiveArray::from_raw(serialized_query) }; + let bytes = env.convert_byte_array(array)?; // Deserialize query plan let spark_plan = serde::deserialize_op(bytes.as_slice())?; @@ -110,13 +116,13 @@ pub extern "system" fn Java_org_apache_comet_Native_createPlan( // Sets up context let mut configs = HashMap::new(); - let config_map = JMap::from_env(&env, config_object)?; - config_map.iter()?.for_each(|config| { - let key: String = env.get_string(JString::from(config.0)).unwrap().into(); - let value: String = env.get_string(JString::from(config.1)).unwrap().into(); - + let config_map = JMap::from_env(&mut env, &config_object)?; + let mut map_iter = config_map.iter(&mut env)?; + while let Some((key, value)) = map_iter.next(&mut env)? { + let key: String = env.get_string(&JString::from(key)).unwrap().into(); + let value: String = env.get_string(&JString::from(value)).unwrap().into(); configs.insert(key, value); - }); + } // Whether we've enabled additional debugging on the native side let debug_native = configs @@ -157,8 +163,8 @@ pub extern "system" fn Java_org_apache_comet_Native_createPlan( /// Parse Comet configs and configure DataFusion session context. fn prepare_datafusion_session_context( conf: &HashMap, -) -> Result { - // Get the batch size from Comet JVM side +) -> CometResult { + // Get the batch size from Boson JVM side let batch_size = conf .get("batch_size") .ok_or(CometError::Internal( @@ -205,10 +211,10 @@ fn prepare_datafusion_session_context( /// Prepares arrow arrays for output. fn prepare_output( + env: &mut JNIEnv, output: Result, - env: JNIEnv, exec_context: &mut ExecutionContext, -) -> Result { +) -> CometResult { let output_batch = output?; let results = output_batch.columns(); let num_rows = output_batch.num_rows(); @@ -226,7 +232,7 @@ fn prepare_output( let return_flag = 1; let long_array = env.new_long_array((results.len() * 2) as i32 + 2)?; - env.set_long_array_region(long_array, 0, &[return_flag, num_rows as jlong])?; + env.set_long_array_region(&long_array, 0, &[return_flag, num_rows as jlong])?; let mut arrays = vec![]; @@ -241,48 +247,61 @@ fn prepare_output( arrays.push((arrow_array, arrow_schema)); } - env.set_long_array_region(long_array, (i * 2) as i32 + 2, &[array, schema])?; + env.set_long_array_region(&long_array, (i * 2) as i32 + 2, &[array, schema])?; i += 1; } // Update metrics - update_metrics(&env, exec_context)?; + update_metrics(env, exec_context)?; // Record the pointer to allocated Arrow Arrays exec_context.ffi_arrays = arrays; - Ok(long_array) + Ok(long_array.into_raw()) } -#[no_mangle] /// Accept serialized query plan and the addresses of Arrow Arrays from Spark, /// then execute the query. Return addresses of arrow vector. -pub extern "system" fn Java_org_apache_comet_Native_executePlan( - env: JNIEnv, +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +#[no_mangle] +pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( + e: JNIEnv, _class: JClass, exec_context: jlong, addresses_array: jobjectArray, finishes: jbooleanArray, batch_rows: jint, ) -> jlongArray { - try_unwrap_or_throw(env, || { - let addresses_vec = convert_addresses_arrays(&env, addresses_array)?; - let mut all_inputs: Vec> = Vec::with_capacity(addresses_vec.len()); - + try_unwrap_or_throw(&e, |mut env| unsafe { let exec_context = get_execution_context(exec_context); - for addresses in addresses_vec.iter() { + + let addresses = JObjectArray::from_raw(addresses_array); + let num_addresses = env.get_array_length(&addresses)? as usize; + + let mut all_inputs: Vec> = Vec::with_capacity(num_addresses); + + for i in 0..num_addresses { let mut inputs: Vec = vec![]; - let array_num = addresses.size()? as usize; - assert_eq!(array_num % 2, 0, "Arrow Array addresses are invalid!"); + let inner_addresses = env.get_object_array_element(&addresses, i as i32)?.into(); + let inner_address_array: AutoElements = + env.get_array_elements(&inner_addresses, ReleaseMode::NoCopyBack)?; - let num_arrays = array_num / 2; - let array_elements = addresses.as_ptr(); + let num_inner_address = inner_address_array.len(); + assert_eq!( + num_inner_address % 2, + 0, + "Arrow Array addresses are invalid!" + ); + + let num_arrays = num_inner_address / 2; + let array_elements = inner_address_array.as_ptr(); let mut i: usize = 0; while i < num_arrays { - let array_ptr = unsafe { *(array_elements.add(i * 2)) }; - let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) }; + let array_ptr = *(array_elements.add(i * 2)); + let schema_ptr = *(array_elements.add(i * 2 + 1)); let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; if exec_context.debug_native { @@ -298,7 +317,8 @@ pub extern "system" fn Java_org_apache_comet_Native_executePlan( } // Prepares the input batches. - let eofs = env.get_boolean_array_elements(finishes, ReleaseMode::NoCopyBack)?; + let array = JBooleanArray::from_raw(finishes); + let eofs = env.get_array_elements(&array, ReleaseMode::NoCopyBack)?; let eof_flags = eofs.as_ptr(); // Whether reaching the end of input batches. @@ -306,7 +326,7 @@ pub extern "system" fn Java_org_apache_comet_Native_executePlan( let mut input_batches = all_inputs .into_iter() .enumerate() - .map(|(idx, inputs)| unsafe { + .map(|(idx, inputs)| { let eof = eof_flags.add(idx); if *eof == 1 { @@ -364,25 +384,25 @@ pub extern "system" fn Java_org_apache_comet_Native_executePlan( match poll_output { Poll::Ready(Some(output)) => { - return prepare_output(output, env, exec_context); + return prepare_output(&mut env, output, exec_context); } Poll::Ready(None) => { // Reaches EOF of output. // Update metrics - update_metrics(&env, exec_context)?; + update_metrics(&mut env, exec_context)?; let long_array = env.new_long_array(1)?; - env.set_long_array_region(long_array, 0, &[-1])?; + env.set_long_array_region(&long_array, 0, &[-1])?; - return Ok(long_array); + return Ok(long_array.into_raw()); } - // After reaching the end of any input, a poll pending means there are more than one - // blocking operators, we don't need go back-forth between JVM/Native. Just - // keeping polling. + // After reaching the end of any input, a poll pending means there are more than + // one blocking operators, we don't need go back-forth + // between JVM/Native. Just keeping polling. Poll::Pending if finished => { // Update metrics - update_metrics(&env, exec_context)?; + update_metrics(&mut env, exec_context)?; // Output not ready yet continue; @@ -391,7 +411,7 @@ pub extern "system" fn Java_org_apache_comet_Native_executePlan( // operators. Just returning to keep reading next input. Poll::Pending => { // Update metrics - update_metrics(&env, exec_context)?; + update_metrics(&mut env, exec_context)?; return return_pending(env); } } @@ -401,19 +421,18 @@ pub extern "system" fn Java_org_apache_comet_Native_executePlan( fn return_pending(env: JNIEnv) -> Result { let long_array = env.new_long_array(1)?; - env.set_long_array_region(long_array, 0, &[0])?; - - Ok(long_array) + env.set_long_array_region(&long_array, 0, &[0])?; + Ok(long_array.into_raw()) } #[no_mangle] /// Peeks into next output if any. pub extern "system" fn Java_org_apache_comet_Native_peekNext( - env: JNIEnv, + e: JNIEnv, _class: JClass, exec_context: jlong, ) -> jlongArray { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |mut env| { // Retrieve the query let exec_context = get_execution_context(exec_context); @@ -427,10 +446,10 @@ pub extern "system" fn Java_org_apache_comet_Native_peekNext( let poll_output = exec_context.runtime.block_on(async { poll!(next_item) }); match poll_output { - Poll::Ready(Some(output)) => prepare_output(output, env, exec_context), + Poll::Ready(Some(output)) => prepare_output(&mut env, output, exec_context), _ => { // Update metrics - update_metrics(&env, exec_context)?; + update_metrics(&mut env, exec_context)?; return_pending(env) } } @@ -440,11 +459,11 @@ pub extern "system" fn Java_org_apache_comet_Native_peekNext( #[no_mangle] /// Drop the native query plan object and context object. pub extern "system" fn Java_org_apache_comet_Native_releasePlan( - env: JNIEnv, + e: JNIEnv, _class: JClass, exec_context: jlong, ) { - try_unwrap_or_throw(env, || unsafe { + try_unwrap_or_throw(&e, |_| unsafe { let execution_context = get_execution_context(exec_context); let _: Box = Box::from_raw(execution_context); Ok(()) @@ -452,51 +471,32 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( } /// Updates the metrics of the query plan. -fn update_metrics(env: &JNIEnv, exec_context: &ExecutionContext) -> Result<(), CometError> { +fn update_metrics(env: &mut JNIEnv, exec_context: &ExecutionContext) -> CometResult<()> { let native_query = exec_context.root_op.as_ref().unwrap(); let metrics = exec_context.metrics.as_obj(); update_comet_metric(env, metrics, native_query) } -/// Converts a Java array of address arrays to a Rust vector of address arrays. -fn convert_addresses_arrays<'a>( - env: &'a JNIEnv<'a>, - addresses_array: jobjectArray, -) -> JNIResult>> { - let array_len = env.get_array_length(addresses_array)?; - let mut res: Vec> = Vec::new(); - - for i in 0..array_len { - let array: AutoArray = env.get_array_elements( - env.get_object_array_element(addresses_array, i)? - .into_inner() as jlongArray, - ReleaseMode::NoCopyBack, - )?; - res.push(array); - } - - Ok(res) -} - fn convert_datatype_arrays( - env: &'_ JNIEnv<'_>, + env: &'_ mut JNIEnv<'_>, serialized_datatypes: jobjectArray, ) -> JNIResult> { - let array_len = env.get_array_length(serialized_datatypes)?; - let mut res: Vec = Vec::new(); - - for i in 0..array_len { - let array = env - .get_object_array_element(serialized_datatypes, i)? - .into_inner() as jbyteArray; + unsafe { + let obj_array = JObjectArray::from_raw(serialized_datatypes); + let array_len = env.get_array_length(&obj_array)?; + let mut res: Vec = Vec::new(); + + for i in 0..array_len { + let inner_array = env.get_object_array_element(&obj_array, i)?; + let inner_array: JByteArray = inner_array.into(); + let bytes = env.convert_byte_array(inner_array)?; + let data_type = serde::deserialize_data_type(bytes.as_slice()).unwrap(); + let arrow_dt = to_arrow_datatype(&data_type); + res.push(arrow_dt); + } - let bytes = env.convert_byte_array(array)?; - let data_type = serde::deserialize_data_type(bytes.as_slice()).unwrap(); - let arrow_dt = to_arrow_datatype(&data_type); - res.push(arrow_dt); + Ok(res) } - - Ok(res) } fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext { @@ -507,10 +507,12 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext { } } +/// Used by Boson shuffle external sorter to write sorted records to disk. +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -/// Used by Comet shuffle external sorter to write sorted records to disk. -pub extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( + e: JNIEnv, _class: JClass, row_addresses: jlongArray, row_sizes: jintArray, @@ -521,18 +523,23 @@ pub extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( checksum_algo: jint, current_checksum: jlong, ) -> jlongArray { - try_unwrap_or_throw(env, || { - let row_num = env.get_array_length(row_addresses)? as usize; + try_unwrap_or_throw(&e, |mut env| unsafe { + let data_types = convert_datatype_arrays(&mut env, serialized_datatypes)?; - let data_types = convert_datatype_arrays(&env, serialized_datatypes)?; + let row_address_array = JLongArray::from_raw(row_addresses); + let row_num = env.get_array_length(&row_address_array)? as usize; + let row_addresses = env.get_array_elements(&row_address_array, ReleaseMode::NoCopyBack)?; - let row_addresses = env.get_long_array_elements(row_addresses, ReleaseMode::NoCopyBack)?; - let row_sizes = env.get_int_array_elements(row_sizes, ReleaseMode::NoCopyBack)?; + let row_size_array = JIntArray::from_raw(row_sizes); + let row_sizes = env.get_array_elements(&row_size_array, ReleaseMode::NoCopyBack)?; let row_addresses_ptr = row_addresses.as_ptr(); let row_sizes_ptr = row_sizes.as_ptr(); - let output_path: String = env.get_string(JString::from(file_path)).unwrap().into(); + let output_path: String = env + .get_string(&JString::from_raw(file_path)) + .unwrap() + .into(); let checksum_enabled = checksum_enabled == 1; let current_checksum = if current_checksum == i64::MIN { @@ -563,21 +570,21 @@ pub extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( }; let long_array = env.new_long_array(2)?; - env.set_long_array_region(long_array, 0, &[written_bytes, checksum])?; + env.set_long_array_region(&long_array, 0, &[written_bytes, checksum])?; - Ok(long_array) + Ok(long_array.into_raw()) }) } #[no_mangle] -/// Used by Comet shuffle external sorter to sort in-memory row partition ids. +/// Used by Boson shuffle external sorter to sort in-memory row partition ids. pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( - env: JNIEnv, + e: JNIEnv, _class: JClass, address: jlong, size: jlong, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |_| { // SAFETY: JVM unsafe memory allocation is aligned with long. let array = unsafe { std::slice::from_raw_parts_mut(address as *mut i64, size as usize) }; array.rdxsort(); diff --git a/core/src/execution/metrics/utils.rs b/core/src/execution/metrics/utils.rs index eb36a5562..6990aa54f 100644 --- a/core/src/execution/metrics/utils.rs +++ b/core/src/execution/metrics/utils.rs @@ -27,8 +27,8 @@ use std::sync::Arc; /// update the metrics of all the children nodes. The metrics are pulled from the /// DataFusion execution plan and pushed to the Java side through JNI. pub fn update_comet_metric( - env: &JNIEnv, - metric_node: JObject, + env: &mut JNIEnv, + metric_node: &JObject, execution_plan: &Arc, ) -> Result<(), CometError> { update_metrics( @@ -43,27 +43,31 @@ pub fn update_comet_metric( .collect::>(), )?; - for (i, child_plan) in execution_plan.children().iter().enumerate() { - let child_metric_node: JObject = jni_call!(env, - comet_metric_node(metric_node).get_child_node(i as i32) -> JObject - )?; - if child_metric_node.is_null() { - continue; + unsafe { + for (i, child_plan) in execution_plan.children().iter().enumerate() { + let child_metric_node: JObject = jni_call!(env, + comet_metric_node(metric_node).get_child_node(i as i32) -> JObject + )?; + if child_metric_node.is_null() { + continue; + } + update_comet_metric(env, &child_metric_node, child_plan)?; } - update_comet_metric(env, child_metric_node, child_plan)?; } Ok(()) } #[inline] fn update_metrics( - env: &JNIEnv, - metric_node: JObject, + env: &mut JNIEnv, + metric_node: &JObject, metric_values: &[(&str, i64)], ) -> Result<(), CometError> { - for &(name, value) in metric_values { - let jname = jni_new_string!(env, &name)?; - jni_call!(env, comet_metric_node(metric_node).add(jname, value) -> ())?; + unsafe { + for &(name, value) in metric_values { + let jname = jni_new_string!(env, &name)?; + jni_call!(env, comet_metric_node(metric_node).add(&jname, value) -> ())?; + } } Ok(()) } diff --git a/core/src/jvm_bridge/comet_exec.rs b/core/src/jvm_bridge/comet_exec.rs index e28fc080f..6b6652eb4 100644 --- a/core/src/jvm_bridge/comet_exec.rs +++ b/core/src/jvm_bridge/comet_exec.rs @@ -18,7 +18,7 @@ use jni::{ errors::Result as JniResult, objects::{JClass, JStaticMethodID}, - signature::{JavaType, Primitive}, + signature::{Primitive, ReturnType}, JNIEnv, }; @@ -27,75 +27,83 @@ use super::get_global_jclass; /// A struct that holds all the JNI methods and fields for JVM CometExec object. pub struct CometExec<'a> { pub class: JClass<'a>, - pub method_get_bool: JStaticMethodID<'a>, - pub method_get_bool_ret: JavaType, - pub method_get_byte: JStaticMethodID<'a>, - pub method_get_byte_ret: JavaType, - pub method_get_short: JStaticMethodID<'a>, - pub method_get_short_ret: JavaType, - pub method_get_int: JStaticMethodID<'a>, - pub method_get_int_ret: JavaType, - pub method_get_long: JStaticMethodID<'a>, - pub method_get_long_ret: JavaType, - pub method_get_float: JStaticMethodID<'a>, - pub method_get_float_ret: JavaType, - pub method_get_double: JStaticMethodID<'a>, - pub method_get_double_ret: JavaType, - pub method_get_decimal: JStaticMethodID<'a>, - pub method_get_decimal_ret: JavaType, - pub method_get_string: JStaticMethodID<'a>, - pub method_get_string_ret: JavaType, - pub method_get_binary: JStaticMethodID<'a>, - pub method_get_binary_ret: JavaType, - pub method_is_null: JStaticMethodID<'a>, - pub method_is_null_ret: JavaType, + pub method_get_bool: JStaticMethodID, + pub method_get_bool_ret: ReturnType, + pub method_get_byte: JStaticMethodID, + pub method_get_byte_ret: ReturnType, + pub method_get_short: JStaticMethodID, + pub method_get_short_ret: ReturnType, + pub method_get_int: JStaticMethodID, + pub method_get_int_ret: ReturnType, + pub method_get_long: JStaticMethodID, + pub method_get_long_ret: ReturnType, + pub method_get_float: JStaticMethodID, + pub method_get_float_ret: ReturnType, + pub method_get_double: JStaticMethodID, + pub method_get_double_ret: ReturnType, + pub method_get_decimal: JStaticMethodID, + pub method_get_decimal_ret: ReturnType, + pub method_get_string: JStaticMethodID, + pub method_get_string_ret: ReturnType, + pub method_get_binary: JStaticMethodID, + pub method_get_binary_ret: ReturnType, + pub method_is_null: JStaticMethodID, + pub method_is_null_ret: ReturnType, } impl<'a> CometExec<'a> { pub const JVM_CLASS: &'static str = "org/apache/spark/sql/comet/CometScalarSubquery"; - pub fn new(env: &JNIEnv<'a>) -> JniResult> { + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { // Get the global class reference let class = get_global_jclass(env, Self::JVM_CLASS)?; Ok(CometExec { - class, method_get_bool: env - .get_static_method_id(class, "getBoolean", "(JJ)Z") + .get_static_method_id(Self::JVM_CLASS, "getBoolean", "(JJ)Z") + .unwrap(), + method_get_bool_ret: ReturnType::Primitive(Primitive::Boolean), + method_get_byte: env + .get_static_method_id(Self::JVM_CLASS, "getByte", "(JJ)B") .unwrap(), - method_get_bool_ret: JavaType::Primitive(Primitive::Boolean), - method_get_byte: env.get_static_method_id(class, "getByte", "(JJ)B").unwrap(), - method_get_byte_ret: JavaType::Primitive(Primitive::Byte), + method_get_byte_ret: ReturnType::Primitive(Primitive::Byte), method_get_short: env - .get_static_method_id(class, "getShort", "(JJ)S") + .get_static_method_id(Self::JVM_CLASS, "getShort", "(JJ)S") + .unwrap(), + method_get_short_ret: ReturnType::Primitive(Primitive::Short), + method_get_int: env + .get_static_method_id(Self::JVM_CLASS, "getInt", "(JJ)I") .unwrap(), - method_get_short_ret: JavaType::Primitive(Primitive::Short), - method_get_int: env.get_static_method_id(class, "getInt", "(JJ)I").unwrap(), - method_get_int_ret: JavaType::Primitive(Primitive::Int), - method_get_long: env.get_static_method_id(class, "getLong", "(JJ)J").unwrap(), - method_get_long_ret: JavaType::Primitive(Primitive::Long), + method_get_int_ret: ReturnType::Primitive(Primitive::Int), + method_get_long: env + .get_static_method_id(Self::JVM_CLASS, "getLong", "(JJ)J") + .unwrap(), + method_get_long_ret: ReturnType::Primitive(Primitive::Long), method_get_float: env - .get_static_method_id(class, "getFloat", "(JJ)F") + .get_static_method_id(Self::JVM_CLASS, "getFloat", "(JJ)F") .unwrap(), - method_get_float_ret: JavaType::Primitive(Primitive::Float), + method_get_float_ret: ReturnType::Primitive(Primitive::Float), method_get_double: env - .get_static_method_id(class, "getDouble", "(JJ)D") + .get_static_method_id(Self::JVM_CLASS, "getDouble", "(JJ)D") .unwrap(), - method_get_double_ret: JavaType::Primitive(Primitive::Double), + method_get_double_ret: ReturnType::Primitive(Primitive::Double), method_get_decimal: env - .get_static_method_id(class, "getDecimal", "(JJ)[B") + .get_static_method_id(Self::JVM_CLASS, "getDecimal", "(JJ)[B") .unwrap(), - method_get_decimal_ret: JavaType::Array(Box::new(JavaType::Primitive(Primitive::Byte))), + method_get_decimal_ret: ReturnType::Array, method_get_string: env - .get_static_method_id(class, "getString", "(JJ)Ljava/lang/String;") + .get_static_method_id(Self::JVM_CLASS, "getString", "(JJ)Ljava/lang/String;") .unwrap(), - method_get_string_ret: JavaType::Object("java/lang/String".to_owned()), + method_get_string_ret: ReturnType::Object, method_get_binary: env - .get_static_method_id(class, "getBinary", "(JJ)[B") + .get_static_method_id(Self::JVM_CLASS, "getBinary", "(JJ)[B") + .unwrap(), + method_get_binary_ret: ReturnType::Array, + method_is_null: env + .get_static_method_id(Self::JVM_CLASS, "isNull", "(JJ)Z") .unwrap(), - method_get_binary_ret: JavaType::Array(Box::new(JavaType::Primitive(Primitive::Byte))), - method_is_null: env.get_static_method_id(class, "isNull", "(JJ)Z").unwrap(), - method_is_null_ret: JavaType::Primitive(Primitive::Boolean), + method_is_null_ret: ReturnType::Primitive(Primitive::Boolean), + class, }) } } diff --git a/core/src/jvm_bridge/comet_metric_node.rs b/core/src/jvm_bridge/comet_metric_node.rs index 1d4928a09..d0176f427 100644 --- a/core/src/jvm_bridge/comet_metric_node.rs +++ b/core/src/jvm_bridge/comet_metric_node.rs @@ -18,7 +18,7 @@ use jni::{ errors::Result as JniResult, objects::{JClass, JMethodID}, - signature::{JavaType, Primitive}, + signature::{Primitive, ReturnType}, JNIEnv, }; @@ -27,33 +27,33 @@ use super::get_global_jclass; /// A struct that holds all the JNI methods and fields for JVM CometMetricNode class. pub struct CometMetricNode<'a> { pub class: JClass<'a>, - pub method_get_child_node: JMethodID<'a>, - pub method_get_child_node_ret: JavaType, - pub method_add: JMethodID<'a>, - pub method_add_ret: JavaType, + pub method_get_child_node: JMethodID, + pub method_get_child_node_ret: ReturnType, + pub method_add: JMethodID, + pub method_add_ret: ReturnType, } impl<'a> CometMetricNode<'a> { pub const JVM_CLASS: &'static str = "org/apache/spark/sql/comet/CometMetricNode"; - pub fn new(env: &JNIEnv<'a>) -> JniResult> { + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { // Get the global class reference let class = get_global_jclass(env, Self::JVM_CLASS)?; Ok(CometMetricNode { - class, method_get_child_node: env .get_method_id( - class, + Self::JVM_CLASS, "getChildNode", format!("(I)L{:};", Self::JVM_CLASS).as_str(), ) .unwrap(), - method_get_child_node_ret: JavaType::Object(Self::JVM_CLASS.to_owned()), + method_get_child_node_ret: ReturnType::Object, method_add: env - .get_method_id(class, "add", "(Ljava/lang/String;J)V") + .get_method_id(Self::JVM_CLASS, "add", "(Ljava/lang/String;J)V") .unwrap(), - method_add_ret: JavaType::Primitive(Primitive::Void), + method_add_ret: ReturnType::Primitive(Primitive::Void), + class, }) } } diff --git a/core/src/jvm_bridge/mod.rs b/core/src/jvm_bridge/mod.rs index 6f162a0ea..331e7768d 100644 --- a/core/src/jvm_bridge/mod.rs +++ b/core/src/jvm_bridge/mod.rs @@ -19,7 +19,7 @@ use jni::{ errors::{Error, Result as JniResult}, - objects::{JClass, JObject, JString, JValue}, + objects::{JClass, JObject, JString, JValueGen, JValueOwned}, AttachGuard, JNIEnv, }; use once_cell::sync::OnceCell; @@ -38,7 +38,7 @@ macro_rules! jni_map_error { /// Macro for converting Rust types to JNI types. macro_rules! jvalues { ($($args:expr,)* $(,)?) => {{ - &[$(jni::objects::JValue::from($args)),*] as &[jni::objects::JValue] + &[$(jni::objects::JValue::from($args).as_jni()),*] as &[jni::sys::jvalue] }} } @@ -75,7 +75,7 @@ macro_rules! jni_static_call { $crate::jvm_bridge::jni_map_error!( $env, $env.call_static_method_unchecked( - paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}, + &paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}, paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}, paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}.clone(), $crate::jvm_bridge::jvalues!($($args,)*) @@ -114,23 +114,23 @@ impl<'a> BinaryWrapper<'a> { } } -impl<'a> TryFrom> for StringWrapper<'a> { +impl<'a> TryFrom> for StringWrapper<'a> { type Error = Error; - fn try_from(value: JValue<'a>) -> Result, Error> { + fn try_from(value: JValueOwned<'a>) -> Result, Error> { match value { - JValue::Object(b) => Ok(StringWrapper::new(JString::from(b))), + JValueGen::Object(b) => Ok(StringWrapper::new(JString::from(b))), _ => Err(Error::WrongJValueType("object", value.type_name())), } } } -impl<'a> TryFrom> for BinaryWrapper<'a> { +impl<'a> TryFrom> for BinaryWrapper<'a> { type Error = Error; - fn try_from(value: JValue<'a>) -> Result, Error> { + fn try_from(value: JValueOwned<'a>) -> Result, Error> { match value { - JValue::Object(b) => Ok(BinaryWrapper::new(b)), + JValueGen::Object(b) => Ok(BinaryWrapper::new(b)), _ => Err(Error::WrongJValueType("object", value.type_name())), } } @@ -151,7 +151,7 @@ pub(crate) use jni_static_call; pub(crate) use jvalues; /// Gets a global reference to a Java class. -pub fn get_global_jclass(env: &JNIEnv<'_>, cls: &str) -> JniResult> { +pub fn get_global_jclass(env: &mut JNIEnv, cls: &str) -> JniResult> { let local_jclass = env.find_class(cls)?; let global = env.new_global_ref::(local_jclass.into())?; @@ -186,11 +186,11 @@ static JVM_CLASSES: OnceCell = OnceCell::new(); impl JVMClasses<'_> { /// Creates a new JVMClasses struct. - pub fn init(env: &JNIEnv) { + pub fn init(env: &mut JNIEnv) { JVM_CLASSES.get_or_init(|| { // A hack to make the `JNIEnv` static. It is not safe but we don't really use the // `JNIEnv` except for creating the global references of the classes. - let env = unsafe { std::mem::transmute::<_, &'static JNIEnv>(env) }; + let env = unsafe { std::mem::transmute::<_, &'static mut JNIEnv>(env) }; JVMClasses { comet_metric_node: CometMetricNode::new(env).unwrap(), diff --git a/core/src/lib.rs b/core/src/lib.rs index c85263f4f..d10478885 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -45,7 +45,7 @@ use once_cell::sync::OnceCell; pub use data_type::*; -use crate::errors::{try_unwrap_or_throw, CometError, CometResult}; +use errors::{try_unwrap_or_throw, CometError, CometResult}; #[macro_use] mod errors; @@ -64,15 +64,15 @@ static JAVA_VM: OnceCell = OnceCell::new(); #[no_mangle] pub extern "system" fn Java_org_apache_comet_NativeBase_init( - env: JNIEnv, + e: JNIEnv, _: JClass, log_conf_path: JString, ) { // Initialize the error handling to capture panic backtraces errors::init(); - try_unwrap_or_throw(env, || { - let path: String = env.get_string(log_conf_path)?.into(); + try_unwrap_or_throw(&e, |mut env| { + let path: String = env.get_string(&log_conf_path)?.into(); // empty path means there is no custom log4rs config file provided, so fallback to use // the default configuration diff --git a/core/src/parquet/mod.rs b/core/src/parquet/mod.rs index b1a7b939c..4f87d15de 100644 --- a/core/src/parquet/mod.rs +++ b/core/src/parquet/mod.rs @@ -41,7 +41,7 @@ use jni::{ use crate::execution::utils::SparkArrowConvert; use arrow::buffer::{Buffer, MutableBuffer}; -use jni::objects::ReleaseMode; +use jni::objects::{JBooleanArray, JLongArray, JPrimitiveArray, ReleaseMode}; use read::ColumnReader; use util::jni::{convert_column_descriptor, convert_encoding}; @@ -58,7 +58,7 @@ struct Context { #[no_mangle] pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader( - env: JNIEnv, + e: JNIEnv, _jclass: JClass, primitive_type: jint, logical_type: jint, @@ -78,9 +78,9 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader( use_decimal_128: jboolean, use_legacy_date_timestamp: jboolean, ) -> jlong { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |mut env| { let desc = convert_column_descriptor( - &env, + &mut env, primitive_type, logical_type, max_dl, @@ -111,66 +111,74 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader( }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setDictionaryPage( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setDictionaryPage( + e: JNIEnv, _jclass: JClass, handle: jlong, page_value_count: jint, page_data: jbyteArray, encoding: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; // convert value encoding ordinal to the native encoding definition let encoding = convert_encoding(encoding); // copy the input on-heap buffer to native - let page_len = env.get_array_length(page_data)?; + let page_data_array = unsafe { JPrimitiveArray::from_raw(page_data) }; + let page_len = env.get_array_length(&page_data_array)?; let mut buffer = MutableBuffer::from_len_zeroed(page_len as usize); - env.get_byte_array_region(page_data, 0, from_u8_slice(buffer.as_slice_mut()))?; + env.get_byte_array_region(&page_data_array, 0, from_u8_slice(buffer.as_slice_mut()))?; reader.set_dictionary_page(page_value_count as usize, buffer.into(), encoding); Ok(()) }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageV1( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setPageV1( + e: JNIEnv, _jclass: JClass, handle: jlong, page_value_count: jint, page_data: jbyteArray, value_encoding: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; // convert value encoding ordinal to the native encoding definition let encoding = convert_encoding(value_encoding); // copy the input on-heap buffer to native - let page_len = env.get_array_length(page_data)?; + let page_data_array = unsafe { JPrimitiveArray::from_raw(page_data) }; + let page_len = env.get_array_length(&page_data_array)?; let mut buffer = MutableBuffer::from_len_zeroed(page_len as usize); - env.get_byte_array_region(page_data, 0, from_u8_slice(buffer.as_slice_mut()))?; + env.get_byte_array_region(&page_data_array, 0, from_u8_slice(buffer.as_slice_mut()))?; reader.set_page_v1(page_value_count as usize, buffer.into(), encoding); Ok(()) }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageBufferV1( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setPageBufferV1( + e: JNIEnv, _jclass: JClass, handle: jlong, page_value_count: jint, buffer: jobject, value_encoding: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let ctx = get_context(handle)?; let reader = &mut ctx.column_reader; @@ -178,19 +186,20 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageBufferV1( let encoding = convert_encoding(value_encoding); // Get slices from Java DirectByteBuffer - let jbuffer = JByteBuffer::from(buffer); + let jbuffer = unsafe { JByteBuffer::from_raw(buffer) }; // Convert the page to global reference so it won't get GC'd by Java. Also free the last // page if there is any. - ctx.last_data_page = Some(env.new_global_ref(jbuffer)?); + ctx.last_data_page = Some(env.new_global_ref(&jbuffer)?); - let buf_slice = env.get_direct_buffer_address(jbuffer)?; + let buf_slice = env.get_direct_buffer_address(&jbuffer)?; + let buf_capacity = env.get_direct_buffer_capacity(&jbuffer)?; unsafe { - let page_ptr = NonNull::new_unchecked(buf_slice.as_ptr() as *mut u8); + let page_ptr = NonNull::new_unchecked(buf_slice); let buffer = Buffer::from_custom_allocation( page_ptr, - buf_slice.len(), + buf_capacity, Arc::new(FFI_ArrowArray::empty()), ); reader.set_page_v1(page_value_count as usize, buffer, encoding); @@ -199,9 +208,11 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageBufferV1( }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageV2( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setPageV2( + e: JNIEnv, _jclass: JClass, handle: jlong, page_value_count: jint, @@ -210,24 +221,27 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageV2( value_data: jbyteArray, value_encoding: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; // convert value encoding ordinal to the native encoding definition let encoding = convert_encoding(value_encoding); // copy the input on-heap buffer to native - let dl_len = env.get_array_length(def_level_data)?; + let def_level_array = unsafe { JPrimitiveArray::from_raw(def_level_data) }; + let dl_len = env.get_array_length(&def_level_array)?; let mut dl_buffer = MutableBuffer::from_len_zeroed(dl_len as usize); - env.get_byte_array_region(def_level_data, 0, from_u8_slice(dl_buffer.as_slice_mut()))?; + env.get_byte_array_region(&def_level_array, 0, from_u8_slice(dl_buffer.as_slice_mut()))?; - let rl_len = env.get_array_length(rep_level_data)?; + let rep_level_array = unsafe { JPrimitiveArray::from_raw(rep_level_data) }; + let rl_len = env.get_array_length(&rep_level_array)?; let mut rl_buffer = MutableBuffer::from_len_zeroed(rl_len as usize); - env.get_byte_array_region(rep_level_data, 0, from_u8_slice(rl_buffer.as_slice_mut()))?; + env.get_byte_array_region(&rep_level_array, 0, from_u8_slice(rl_buffer.as_slice_mut()))?; - let v_len = env.get_array_length(value_data)?; + let value_array = unsafe { JPrimitiveArray::from_raw(value_data) }; + let v_len = env.get_array_length(&value_array)?; let mut v_buffer = MutableBuffer::from_len_zeroed(v_len as usize); - env.get_byte_array_region(value_data, 0, from_u8_slice(v_buffer.as_slice_mut()))?; + env.get_byte_array_region(&value_array, 0, from_u8_slice(v_buffer.as_slice_mut()))?; reader.set_page_v2( page_value_count as usize, @@ -246,7 +260,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setNull( _jclass: JClass, handle: jlong, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_null(); Ok(()) @@ -260,7 +274,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setBoolean( handle: jlong, value: jboolean, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_boolean(value != 0); Ok(()) @@ -274,7 +288,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setByte( handle: jlong, value: jbyte, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) @@ -288,7 +302,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setShort( handle: jlong, value: jshort, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) @@ -302,7 +316,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setInt( handle: jlong, value: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) @@ -316,7 +330,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setLong( handle: jlong, value: jlong, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) @@ -330,7 +344,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setFloat( handle: jlong, value: jfloat, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) @@ -344,44 +358,50 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setDouble( handle: jlong, value: jdouble, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setBinary( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setBinary( + e: JNIEnv, _jclass: JClass, handle: jlong, value: jbyteArray, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; - let len = env.get_array_length(value)?; + let value_array = unsafe { JPrimitiveArray::from_raw(value) }; + let len = env.get_array_length(&value_array)?; let mut buffer = MutableBuffer::from_len_zeroed(len as usize); - env.get_byte_array_region(value, 0, from_u8_slice(buffer.as_slice_mut()))?; + env.get_byte_array_region(&value_array, 0, from_u8_slice(buffer.as_slice_mut()))?; reader.set_binary(buffer); Ok(()) }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setDecimal( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setDecimal( + e: JNIEnv, _jclass: JClass, handle: jlong, value: jbyteArray, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; - let len = env.get_array_length(value)?; + let value_array = unsafe { JPrimitiveArray::from_raw(value) }; + let len = env.get_array_length(&value_array)?; let mut buffer = MutableBuffer::from_len_zeroed(len as usize); - env.get_byte_array_region(value, 0, from_u8_slice(buffer.as_slice_mut()))?; + env.get_byte_array_region(&value_array, 0, from_u8_slice(buffer.as_slice_mut()))?; reader.set_decimal_flba(buffer); Ok(()) }) @@ -395,26 +415,29 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setPosition( value: jlong, size: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_position(value, size as usize); Ok(()) }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setIndices( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setIndices( + e: JNIEnv, _jclass: JClass, handle: jlong, offset: jlong, batch_size: jint, indices: jlongArray, ) -> jlong { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |mut env| { let reader = get_reader(handle)?; - let indices = env.get_long_array_elements(indices, ReleaseMode::NoCopyBack)?; - let len = indices.size()? as usize; + let indice_array = unsafe { JLongArray::from_raw(indices) }; + let indices = unsafe { env.get_array_elements(&indice_array, ReleaseMode::NoCopyBack)? }; + let len = indices.len(); // paris alternately contains start index and length of continuous indices let pairs = unsafe { core::slice::from_raw_parts_mut(indices.as_ptr(), len) }; let mut skipped = 0; @@ -437,19 +460,22 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setIndices( }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setIsDeleted( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setIsDeleted( + e: JNIEnv, _jclass: JClass, handle: jlong, is_deleted: jbooleanArray, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; - let len = env.get_array_length(is_deleted)?; + let is_deleted_array = unsafe { JBooleanArray::from_raw(is_deleted) }; + let len = env.get_array_length(&is_deleted_array)?; let mut buffer = MutableBuffer::from_len_zeroed(len as usize); - env.get_boolean_array_region(is_deleted, 0, buffer.as_slice_mut())?; + env.get_boolean_array_region(&is_deleted_array, 0, buffer.as_slice_mut())?; reader.set_is_deleted(buffer); Ok(()) }) @@ -461,7 +487,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_resetBatch( _jclass: JClass, handle: jlong, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.reset_batch(); Ok(()) @@ -470,20 +496,20 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_resetBatch( #[no_mangle] pub extern "system" fn Java_org_apache_comet_parquet_Native_readBatch( - env: JNIEnv, + e: JNIEnv, _jclass: JClass, handle: jlong, batch_size: jint, null_pad_size: jint, ) -> jintArray { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; let (num_values, num_nulls) = reader.read_batch(batch_size as usize, null_pad_size as usize); let res = env.new_int_array(2)?; let buf: [i32; 2] = [num_values as i32, num_nulls as i32]; - env.set_int_array_region(res, 0, &buf)?; - Ok(res) + env.set_int_array_region(&res, 0, &buf)?; + Ok(res.into_raw()) }) } @@ -495,7 +521,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_skipBatch( batch_size: jint, discard: jboolean, ) -> jint { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; Ok(reader.skip_batch(batch_size as usize, discard == 0) as jint) }) @@ -503,11 +529,11 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_skipBatch( #[no_mangle] pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch( - env: JNIEnv, + e: JNIEnv, _jclass: JClass, handle: jlong, ) -> jlongArray { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let ctx = get_context(handle)?; let reader = &mut ctx.column_reader; let data = reader.current_batch(); @@ -520,9 +546,9 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch( let res = env.new_long_array(2)?; let buf: [i64; 2] = [array, schema]; - env.set_long_array_region(res, 0, &buf) + env.set_long_array_region(&res, 0, &buf) .expect("set long array region failed"); - Ok(res) + Ok(res.into_raw()) } }) } @@ -547,7 +573,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_closeColumnReader( _jclass: JClass, handle: jlong, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { unsafe { let ctx = handle as *mut Context; let _ = Box::from_raw(ctx); diff --git a/core/src/parquet/util/jni.rs b/core/src/parquet/util/jni.rs index 000eeee0b..225abfc03 100644 --- a/core/src/parquet/util/jni.rs +++ b/core/src/parquet/util/jni.rs @@ -19,8 +19,8 @@ use std::sync::Arc; use jni::{ errors::Result as JNIResult, - objects::{JMethodID, JString}, - sys::{jboolean, jint, jobjectArray, jstring}, + objects::{JObjectArray, JString}, + sys::{jboolean, jint, jobjectArray}, JNIEnv, }; @@ -33,7 +33,7 @@ use parquet::{ /// Convert primitives from Spark side into a `ColumnDescriptor`. #[allow(clippy::too_many_arguments)] pub fn convert_column_descriptor( - env: &JNIEnv, + env: &mut JNIEnv, physical_type_id: jint, logical_type_id: jint, max_dl: jint, @@ -114,12 +114,13 @@ impl TypePromotionInfo { } } -fn convert_column_path(env: &JNIEnv, path: jobjectArray) -> JNIResult { - let array_len = env.get_array_length(path)?; +fn convert_column_path(env: &mut JNIEnv, path: jobjectArray) -> JNIResult { + let path_array = unsafe { JObjectArray::from_raw(path) }; + let array_len = env.get_array_length(&path_array)?; let mut res: Vec = Vec::new(); for i in 0..array_len { - let p: JString = (env.get_object_array_element(path, i)?.into_inner() as jstring).into(); - res.push(env.get_string(p)?.into()); + let p: JString = env.get_object_array_element(&path_array, i)?.into(); + res.push(env.get_string(&p)?.into()); } Ok(ColumnPath::new(res)) } @@ -184,16 +185,3 @@ fn fix_type_length(t: &PhysicalType, type_length: i32) -> i32 { _ => type_length, } } - -fn get_method_id<'a>(env: &'a JNIEnv, class: &'a str, method: &str, sig: &str) -> JMethodID<'a> { - // first verify the class exists - let _ = env - .find_class(class) - .unwrap_or_else(|_| panic!("Class '{}' not found", class)); - env.get_method_id(class, method, sig).unwrap_or_else(|_| { - panic!( - "Method '{}' with signature '{}' of class '{}' not found", - method, sig, class - ) - }) -} diff --git a/core/src/parquet/util/mod.rs b/core/src/parquet/util/mod.rs index 6a8c731d4..7a37b786d 100644 --- a/core/src/parquet/util/mod.rs +++ b/core/src/parquet/util/mod.rs @@ -22,7 +22,5 @@ pub mod memory; mod buffer; pub use buffer::*; -mod jni_buffer; -pub use jni_buffer::*; pub mod test_common;