From 0da7a818bc3b88ea16101ca7c8db4aa37b150fd6 Mon Sep 17 00:00:00 2001 From: ijl Date: Wed, 13 Nov 2024 15:37:07 +0000 Subject: [PATCH] alloc using PyMem functions --- src/alloc.rs | 38 ++++++++++++++++++ src/deserialize/backend/yyjson.rs | 67 ++++++++++++++++++------------- src/deserialize/pyobject.rs | 2 +- src/lib.rs | 1 + src/typeref.rs | 48 ---------------------- src/util.rs | 6 +++ 6 files changed, 85 insertions(+), 77 deletions(-) create mode 100644 src/alloc.rs diff --git a/src/alloc.rs b/src/alloc.rs new file mode 100644 index 00000000..4c54938d --- /dev/null +++ b/src/alloc.rs @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: (Apache-2.0 OR MIT) + +use std::alloc::{GlobalAlloc, Layout}; +use std::ffi::c_void; + +struct PyMemAllocator {} + +#[global_allocator] +static ALLOCATOR: PyMemAllocator = PyMemAllocator {}; + +unsafe impl Sync for PyMemAllocator {} + +unsafe impl GlobalAlloc for PyMemAllocator { + #[inline] + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + unsafe { pyo3_ffi::PyMem_Malloc(layout.size()) as *mut u8 } + } + + #[inline] + unsafe fn dealloc(&self, ptr: *mut u8, _layout: Layout) { + unsafe { pyo3_ffi::PyMem_Free(ptr as *mut c_void) } + } + + #[inline] + unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { + unsafe { + let len = layout.size(); + let ptr = pyo3_ffi::PyMem_Malloc(len) as *mut u8; + core::ptr::write_bytes(ptr, 0, len); + ptr + } + } + + #[inline] + unsafe fn realloc(&self, ptr: *mut u8, _layout: Layout, new_size: usize) -> *mut u8 { + unsafe { pyo3_ffi::PyMem_Realloc(ptr as *mut c_void, new_size) as *mut u8 } + } +} diff --git a/src/deserialize/backend/yyjson.rs b/src/deserialize/backend/yyjson.rs index c5538b4b..8ce35d98 100644 --- a/src/deserialize/backend/yyjson.rs +++ b/src/deserialize/backend/yyjson.rs @@ -4,7 +4,6 @@ use crate::deserialize::pyobject::*; use crate::deserialize::DeserializeError; use crate::ffi::yyjson::*; use crate::str::unicode_from_str; -use crate::typeref::{yyjson_init, YYJSON_ALLOC, YYJSON_BUFFER_SIZE}; use core::ffi::c_char; use core::ptr::{null, null_mut, NonNull}; use std::borrow::Cow; @@ -57,25 +56,52 @@ fn unsafe_yyjson_get_next_non_container(val: *mut yyjson_val) -> *mut yyjson_val unsafe { ((val as *mut u8).add(YYJSON_VAL_SIZE)) as *mut yyjson_val } } +type DeserializeBuffer = Vec>; + +const MINIMUM_BUFFER_CAPACITY: usize = 4096 - core::mem::size_of::(); + pub(crate) fn deserialize( data: &'static str, ) -> Result, DeserializeError<'static>> { + let buffer_capacity = usize::max( + MINIMUM_BUFFER_CAPACITY, + yyjson_read_max_memory_usage(data.len()), + ); + let mut buffer: DeserializeBuffer = Vec::with_capacity(buffer_capacity); + let mut alloc = crate::ffi::yyjson::yyjson_alc { + malloc: None, + realloc: None, + free: None, + ctx: null_mut(), + }; + let alloc_ptr = core::ptr::addr_of_mut!(alloc); + unsafe { + crate::ffi::yyjson::yyjson_alc_pool_init( + alloc_ptr, + buffer.as_mut_ptr().cast::(), + buffer_capacity, + ); + } + let mut err = yyjson_read_err { code: YYJSON_READ_SUCCESS, msg: null(), pos: 0, }; - let doc = if yyjson_read_max_memory_usage(data.len()) < YYJSON_BUFFER_SIZE { - read_doc_with_buffer(data, &mut err) - } else { - read_doc_default(data, &mut err) + + let doc = unsafe { + yyjson_read_opts( + data.as_ptr() as *mut c_char, + data.len(), + alloc_ptr as *const crate::ffi::yyjson::yyjson_alc, + &mut err, + ) }; if unlikely!(doc.is_null()) { let msg: Cow = unsafe { core::ffi::CStr::from_ptr(err.msg).to_string_lossy() }; Err(DeserializeError::from_yyjson(msg, err.pos as i64, data)) } else { let val = yyjson_doc_get_root(doc); - if unlikely!(!unsafe_yyjson_is_ctn(val)) { let pyval = match ElementType::from_tag(val) { ElementType::String => parse_yy_string(val), @@ -85,8 +111,8 @@ pub(crate) fn deserialize( ElementType::Null => parse_none(), ElementType::True => parse_true(), ElementType::False => parse_false(), - ElementType::Array => unreachable!(), - ElementType::Object => unreachable!(), + ElementType::Array => unreachable_unchecked!(), + ElementType::Object => unreachable_unchecked!(), }; unsafe { yyjson_doc_free(doc) }; Ok(pyval) @@ -110,21 +136,6 @@ pub(crate) fn deserialize( } } -fn read_doc_default(data: &'static str, err: &mut yyjson_read_err) -> *mut yyjson_doc { - unsafe { yyjson_read_opts(data.as_ptr() as *mut c_char, data.len(), null_mut(), err) } -} - -fn read_doc_with_buffer(data: &'static str, err: &mut yyjson_read_err) -> *mut yyjson_doc { - unsafe { - yyjson_read_opts( - data.as_ptr() as *mut c_char, - data.len(), - &YYJSON_ALLOC.get_or_init(yyjson_init).alloc, - err, - ) - } -} - enum ElementType { String, Uint64, @@ -149,7 +160,7 @@ impl ElementType { TAG_FALSE => Self::False, TAG_ARRAY => Self::Array, TAG_OBJECT => Self::Object, - _ => unreachable!(), + _ => unreachable_unchecked!(), } } } @@ -221,8 +232,8 @@ fn populate_yy_array(list: *mut pyo3_ffi::PyObject, elem: *mut yyjson_val) { ElementType::Null => parse_none(), ElementType::True => parse_true(), ElementType::False => parse_false(), - ElementType::Array => unreachable!(), - ElementType::Object => unreachable!(), + ElementType::Array => unreachable_unchecked!(), + ElementType::Object => unreachable_unchecked!(), }; append_to_list!(dptr, pyval.as_ptr()); } @@ -283,8 +294,8 @@ fn populate_yy_object(dict: *mut pyo3_ffi::PyObject, elem: *mut yyjson_val) { ElementType::Null => parse_none(), ElementType::True => parse_true(), ElementType::False => parse_false(), - ElementType::Array => unreachable!(), - ElementType::Object => unreachable!(), + ElementType::Array => unreachable_unchecked!(), + ElementType::Object => unreachable_unchecked!(), }; add_to_dict!(dict, pykey, pyval.as_ptr()); reverse_pydict_incref!(pykey); diff --git a/src/deserialize/pyobject.rs b/src/deserialize/pyobject.rs index 9ad9b090..e486c916 100644 --- a/src/deserialize/pyobject.rs +++ b/src/deserialize/pyobject.rs @@ -16,7 +16,7 @@ pub fn get_unicode_key(key_str: &str) -> *mut pyo3_ffi::PyObject { unsafe { let entry = KEY_MAP .get_mut() - .unwrap_or_else(|| unreachable!()) + .unwrap_or_else(|| unreachable_unchecked!()) .entry(&hash) .or_insert_with( || hash, diff --git a/src/lib.rs b/src/lib.rs index e65cd7f5..ec7051ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,7 @@ extern crate unwinding; #[macro_use] mod util; +mod alloc; mod deserialize; mod ffi; mod opt; diff --git a/src/typeref.rs b/src/typeref.rs index 0fe13f9f..c5fcd336 100644 --- a/src/typeref.rs +++ b/src/typeref.rs @@ -2,15 +2,9 @@ use crate::ffi::orjson_fragmenttype_new; use core::ffi::c_char; -#[cfg(feature = "yyjson")] -use core::ffi::c_void; -#[cfg(feature = "yyjson")] -use core::mem::MaybeUninit; use core::ptr::{null_mut, NonNull}; use once_cell::race::{OnceBool, OnceBox}; use pyo3_ffi::*; -#[cfg(feature = "yyjson")] -use std::cell::UnsafeCell; pub struct NumpyTypes { pub array: *mut PyTypeObject, @@ -76,48 +70,6 @@ pub static mut DESCR_STR: *mut PyObject = null_mut(); pub static mut VALUE_STR: *mut PyObject = null_mut(); pub static mut INT_ATTR_STR: *mut PyObject = null_mut(); -#[cfg(feature = "yyjson")] -pub const YYJSON_BUFFER_SIZE: usize = 1024 * 1024 * 8; - -#[cfg(feature = "yyjson")] -#[repr(align(64))] -struct YYJSONBuffer(UnsafeCell>); - -#[cfg(feature = "yyjson")] -pub struct YYJSONAlloc { - pub alloc: crate::ffi::yyjson::yyjson_alc, - _buffer: Box, -} - -#[cfg(feature = "yyjson")] -pub static mut YYJSON_ALLOC: OnceBox = OnceBox::new(); - -#[cfg(feature = "yyjson")] -pub fn yyjson_init() -> Box { - // Using unsafe to ensure allocation happens on the heap without going through the stack - // so we don't stack overflow in debug mode. Once rust-lang/rust#63291 is stable (Box::new_uninit) - // we can use that instead. - let layout = std::alloc::Layout::new::(); - let buffer = unsafe { Box::from_raw(std::alloc::alloc(layout).cast::()) }; - let mut alloc = crate::ffi::yyjson::yyjson_alc { - malloc: None, - realloc: None, - free: None, - ctx: null_mut(), - }; - unsafe { - crate::ffi::yyjson::yyjson_alc_pool_init( - &mut alloc, - buffer.0.get().cast::(), - YYJSON_BUFFER_SIZE, - ); - } - Box::new(YYJSONAlloc { - alloc, - _buffer: buffer, - }) -} - #[allow(non_upper_case_globals)] pub static mut JsonEncodeError: *mut PyObject = null_mut(); #[allow(non_upper_case_globals)] diff --git a/src/util.rs b/src/util.rs index 8add0f75..a6344343 100644 --- a/src/util.rs +++ b/src/util.rs @@ -305,3 +305,9 @@ macro_rules! popcnt { core::mem::transmute::($val).count_ones() as usize }; } + +macro_rules! unreachable_unchecked { + () => { + unsafe { core::hint::unreachable_unchecked() } + }; +}