From 6f045d07664d70a7d7d44aee2e36d7aba21afb67 Mon Sep 17 00:00:00 2001 From: KyJah Keys Date: Wed, 7 Feb 2024 18:31:25 -0500 Subject: [PATCH] replaced all do_with_* methods with macros --- src/DeltaLake/Bridge/Interop/Interop.cs | 12 +- src/DeltaLake/Bridge/Runtime.cs | 2 + .../Bridge/include/delta-lake-bridge.h | 13 +- src/DeltaLake/Bridge/src/table.rs | 481 +++++++++--------- 4 files changed, 254 insertions(+), 254 deletions(-) diff --git a/src/DeltaLake/Bridge/Interop/Interop.cs b/src/DeltaLake/Bridge/Interop/Interop.cs index ca6bf95..9dd2938 100644 --- a/src/DeltaLake/Bridge/Interop/Interop.cs +++ b/src/DeltaLake/Bridge/Interop/Interop.cs @@ -366,13 +366,13 @@ internal static unsafe partial class Methods public static extern ByteArray* table_uri([NativeTypeName("const struct RawDeltaTable *")] RawDeltaTable* table); [DllImport("delta_rs_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void table_free([NativeTypeName("struct RawDeltaTable *")] RawDeltaTable* table); + public static extern void table_free([NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* table); [DllImport("delta_rs_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void create_deltalake([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct TableCreatOptions * _Nonnull")] TableCreatOptions* options, [NativeTypeName("TableNewCallback")] IntPtr callback); + public static extern void create_deltalake([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct TableCreatOptions * _Nonnull")] TableCreatOptions* options, [NativeTypeName("const struct CancellationToken *")] CancellationToken* cancellation_token, [NativeTypeName("TableNewCallback")] IntPtr callback); [DllImport("delta_rs_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void table_new([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct ByteArrayRef * _Nonnull")] ByteArrayRef* table_uri, [NativeTypeName("struct TableOptions * _Nonnull")] TableOptions* table_options, [NativeTypeName("TableNewCallback")] IntPtr callback); + public static extern void table_new([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct ByteArrayRef * _Nonnull")] ByteArrayRef* table_uri, [NativeTypeName("struct TableOptions * _Nonnull")] TableOptions* table_options, [NativeTypeName("const struct CancellationToken *")] CancellationToken* cancellation_token, [NativeTypeName("TableNewCallback")] IntPtr callback); [DllImport("delta_rs_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [return: NativeTypeName("struct GenericOrError")] @@ -396,7 +396,7 @@ internal static unsafe partial class Methods public static extern byte table_load_with_datetime([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* table, [NativeTypeName("int64_t")] long ts_milliseconds, [NativeTypeName("const struct CancellationToken *")] CancellationToken* cancellation_token, [NativeTypeName("TableEmptyCallback")] IntPtr callback); [DllImport("delta_rs_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void table_merge([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* delta_table, [NativeTypeName("const struct ByteArrayRef *")] ByteArrayRef* query, void* stream, [NativeTypeName("const struct CancellationToken *")] CancellationToken* cancellation_token, [NativeTypeName("GenericErrorCallback")] IntPtr callback); + public static extern void table_merge([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* delta_table, [NativeTypeName("struct ByteArrayRef * _Nonnull")] ByteArrayRef* query, void* stream, [NativeTypeName("const struct CancellationToken *")] CancellationToken* cancellation_token, [NativeTypeName("GenericErrorCallback")] IntPtr callback); [DllImport("delta_rs_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] [return: NativeTypeName("struct ProtocolResponse")] @@ -412,7 +412,7 @@ internal static unsafe partial class Methods public static extern void table_delete([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* table, [NativeTypeName("const struct ByteArrayRef *")] ByteArrayRef* predicate, [NativeTypeName("const struct CancellationToken *")] CancellationToken* cancellation_token, [NativeTypeName("GenericErrorCallback")] IntPtr callback); [DllImport("delta_rs_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void table_query([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* table, [NativeTypeName("const struct ByteArrayRef *")] ByteArrayRef* query, [NativeTypeName("const struct ByteArrayRef *")] ByteArrayRef* table_name, [NativeTypeName("const struct CancellationToken *")] CancellationToken* cancellation_token, [NativeTypeName("GenericErrorCallback")] IntPtr callback); + public static extern void table_query([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* table, [NativeTypeName("struct ByteArrayRef * _Nonnull")] ByteArrayRef* query, [NativeTypeName("struct ByteArrayRef * _Nonnull")] ByteArrayRef* table_name, [NativeTypeName("const struct CancellationToken *")] CancellationToken* cancellation_token, [NativeTypeName("GenericErrorCallback")] IntPtr callback); [DllImport("delta_rs_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] public static extern void table_insert([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* table, [NativeTypeName("void * _Nonnull")] void* stream, [NativeTypeName("const struct ByteArrayRef *")] ByteArrayRef* predicate, [NativeTypeName("const struct ByteArrayRef * _Nonnull")] ByteArrayRef* mode, [NativeTypeName("uintptr_t")] UIntPtr max_rows_per_group, [NativeTypeName("bool")] byte overwrite_schema, [NativeTypeName("const struct CancellationToken *")] CancellationToken* cancellation_token, [NativeTypeName("GenericErrorCallback")] IntPtr callback); @@ -422,7 +422,7 @@ internal static unsafe partial class Methods public static extern GenericOrError table_schema([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* table); [DllImport("delta_rs_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] - public static extern void table_checkpoint([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* table, [NativeTypeName("TableEmptyCallback")] IntPtr callback); + public static extern void table_checkpoint([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* table, [NativeTypeName("const struct CancellationToken *")] CancellationToken* cancellation_token, [NativeTypeName("TableEmptyCallback")] IntPtr callback); [DllImport("delta_rs_bridge", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] public static extern void table_vacuum([NativeTypeName("struct Runtime * _Nonnull")] Runtime* runtime, [NativeTypeName("struct RawDeltaTable * _Nonnull")] RawDeltaTable* table, [NativeTypeName("const struct VacuumOptions *")] VacuumOptions* options, [NativeTypeName("GenericErrorCallback")] IntPtr callback); diff --git a/src/DeltaLake/Bridge/Runtime.cs b/src/DeltaLake/Bridge/Runtime.cs index b846400..8972f2c 100644 --- a/src/DeltaLake/Bridge/Runtime.cs +++ b/src/DeltaLake/Bridge/Runtime.cs @@ -82,6 +82,7 @@ internal async Task LoadTableAsync( Ptr, scope.Pointer(scope.ByteArray(tableUri)), scope.Pointer(nativeOptions), + scope.CancellationToken(cancellationToken), scope.FunctionPointer((success, fail) => { if (cancellationToken.IsCancellationRequested) @@ -133,6 +134,7 @@ internal async Task
CreateTableAsync(DeltaLake.Table.TableCreateOptions o Interop.Methods.create_deltalake( Ptr, scope.Pointer(nativeOptions), + scope.CancellationToken(cancellationToken), scope.FunctionPointer((success, fail) => { if (cancellationToken.IsCancellationRequested) diff --git a/src/DeltaLake/Bridge/include/delta-lake-bridge.h b/src/DeltaLake/Bridge/include/delta-lake-bridge.h index acf5083..9924c0c 100644 --- a/src/DeltaLake/Bridge/include/delta-lake-bridge.h +++ b/src/DeltaLake/Bridge/include/delta-lake-bridge.h @@ -273,15 +273,17 @@ void partition_filter_list_free(struct PartitionFilterList *list); struct ByteArray *table_uri(const struct RawDeltaTable *table); -void table_free(struct RawDeltaTable *table); +void table_free(struct RawDeltaTable *_Nonnull table); void create_deltalake(struct Runtime *_Nonnull runtime, struct TableCreatOptions *_Nonnull options, + const struct CancellationToken *cancellation_token, TableNewCallback callback); void table_new(struct Runtime *_Nonnull runtime, struct ByteArrayRef *_Nonnull table_uri, struct TableOptions *_Nonnull table_options, + const struct CancellationToken *cancellation_token, TableNewCallback callback); struct GenericOrError table_file_uris(struct Runtime *_Nonnull runtime, @@ -317,8 +319,8 @@ bool table_load_with_datetime(struct Runtime *_Nonnull runtime, void table_merge(struct Runtime *_Nonnull runtime, struct RawDeltaTable *_Nonnull delta_table, - const struct ByteArrayRef *query, - void *stream, + struct ByteArrayRef *_Nonnull query, + void *_Nonnull stream, const struct CancellationToken *cancellation_token, GenericErrorCallback callback); @@ -349,8 +351,8 @@ void table_delete(struct Runtime *_Nonnull runtime, void table_query(struct Runtime *_Nonnull runtime, struct RawDeltaTable *_Nonnull table, - const struct ByteArrayRef *query, - const struct ByteArrayRef *table_name, + struct ByteArrayRef *_Nonnull query, + struct ByteArrayRef *_Nonnull table_name, const struct CancellationToken *cancellation_token, GenericErrorCallback callback); @@ -372,6 +374,7 @@ struct GenericOrError table_schema(struct Runtime *_Nonnull runtime, void table_checkpoint(struct Runtime *_Nonnull runtime, struct RawDeltaTable *_Nonnull table, + const struct CancellationToken *cancellation_token, TableEmptyCallback callback); void table_vacuum(struct Runtime *_Nonnull runtime, diff --git a/src/DeltaLake/Bridge/src/table.rs b/src/DeltaLake/Bridge/src/table.rs index 077dbe2..06e9022 100644 --- a/src/DeltaLake/Bridge/src/table.rs +++ b/src/DeltaLake/Bridge/src/table.rs @@ -49,7 +49,22 @@ macro_rules! run_async_with_cancellation { ($runtime: expr, $table:expr, $cancellation_token: expr, $rt:ident, $tbl:ident, $work: block, $on_cancel: block ) => {{ let ($rt, $tbl) = unsafe { ($runtime.as_mut(), $table.as_mut()) }; let runtime_handle = $rt.handle(); - let cancel_token = unsafe { $cancellation_token.as_ref() }.map(|v| v.token.clone()); + let cancel_token = $cancellation_token.map(|v| v.token.clone()); + runtime_handle.spawn(async move { + if let Some(cancel_token) = cancel_token { + tokio::select! { + _ = cancel_token.cancelled() => unsafe {$on_cancel}, + _ = async $work => {}, + } + } else { + (async $work).await + } + }); + }}; + ($runtime: expr, $cancellation_token: expr, $rt:ident, $work: block, $on_cancel: block ) => {{ + let $rt = unsafe { $runtime.as_mut() }; + let runtime_handle = $rt.handle(); + let cancel_token = $cancellation_token.map(|v| v.token.clone()); runtime_handle.spawn(async move { if let Some(cancel_token) = cancel_token { tokio::select! { @@ -62,6 +77,7 @@ macro_rules! run_async_with_cancellation { }); }}; } + pub struct RawDeltaTable { table: deltalake::DeltaTable, } @@ -230,9 +246,9 @@ pub extern "C" fn table_uri(table: *const RawDeltaTable) -> *mut ByteArray { } #[no_mangle] -pub extern "C" fn table_free(table: *mut RawDeltaTable) { +pub extern "C" fn table_free(table: NonNull) { unsafe { - let _ = Box::from_raw(table); + let _ = Box::from_raw(table.as_ptr()); } } @@ -240,9 +256,10 @@ pub extern "C" fn table_free(table: *mut RawDeltaTable) { pub extern "C" fn create_deltalake( mut runtime: NonNull, options: NonNull, + cancellation_token: Option<&CancellationToken>, callback: TableNewCallback, ) { - let (runtime, options) = unsafe { (runtime.as_mut(), options.as_ref()) }; + let options = unsafe { options.as_ref() }; let table_uri = options.table_uri.to_owned_string(); let schema = unsafe { &*(options.schema as *mut arrow::ffi::FFI_ArrowSchema) }; @@ -252,7 +269,7 @@ pub extern "C" fn create_deltalake( callback( std::ptr::null_mut(), Box::into_raw(Box::new(DeltaTableError::new( - runtime, + runtime.as_mut(), DeltaTableErrorCode::Utf8, &err.to_string(), ))), @@ -285,7 +302,7 @@ pub extern "C" fn create_deltalake( callback( std::ptr::null_mut(), Box::into_raw(Box::new(DeltaTableError::new( - runtime, + runtime.as_mut(), DeltaTableErrorCode::Utf8, &err.to_string(), ))), @@ -294,30 +311,36 @@ pub extern "C" fn create_deltalake( } } }; - runtime.handle().spawn(async move { - match create_delta_table( - runtime, - table_uri, - schema, - partition_by, - save_mode, - name, - description, - configuration, - storage_options, - custom_metadata, - ) - .await + run_async_with_cancellation!( + runtime, + cancellation_token, + rt, { - Ok(table) => unsafe { - callback( - Box::into_raw(Box::new(RawDeltaTable::new(table))), - std::ptr::null(), - ); - }, - Err(err) => unsafe { callback(std::ptr::null_mut(), err.into_raw()) }, - } - }); + match create_delta_table( + rt, + table_uri, + schema, + partition_by, + save_mode, + name, + description, + configuration, + storage_options, + custom_metadata, + ) + .await + { + Ok(table) => unsafe { + callback( + Box::into_raw(Box::new(RawDeltaTable::new(table))), + std::ptr::null(), + ); + }, + Err(err) => unsafe { callback(std::ptr::null_mut(), err.into_raw()) }, + } + }, + { callback(std::ptr::null_mut(), std::ptr::null()) } + ) } #[no_mangle] @@ -325,9 +348,10 @@ pub extern "C" fn table_new( mut runtime: NonNull, table_uri: NonNull, table_options: NonNull, + cancellation_token: Option<&CancellationToken>, callback: TableNewCallback, ) { - let (runtime, options) = unsafe { (runtime.as_mut(), table_options.as_ref()) }; + let options = unsafe { table_options.as_ref() }; let table_uri = unsafe { let uri = table_uri.as_ref(); match std::str::from_utf8(uri.to_slice()) { @@ -336,7 +360,7 @@ pub extern "C" fn table_new( callback( std::ptr::null_mut(), Box::into_raw(Box::new(DeltaTableError::new( - runtime, + runtime.as_mut(), DeltaTableErrorCode::Utf8, &err.to_string(), ))), @@ -369,23 +393,28 @@ pub extern "C" fn table_new( .unwrap(); } - let runtime_handle = runtime.handle(); - runtime_handle.spawn(async move { - match builder.load().await { - Ok(table) => unsafe { - callback( - Box::into_raw(Box::new(RawDeltaTable::new(table))), - std::ptr::null(), - ) - }, - Err(err) => unsafe { - callback( - std::ptr::null_mut(), - Box::into_raw(Box::new(DeltaTableError::from_error(runtime, err))), - ) - }, - } - }); + run_async_with_cancellation!( + runtime, + cancellation_token, + rt, + { + match builder.load().await { + Ok(table) => unsafe { + callback( + Box::into_raw(Box::new(RawDeltaTable::new(table))), + std::ptr::null(), + ) + }, + Err(err) => unsafe { + callback( + std::ptr::null_mut(), + Box::into_raw(Box::new(DeltaTableError::from_error(rt, err))), + ) + }, + } + }, + { callback(std::ptr::null_mut(), std::ptr::null()) } + ); } #[no_mangle] @@ -471,7 +500,7 @@ pub extern "C" fn history( mut runtime: NonNull, mut table: NonNull, limit: usize, - cancellation_token: *const CancellationToken, + cancellation_token: Option<&CancellationToken>, callback: GenericErrorCallback, ) { run_async_with_cancellation!( @@ -506,16 +535,18 @@ pub extern "C" fn history( #[no_mangle] pub extern "C" fn table_update_incremental( - runtime: NonNull, - table: NonNull, - cancellation_token: *const CancellationToken, + mut runtime: NonNull, + mut table: NonNull, + cancellation_token: Option<&CancellationToken>, callback: TableEmptyCallback, ) { - do_with_table_and_runtime_and_cancel( + run_async_with_cancellation!( runtime, table, cancellation_token, - move |rt, tbl| async move { + rt, + tbl, + { match tbl.table.update_incremental(None).await { Ok(_) => unsafe { callback(std::ptr::null()); @@ -526,23 +557,25 @@ pub extern "C" fn table_update_incremental( }, }; }, - move || unsafe { callback(std::ptr::null()) }, + { callback(std::ptr::null()) } ); } #[no_mangle] pub extern "C" fn table_load_version( - runtime: NonNull, - table: NonNull, + mut runtime: NonNull, + mut table: NonNull, version: i64, - cancellation_token: *const CancellationToken, + cancellation_token: Option<&CancellationToken>, callback: TableEmptyCallback, ) { - do_with_table_and_runtime_and_cancel( + run_async_with_cancellation!( runtime, table, cancellation_token, - move |rt, tbl| async move { + rt, + tbl, + { match tbl.table.load_version(version).await { Ok(_) => unsafe { callback(std::ptr::null()) }, Err(err) => { @@ -551,16 +584,16 @@ pub extern "C" fn table_load_version( } }; }, - move || unsafe { callback(std::ptr::null()) }, + { callback(std::ptr::null()) } ) } #[no_mangle] pub extern "C" fn table_load_with_datetime( - runtime: NonNull, - table: NonNull, + mut runtime: NonNull, + mut table: NonNull, ts_milliseconds: i64, - cancellation_token: *const CancellationToken, + cancellation_token: Option<&CancellationToken>, callback: TableEmptyCallback, ) -> bool { let naive_dt = match NaiveDateTime::from_timestamp_millis(ts_milliseconds) { @@ -569,11 +602,13 @@ pub extern "C" fn table_load_with_datetime( }; let dt = DateTime::::from_naive_utc_and_offset(naive_dt, Utc); - do_with_table_and_runtime_and_cancel( + run_async_with_cancellation!( runtime, table, cancellation_token, - move |rt, tbl| async move { + rt, + tbl, + { match tbl.table.load_with_datetime(dt).await { Ok(_) => unsafe { callback(std::ptr::null()) }, Err(err) => { @@ -582,7 +617,7 @@ pub extern "C" fn table_load_with_datetime( } }; }, - move || unsafe { callback(std::ptr::null()) }, + { callback(std::ptr::null()) } ); true } @@ -590,14 +625,13 @@ pub extern "C" fn table_load_with_datetime( #[no_mangle] pub extern "C" fn table_merge( mut runtime: NonNull, - delta_table: NonNull, - query: *const ByteArrayRef, - stream: *mut c_void, - cancellation_token: *const CancellationToken, + mut delta_table: NonNull, + query: NonNull, + stream: NonNull, + cancellation_token: Option<&CancellationToken>, callback: GenericErrorCallback, ) { - let query = unsafe { &*query }; - let query_str = query.to_str(); + let query_str = unsafe { query.as_ref().to_str() }; let mut parser = match DeltaLakeParser::new(query_str) { Ok(data) => data, Err(err) => unsafe { @@ -614,7 +648,10 @@ pub extern "C" fn table_merge( }, }; let source_df = unsafe { - match ffi_to_df(runtime.as_mut(), stream as *mut FFI_ArrowArrayStream) { + match ffi_to_df( + runtime.as_mut(), + stream.cast::().as_ptr(), + ) { Ok(source_df) => source_df, Err(error) => { callback(std::ptr::null(), error.into_raw()); @@ -629,11 +666,13 @@ pub extern "C" fn table_merge( source, on, clauses, - }) => do_with_table_and_runtime_and_cancel( + }) => run_async_with_cancellation!( runtime, delta_table, cancellation_token, - move |rt, tbl| async move { + rt, + tbl, + { let snapshot = match tbl.table.snapshot() { Ok(snapshot) => snapshot.clone(), Err(err) => unsafe { @@ -726,7 +765,7 @@ pub extern "C" fn table_merge( }, }; }, - move || unsafe { callback(std::ptr::null(), std::ptr::null()) }, + { callback(std::ptr::null(), std::ptr::null()) } ), Err(err) => unsafe { callback( @@ -744,33 +783,35 @@ pub extern "C" fn table_merge( #[no_mangle] pub extern "C" fn table_protocol_versions( - runtime: NonNull, - table: NonNull, + mut runtime: NonNull, + mut table: NonNull, ) -> ProtocolResponse { - do_with_table_and_runtime_sync(runtime, table, |rt, tbl| match tbl.table.protocol() { - Ok(protocol) => ProtocolResponse { - min_reader_version: protocol.min_reader_version, - min_writer_version: protocol.min_writer_version, - error: std::ptr::null(), - }, - Err(err) => ProtocolResponse { - min_reader_version: 0, - min_writer_version: 0, - error: DeltaTableError::from_error(rt, err).into_raw(), - }, + run_sync!(runtime, table, rt, tbl, { + match tbl.table.protocol() { + Ok(protocol) => ProtocolResponse { + min_reader_version: protocol.min_reader_version, + min_writer_version: protocol.min_writer_version, + error: std::ptr::null(), + }, + Err(err) => ProtocolResponse { + min_reader_version: 0, + min_writer_version: 0, + error: DeltaTableError::from_error(rt, err).into_raw(), + }, + } }) } #[no_mangle] pub extern "C" fn table_restore( - runtime: NonNull, - table: NonNull, + mut runtime: NonNull, + mut table: NonNull, version_or_timestamp: i64, is_timestamp: bool, ignore_missing_files: bool, protocol_downgrade_allowed: bool, custom_metadata: *mut Map, - cancellation_token: *const CancellationToken, + cancellation_token: Option<&CancellationToken>, callback: TableEmptyCallback, ) { let json_metadata = if !custom_metadata.is_null() { @@ -784,11 +825,13 @@ pub extern "C" fn table_restore( } else { None }; - do_with_table_and_runtime_and_cancel( + run_async_with_cancellation!( runtime, table, cancellation_token, - move |rt, tbl| async move { + rt, + tbl, + { let snapshot = match tbl.table.snapshot() { Ok(snapshot) => snapshot.clone(), Err(err) => unsafe { @@ -843,16 +886,16 @@ pub extern "C" fn table_restore( } }; }, - move || unsafe { callback(std::ptr::null()) }, + { callback(std::ptr::null()) } ); } #[no_mangle] pub extern "C" fn table_update( mut runtime: NonNull, - table: NonNull, + mut table: NonNull, query: *const ByteArrayRef, - cancellation_token: *const CancellationToken, + cancellation_token: Option<&CancellationToken>, callback: GenericErrorCallback, ) { let query = { @@ -876,11 +919,13 @@ pub extern "C" fn table_update( return; }, }; - do_with_table_and_runtime_and_cancel( + run_async_with_cancellation!( runtime, table, cancellation_token, - move |rt, tbl| async move { + rt, + tbl, + { let snapshot = match tbl.table.snapshot() { Ok(snapshot) => snapshot.clone(), Err(err) => unsafe { @@ -922,29 +967,26 @@ pub extern "C" fn table_update( }, }; }, - move || unsafe { callback(std::ptr::null(), std::ptr::null()) }, + { callback(std::ptr::null(), std::ptr::null()) } ); } #[no_mangle] pub extern "C" fn table_delete( - runtime: NonNull, - table: NonNull, - predicate: *const ByteArrayRef, - cancellation_token: *const CancellationToken, + mut runtime: NonNull, + mut table: NonNull, + predicate: Option<&ByteArrayRef>, + cancellation_token: Option<&CancellationToken>, callback: GenericErrorCallback, ) { - let predicate = if !predicate.is_null() { - let predicate = unsafe { &*predicate }; - Some(predicate.to_owned_string()) - } else { - None - }; - do_with_table_and_runtime_and_cancel( + let predicate = predicate.map(|p| p.to_owned_string()); + run_async_with_cancellation!( runtime, table, cancellation_token, - move |rt, tbl| async move { + rt, + tbl, + { let snapshot = match tbl.table.snapshot() { Ok(snapshot) => snapshot.clone(), Err(err) => unsafe { @@ -978,38 +1020,30 @@ pub extern "C" fn table_delete( }, }; }, - move || unsafe { callback(std::ptr::null(), std::ptr::null()) }, + { callback(std::ptr::null(), std::ptr::null()) } ); } #[no_mangle] pub extern "C" fn table_query( - runtime: NonNull, - table: NonNull, - query: *const ByteArrayRef, - table_name: *const ByteArrayRef, - cancellation_token: *const CancellationToken, + mut runtime: NonNull, + mut table: NonNull, + query: NonNull, + table_name: NonNull, + cancellation_token: Option<&CancellationToken>, callback: GenericErrorCallback, ) { - let query = { - let query = unsafe { &*query }; - query.to_str() - }; - let table_name = if table_name.is_null() { - None - } else { - unsafe { (*table_name).to_option_str() } - }; - - do_with_table_and_runtime_and_cancel( + let (query, table_name) = unsafe { (query.as_ref().to_str(), table_name.as_ref().to_str()) }; + run_async_with_cancellation!( runtime, table, cancellation_token, - move |rt, tbl| async move { + rt, + tbl, + { let ctx = SessionContext::new(); let arc = Arc::new(tbl.table.clone()); - let name = table_name.unwrap_or("demo"); - if let Err(err) = ctx.register_table(name, arc) { + if let Err(err) = ctx.register_table(table_name, arc) { unsafe { callback( std::ptr::null(), @@ -1070,7 +1104,7 @@ pub extern "C" fn table_query( }, }; }, - move || unsafe { callback(std::ptr::null(), std::ptr::null()) }, + { callback(std::ptr::null(), std::ptr::null()) } ); } @@ -1083,7 +1117,7 @@ pub extern "C" fn table_insert( mode: &ByteArrayRef, max_rows_per_group: usize, overwrite_schema: bool, - cancellation_token: *const CancellationToken, + cancellation_token: Option<&CancellationToken>, callback: GenericErrorCallback, ) { let save_mode = unsafe { @@ -1168,13 +1202,11 @@ pub extern "C" fn table_insert( /// Must free the error #[no_mangle] pub extern "C" fn table_schema( - runtime: NonNull, - table: NonNull, + mut runtime: NonNull, + mut table: NonNull, ) -> GenericOrError { - do_with_table_and_runtime_sync( - runtime, - table, - move |rt, tbl| match crate::schema::get_schema(rt, &tbl.table) { + run_sync!(runtime, table, rt, tbl, { + match crate::schema::get_schema(rt, &tbl.table) { Ok(schema) => { let schema: arrow::ffi::FFI_ArrowSchema = match schema.try_into() { Ok(converted) => converted, @@ -1199,34 +1231,43 @@ pub extern "C" fn table_schema( bytes: std::ptr::null(), error: err.into_raw(), }, - }, - ) + } + }) } #[no_mangle] pub extern "C" fn table_checkpoint( - runtime: NonNull, - table: NonNull, + mut runtime: NonNull, + mut table: NonNull, + cancellation_token: Option<&CancellationToken>, callback: TableEmptyCallback, ) { - do_with_table_and_runtime(runtime, table, move |rt, tbl| async move { - match deltalake::checkpoints::create_checkpoint(&tbl.table).await { - Ok(_) => unsafe { - callback(std::ptr::null()); - }, - Err(err) => { - let error = - DeltaTableError::new(rt, DeltaTableErrorCode::Protocol, &err.to_string()); - unsafe { callback(error.into_raw()) } - } - }; - }) + run_async_with_cancellation!( + runtime, + table, + cancellation_token, + rt, + tbl, + { + match deltalake::checkpoints::create_checkpoint(&tbl.table).await { + Ok(_) => unsafe { + callback(std::ptr::null()); + }, + Err(err) => { + let error = + DeltaTableError::new(rt, DeltaTableErrorCode::Protocol, &err.to_string()); + unsafe { callback(error.into_raw()) } + } + }; + }, + { callback(std::ptr::null()) } + ); } #[no_mangle] pub extern "C" fn table_vacuum( - runtime: NonNull, - table: NonNull, + mut runtime: NonNull, + mut table: NonNull, options: *const VacuumOptions, callback: GenericErrorCallback, ) { @@ -1245,28 +1286,36 @@ pub extern "C" fn table_vacuum( custom_metadata, ) }; - do_with_table_and_runtime(runtime, table, move |rt, tbl| async move { - match vacuum( - &mut tbl.table, - dry_run, - retention_hours, - enforce_retention_duration, - custom_metadata, - ) - .await + run_async_with_cancellation!( + runtime, + table, + None::<&CancellationToken>, + rt, + tbl, { - Ok(strings) => { - let dyn_array = Box::into_raw(Box::new(DynamicArray::from_vec_string(strings))); - unsafe { - callback(dyn_array as *const c_void, std::ptr::null()); + match vacuum( + &mut tbl.table, + dry_run, + retention_hours, + enforce_retention_duration, + custom_metadata, + ) + .await + { + Ok(strings) => { + let dyn_array = Box::into_raw(Box::new(DynamicArray::from_vec_string(strings))); + unsafe { + callback(dyn_array as *const c_void, std::ptr::null()); + } + } + Err(err) => { + let error = DeltaTableError::from_error(rt, err); + unsafe { callback(std::ptr::null_mut(), Box::into_raw(Box::new(error))) } } } - Err(err) => { - let error = DeltaTableError::from_error(rt, err); - unsafe { callback(std::ptr::null_mut(), Box::into_raw(Box::new(error))) } - } - }; - }); + }, + { callback(std::ptr::null(), std::ptr::null()) } + ); } async fn vacuum( @@ -1307,10 +1356,10 @@ pub extern "C" fn table_version(table_handle: NonNull) -> i64 { #[no_mangle] pub extern "C" fn table_metadata( - runtime: NonNull, - table_handle: NonNull, + mut runtime: NonNull, + mut table_handle: NonNull, ) -> MetadataOrError { - do_with_table_and_runtime_sync(runtime, table_handle, |rt, table| { + run_sync!(runtime, table_handle, rt, table, { match table.table.metadata() { Ok(metadata) => { let partition_columns = metadata @@ -1371,11 +1420,11 @@ pub extern "C" fn table_metadata( #[no_mangle] pub extern "C" fn table_add_constraints( - runtime: NonNull, - table: NonNull, + mut runtime: NonNull, + mut table: NonNull, constraints: *mut Map, custom_metadata: *mut Map, - cancellation_token: *const CancellationToken, + cancellation_token: Option<&CancellationToken>, callback: TableEmptyCallback, ) { let constraints: HashMap = unsafe { @@ -1399,11 +1448,13 @@ pub extern "C" fn table_add_constraints( } }; - do_with_table_and_runtime_and_cancel( + run_async_with_cancellation!( runtime, table, cancellation_token, - move |rt, tbl| async move { + rt, + tbl, + { let snapshot = match tbl.table.snapshot() { Ok(snapshot) => snapshot.clone(), Err(err) => unsafe { @@ -1433,66 +1484,10 @@ pub extern "C" fn table_add_constraints( }, } }, - move || unsafe { callback(std::ptr::null()) }, + { callback(std::ptr::null()) } ); } -fn do_with_table_and_runtime<'a, F, Fut>( - mut rt: NonNull, - mut table: NonNull, - work: F, -) where - F: FnOnce(&'a mut Runtime, &'a mut RawDeltaTable) -> Fut + Send + 'static, - Fut: std::future::Future + Send, -{ - let runtime = unsafe { rt.as_mut() }; - let table = unsafe { table.as_mut() }; - let runtime_handle = runtime.handle(); - runtime_handle.spawn(async move { - work(runtime, table).await; - }); -} - -fn do_with_table_and_runtime_and_cancel<'a, F, Fut>( - mut rt: NonNull, - mut table: NonNull, - cancellation_token: *const CancellationToken, - work: F, - on_cancel: impl FnOnce() + Send + 'static, -) where - F: FnOnce(&'a mut Runtime, &'a mut RawDeltaTable) -> Fut + Send + 'static, - Fut: std::future::Future + Send + 'static, -{ - let runtime = unsafe { rt.as_mut() }; - let table = unsafe { table.as_mut() }; - let cancel_token = unsafe { cancellation_token.as_ref() }.map(|v| v.token.clone()); - let runtime_handle = runtime.handle(); - let call_future = work(runtime, table); - runtime_handle.spawn(async move { - if let Some(cancel_token) = cancel_token { - tokio::select! { - _ = cancel_token.cancelled() => on_cancel(), - _ = call_future => {}, - } - } else { - call_future.await - } - }); -} - -fn do_with_table_and_runtime_sync<'a, F, T>( - mut rt: NonNull, - mut table: NonNull, - work: F, -) -> T -where - F: FnOnce(&'a mut Runtime, &'a mut RawDeltaTable) -> T, -{ - let runtime = unsafe { rt.as_mut() }; - let table = unsafe { table.as_mut() }; - work(runtime, table) -} - impl RawDeltaTable { fn new(table: deltalake::DeltaTable) -> Self { RawDeltaTable { table }