From b9f5dc08536c903615d00e49052072227f68099f Mon Sep 17 00:00:00 2001 From: ijl Date: Tue, 2 Jul 2024 15:27:10 +0000 Subject: [PATCH] str various refactor and perf --- Cargo.toml | 2 +- build.rs | 4 - src/deserialize/utf8.rs | 8 +- src/lib.rs | 10 +- src/serialize/per_type/unicode.rs | 15 +- src/serialize/writer/byteswriter.rs | 10 + src/serialize/writer/formatter.rs | 16 +- src/serialize/writer/json.rs | 26 +- src/serialize/writer/mod.rs | 2 +- src/serialize/writer/str.rs | 405 ------------------------ src/serialize/writer/str/avx512.rs | 117 +++++++ src/serialize/writer/str/escape.rs | 119 +++++++ src/serialize/writer/str/generic.rs | 110 +++++++ src/serialize/writer/str/mod.rs | 22 ++ src/serialize/writer/str/scalar.rs | 41 +++ src/str/avx512.rs | 104 ++++++ src/str/check.rs | 32 -- src/str/mod.rs | 14 +- src/str/{create.rs => pyunicode_new.rs} | 49 +-- src/str/scalar.rs | 45 +++ src/util.rs | 32 ++ 21 files changed, 669 insertions(+), 514 deletions(-) delete mode 100644 src/serialize/writer/str.rs create mode 100644 src/serialize/writer/str/avx512.rs create mode 100644 src/serialize/writer/str/escape.rs create mode 100644 src/serialize/writer/str/generic.rs create mode 100644 src/serialize/writer/str/mod.rs create mode 100644 src/serialize/writer/str/scalar.rs create mode 100644 src/str/avx512.rs delete mode 100644 src/str/check.rs rename src/str/{create.rs => pyunicode_new.rs} (54%) create mode 100644 src/str/scalar.rs diff --git a/Cargo.toml b/Cargo.toml index 326e64d2..6a323404 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,7 @@ pyo3-ffi = { version = "^0.22", default-features = false, features = ["extension ryu = { version = "1", default-features = false } serde = { version = "1", default-features = false } serde_json = { version = "1", default-features = false, features = ["std", "float_roundtrip"] } -simdutf8 = { version = "0.1", default-features = false, features = ["std", "aarch64_neon"] } +simdutf8 = { version = "0.1", default-features = false, features = ["std", "public_imp", "aarch64_neon"] } smallvec = { version = "^1.11", default-features = false, features = ["union", "write"] } unwinding = { version = "0.2", features = ["unwinder"], optional = true } xxhash-rust = { version = "^0.8", default-features = false, features = ["xxh3"] } diff --git a/build.rs b/build.rs index 9d9527b7..2b705489 100644 --- a/build.rs +++ b/build.rs @@ -41,10 +41,6 @@ fn main() { if env::var("ORJSON_DISABLE_SIMD").is_err() { if let Some(true) = version_check::supports_feature("portable_simd") { println!("cargo:rustc-cfg=feature=\"unstable-simd\""); - #[cfg(all(target_arch = "x86_64", target_feature = "avx512vl"))] - if env::var("ORJSON_DISABLE_AVX512").is_err() { - println!("cargo:rustc-cfg=feature=\"avx512\""); - } } } diff --git a/src/deserialize/utf8.rs b/src/deserialize/utf8.rs index fa1bb2dc..2d2842c6 100644 --- a/src/deserialize/utf8.rs +++ b/src/deserialize/utf8.rs @@ -8,16 +8,16 @@ use crate::util::INVALID_STR; use core::ffi::c_char; use std::borrow::Cow; -#[cfg(all(target_arch = "x86_64", not(target_feature = "sse4.2")))] +#[cfg(all(target_arch = "x86_64", not(target_feature = "avx2")))] fn is_valid_utf8(buf: &[u8]) -> bool { - if std::is_x86_feature_detected!("sse4.2") { - simdutf8::basic::from_utf8(buf).is_ok() + if std::is_x86_feature_detected!("avx2") { + unsafe { simdutf8::basic::imp::x86::avx2::validate_utf8(buf).is_ok() } } else { encoding_rs::Encoding::utf8_valid_up_to(buf) == buf.len() } } -#[cfg(all(target_arch = "x86_64", target_feature = "sse4.2"))] +#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] fn is_valid_utf8(buf: &[u8]) -> bool { simdutf8::basic::from_utf8(buf).is_ok() } diff --git a/src/lib.rs b/src/lib.rs index 0ae348ee..b7a1563e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,7 @@ // SPDX-License-Identifier: (Apache-2.0 OR MIT) -#![cfg_attr( - all(target_arch = "x86_64", feature = "avx512"), - feature(avx512_target_feature) -)] -#![cfg_attr( - all(target_arch = "x86_64", feature = "avx512"), - feature(stdarch_x86_avx512) -)] +#![cfg_attr(feature = "avx512", feature(avx512_target_feature))] +#![cfg_attr(feature = "avx512", feature(stdarch_x86_avx512))] #![cfg_attr(feature = "intrinsics", feature(core_intrinsics))] #![cfg_attr(feature = "optimize", feature(optimize_attribute))] #![cfg_attr(feature = "strict_provenance", feature(strict_provenance))] diff --git a/src/serialize/per_type/unicode.rs b/src/serialize/per_type/unicode.rs index a44cda09..ecc8ae6c 100644 --- a/src/serialize/per_type/unicode.rs +++ b/src/serialize/per_type/unicode.rs @@ -17,16 +17,19 @@ impl StrSerializer { } impl Serialize for StrSerializer { - #[inline] + #[inline(always)] fn serialize(&self, serializer: S) -> Result where S: Serializer, { - let uni = unicode_to_str(self.ptr); - if unlikely!(uni.is_none()) { - err!(SerializeError::InvalidStr) - } - serializer.serialize_str(uni.unwrap()) + let uni = { + let tmp = unicode_to_str(self.ptr); + if unlikely!(tmp.is_none()) { + err!(SerializeError::InvalidStr) + }; + tmp.unwrap() + }; + serializer.serialize_str(uni) } } diff --git a/src/serialize/writer/byteswriter.rs b/src/serialize/writer/byteswriter.rs index 72510456..086d7202 100644 --- a/src/serialize/writer/byteswriter.rs +++ b/src/serialize/writer/byteswriter.rs @@ -113,6 +113,11 @@ pub trait WriteExt: std::io::Write { let _ = len; } + #[inline] + fn has_capacity(&mut self, _len: usize) -> bool { + false + } + #[inline] fn set_written(&mut self, len: usize) { let _ = len; @@ -157,6 +162,11 @@ impl WriteExt for &mut BytesWriter { } } + #[inline] + fn has_capacity(&mut self, len: usize) -> bool { + return self.len + len <= self.cap; + } + #[inline(always)] fn set_written(&mut self, len: usize) { self.len += len; diff --git a/src/serialize/writer/formatter.rs b/src/serialize/writer/formatter.rs index cc020a30..15971f3a 100644 --- a/src/serialize/writer/formatter.rs +++ b/src/serialize/writer/formatter.rs @@ -4,6 +4,12 @@ use crate::serialize::writer::WriteExt; use std::io; +macro_rules! debug_assert_has_capacity { + ($writer:expr) => { + debug_assert!($writer.has_capacity(4)) + }; +} + pub trait Formatter { #[inline] fn write_null(&mut self, writer: &mut W) -> io::Result<()> @@ -196,7 +202,7 @@ pub trait Formatter { where W: ?Sized + io::Write + WriteExt, { - reserve_minimum!(writer); + debug_assert_has_capacity!(writer); unsafe { writer.write_reserved_punctuation(b']').unwrap() }; Ok(()) } @@ -206,7 +212,7 @@ pub trait Formatter { where W: ?Sized + io::Write + WriteExt, { - reserve_minimum!(writer); + debug_assert_has_capacity!(writer); if !first { unsafe { writer.write_reserved_punctuation(b',').unwrap() } } @@ -238,7 +244,7 @@ pub trait Formatter { where W: ?Sized + io::Write + WriteExt, { - reserve_minimum!(writer); + debug_assert_has_capacity!(writer); unsafe { writer.write_reserved_punctuation(b'}').unwrap(); } @@ -250,7 +256,7 @@ pub trait Formatter { where W: ?Sized + io::Write + WriteExt, { - reserve_minimum!(writer); + debug_assert_has_capacity!(writer); if !first { unsafe { writer.write_reserved_punctuation(b',').unwrap(); @@ -272,7 +278,7 @@ pub trait Formatter { where W: ?Sized + io::Write + WriteExt, { - reserve_minimum!(writer); + debug_assert_has_capacity!(writer); unsafe { writer.write_reserved_punctuation(b':') } } diff --git a/src/serialize/writer/json.rs b/src/serialize/writer/json.rs index 0c7a4257..fe575831 100644 --- a/src/serialize/writer/json.rs +++ b/src/serialize/writer/json.rs @@ -572,9 +572,29 @@ macro_rules! reserve_str { }; } +#[cfg(all(feature = "unstable-simd", not(target_arch = "x86_64")))] +#[inline(always)] +fn format_escaped_str(writer: &mut W, value: &str) +where + W: ?Sized + io::Write + WriteExt, +{ + unsafe { + reserve_str!(writer, value); + + let written = format_escaped_str_impl_generic_128( + writer.as_mut_buffer_ptr(), + value.as_bytes().as_ptr(), + value.len(), + ); + + writer.set_written(written); + } +} + #[cfg(all( feature = "unstable-simd", - any(not(target_arch = "x86_64"), not(feature = "avx512")) + target_arch = "x86_64", + not(feature = "avx512") ))] #[inline(always)] fn format_escaped_str(writer: &mut W, value: &str) @@ -584,7 +604,7 @@ where unsafe { reserve_str!(writer, value); - let written = format_escaped_str_impl_128( + let written = format_escaped_str_impl_generic_128( writer.as_mut_buffer_ptr(), value.as_bytes().as_ptr(), value.len(), @@ -611,7 +631,7 @@ where ); writer.set_written(written); } else { - let written = format_escaped_str_impl_128( + let written = format_escaped_str_impl_generic_128( writer.as_mut_buffer_ptr(), value.as_bytes().as_ptr(), value.len(), diff --git a/src/serialize/writer/mod.rs b/src/serialize/writer/mod.rs index 34248dc7..3e945b2c 100644 --- a/src/serialize/writer/mod.rs +++ b/src/serialize/writer/mod.rs @@ -1,4 +1,4 @@ -// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: (Apache-2.0 OR MIT) mod byteswriter; mod formatter; diff --git a/src/serialize/writer/str.rs b/src/serialize/writer/str.rs deleted file mode 100644 index 1a58330f..00000000 --- a/src/serialize/writer/str.rs +++ /dev/null @@ -1,405 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright 2023-2024 liuq19, ijl -// adapted from sonic-rs' src/util/string.rs - -#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] -use core::mem::transmute; - -#[cfg(feature = "unstable-simd")] -const STRIDE_128_BIT: usize = 16; - -#[cfg(feature = "unstable-simd")] -use core::simd::cmp::{SimdPartialEq, SimdPartialOrd}; - -#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] -use core::arch::x86_64::{ - __m256i, _mm256_cmpeq_epu8_mask, _mm256_cmplt_epu8_mask, _mm256_lddqu_si256, - _mm256_maskz_loadu_epi8, _mm256_set1_epi8, _mm256_storeu_epi8, -}; - -#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] -macro_rules! splat_mm256 { - ($val:expr) => { - _mm256_set1_epi8(transmute::($val)) - }; -} - -macro_rules! write_escape { - ($escape:expr, $dst:expr) => { - core::ptr::copy_nonoverlapping($escape.0.as_ptr(), $dst, 8); - }; -} - -#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] -macro_rules! impl_format_simd_avx512vl { - ($dst:expr, $src:expr, $value_len:expr) => { - let mut nb: usize = $value_len; - - let blash = splat_mm256!(b'\\'); - let quote = splat_mm256!(b'"'); - let x20 = splat_mm256!(32); - - unsafe { - while nb >= STRIDE { - let str_vec = _mm256_lddqu_si256(transmute::<*const u8, *const __m256i>($src)); - - _mm256_storeu_epi8($dst as *mut i8, str_vec); - - #[allow(unused_mut)] - let mut mask = _mm256_cmpeq_epu8_mask(str_vec, blash) - | _mm256_cmpeq_epu8_mask(str_vec, quote) - | _mm256_cmplt_epu8_mask(str_vec, x20); - - if unlikely!(mask > 0) { - let cn = mask.trailing_zeros() as usize; // _tzcnt_u32() not inlining - - $src = $src.add(cn); - let escape = QUOTE_TAB[*($src) as usize]; - $src = $src.add(1); - - nb -= cn; - nb -= 1; - - $dst = $dst.add(cn); - write_escape!(escape, $dst); - $dst = $dst.add(escape.1 as usize); - } else { - nb -= STRIDE; - $dst = $dst.add(STRIDE); - $src = $src.add(STRIDE); - } - } - - while nb > 0 { - let remainder_mask: u32 = !(u32::MAX << nb); - let str_vec = _mm256_maskz_loadu_epi8(remainder_mask, $src as *const i8); - - _mm256_storeu_epi8($dst as *mut i8, str_vec); - - let mut mask = _mm256_cmpeq_epu8_mask(str_vec, blash) - | _mm256_cmpeq_epu8_mask(str_vec, quote) - | _mm256_cmplt_epu8_mask(str_vec, x20); - - mask &= remainder_mask; - - if unlikely!(mask > 0) { - let cn = mask.trailing_zeros() as usize; // _tzcnt_u32() not inlining - - $src = $src.add(cn); - let escape = QUOTE_TAB[*($src) as usize]; - $src = $src.add(1); - - nb -= cn; - nb -= 1; - - $dst = $dst.add(cn); - write_escape!(escape, $dst); - $dst = $dst.add(escape.1 as usize); - } else { - $dst = $dst.add(nb); - break; - } - } - } - }; -} - -#[cfg(feature = "unstable-simd")] -macro_rules! impl_escape_unchecked { - ($src:expr, $dst:expr, $nb:expr, $omask:expr, $cn:expr, $v:expr, $rotate:expr) => { - if $rotate == true { - for _ in 0..$cn { - $v = $v.rotate_elements_left::<1>(); - } - } - $nb -= $cn; - $dst = $dst.add($cn); - $src = $src.add($cn); - $omask >>= $cn; - loop { - if $rotate == true { - $v = $v.rotate_elements_left::<1>(); - } - $nb -= 1; - $omask >>= 1; - let escape = QUOTE_TAB[*($src) as usize]; - write_escape!(escape, $dst); - $dst = $dst.add(escape.1 as usize); - $src = $src.add(1); - if likely!($omask & 1 != 1) { - break; - } - } - }; -} - -#[cfg(feature = "unstable-simd")] -macro_rules! impl_format_simd_generic_128 { - ($dst:expr, $src:expr, $value_len:expr) => { - let last_stride_src = $src.add($value_len).sub(STRIDE); - let mut nb: usize = $value_len; - - assume!($value_len >= STRIDE_128_BIT); - - const BLASH: StrVector = StrVector::from_array([b'\\'; STRIDE]); - const QUOTE: StrVector = StrVector::from_array([b'"'; STRIDE]); - const X20: StrVector = StrVector::from_array([32; STRIDE]); - - unsafe { - { - const ROTATE: bool = false; - while nb > STRIDE { - let mut v = StrVector::from_slice(core::slice::from_raw_parts($src, STRIDE)); - let mut mask = - (v.simd_eq(BLASH) | v.simd_eq(QUOTE) | v.simd_lt(X20)).to_bitmask() as u32; - v.copy_to_slice(core::slice::from_raw_parts_mut($dst, STRIDE)); - - if unlikely!(mask > 0) { - let cn = mask.trailing_zeros() as usize; - impl_escape_unchecked!($src, $dst, nb, mask, cn, v, ROTATE); - } else { - nb -= STRIDE; - $dst = $dst.add(STRIDE); - $src = $src.add(STRIDE); - } - } - } - - if nb > 0 { - const ROTATE: bool = true; - let mut v = - StrVector::from_slice(core::slice::from_raw_parts(last_stride_src, STRIDE)); - let mut to_skip = STRIDE - nb; - while to_skip >= 4 { - to_skip -= 4; - v = v.rotate_elements_left::<4>(); - } - while to_skip > 0 { - to_skip -= 1; - v = v.rotate_elements_left::<1>(); - } - - let mut mask = (v.simd_eq(BLASH) | v.simd_eq(QUOTE) | v.simd_lt(X20)).to_bitmask() - as u32 - & (STRIDE_SATURATION >> (32 - STRIDE - nb)); - - while nb > 0 { - v.copy_to_slice(core::slice::from_raw_parts_mut($dst, STRIDE)); - - if unlikely!(mask > 0) { - let cn = mask.trailing_zeros() as usize; - impl_escape_unchecked!($src, $dst, nb, mask, cn, v, ROTATE); - } else { - $dst = $dst.add(nb); - break; - } - } - } - } - }; -} - -macro_rules! impl_format_scalar { - ($dst:expr, $src:expr, $value_len:expr) => { - unsafe { - for _ in 0..$value_len { - core::ptr::write($dst, *($src)); - $src = $src.add(1); - $dst = $dst.add(1); - if unlikely!(NEED_ESCAPED[*($src.sub(1)) as usize] > 0) { - let escape = QUOTE_TAB[*($src.sub(1)) as usize]; - write_escape!(escape, $dst.sub(1)); - $dst = $dst.add(escape.1 as usize - 1); - } - } - } - }; -} - -#[inline(never)] -#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] -#[cfg_attr(target_arch = "x86_64", target_feature(enable = "avx512f"))] -#[cfg_attr(target_arch = "x86_64", target_feature(enable = "avx512bw"))] -#[cfg_attr(target_arch = "x86_64", target_feature(enable = "avx512vl"))] -#[cfg_attr(target_arch = "x86_64", target_feature(enable = "bmi2"))] -pub unsafe fn format_escaped_str_impl_512vl( - odst: *mut u8, - value_ptr: *const u8, - value_len: usize, -) -> usize { - const STRIDE: usize = 32; - - let mut dst = odst; - let mut src = value_ptr; - - core::ptr::write(dst, b'"'); - dst = dst.add(1); - - impl_format_simd_avx512vl!(dst, src, value_len); - - core::ptr::write(dst, b'"'); - dst = dst.add(1); - - dst as usize - odst as usize -} - -#[inline(never)] -#[cfg(feature = "unstable-simd")] -#[cfg_attr(target_arch = "aarch64", target_feature(enable = "neon"))] -#[cfg_attr(target_arch = "x86_64", target_feature(enable = "sse2"))] -pub unsafe fn format_escaped_str_impl_128( - odst: *mut u8, - value_ptr: *const u8, - value_len: usize, -) -> usize { - const STRIDE: usize = STRIDE_128_BIT; - const STRIDE_SATURATION: u32 = u16::MAX as u32; - type StrVector = core::simd::u8x16; - - let mut dst = odst; - let mut src = value_ptr; - - core::ptr::write(dst, b'"'); - dst = dst.add(1); - - if value_len < STRIDE_128_BIT { - impl_format_scalar!(dst, src, value_len) - } else { - impl_format_simd_generic_128!(dst, src, value_len); - } - - core::ptr::write(dst, b'"'); - dst = dst.add(1); - - dst as usize - odst as usize -} - -#[cfg(not(feature = "unstable-simd"))] -pub unsafe fn format_escaped_str_scalar( - odst: *mut u8, - value_ptr: *const u8, - value_len: usize, -) -> usize { - let mut dst = odst; - let mut src = value_ptr; - - core::ptr::write(dst, b'"'); - dst = dst.add(1); - - impl_format_scalar!(dst, src, value_len); - - core::ptr::write(dst, b'"'); - dst = dst.add(1); - - dst as usize - odst as usize -} - -const NEED_ESCAPED: [u8; 256] = [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -]; - -const QUOTE_TAB: [([u8; 7], u8); 96] = [ - (*b"\\u0000\0", 6), - (*b"\\u0001\0", 6), - (*b"\\u0002\0", 6), - (*b"\\u0003\0", 6), - (*b"\\u0004\0", 6), - (*b"\\u0005\0", 6), - (*b"\\u0006\0", 6), - (*b"\\u0007\0", 6), - (*b"\\b\0\0\0\0\0", 2), - (*b"\\t\0\0\0\0\0", 2), - (*b"\\n\0\0\0\0\0", 2), - (*b"\\u000b\0", 6), - (*b"\\f\0\0\0\0\0", 2), - (*b"\\r\0\0\0\0\0", 2), - (*b"\\u000e\0", 6), - (*b"\\u000f\0", 6), - (*b"\\u0010\0", 6), - (*b"\\u0011\0", 6), - (*b"\\u0012\0", 6), - (*b"\\u0013\0", 6), - (*b"\\u0014\0", 6), - (*b"\\u0015\0", 6), - (*b"\\u0016\0", 6), - (*b"\\u0017\0", 6), - (*b"\\u0018\0", 6), - (*b"\\u0019\0", 6), - (*b"\\u001a\0", 6), - (*b"\\u001b\0", 6), - (*b"\\u001c\0", 6), - (*b"\\u001d\0", 6), - (*b"\\u001e\0", 6), - (*b"\\u001f\0", 6), - ([0; 7], 0), - ([0; 7], 0), - (*b"\\\"\0\0\0\0\0", 2), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), - (*b"\\\\\0\0\0\0\0", 2), - ([0; 7], 0), - ([0; 7], 0), - ([0; 7], 0), -]; diff --git a/src/serialize/writer/str/avx512.rs b/src/serialize/writer/str/avx512.rs new file mode 100644 index 00000000..d245ed53 --- /dev/null +++ b/src/serialize/writer/str/avx512.rs @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: (Apache-2.0 OR MIT) + +use core::mem::transmute; + +use super::escape::QUOTE_TAB; + +use core::arch::x86_64::{ + __m256i, _mm256_cmpeq_epu8_mask, _mm256_cmplt_epu8_mask, _mm256_load_si256, _mm256_loadu_si256, + _mm256_maskz_loadu_epi8, _mm256_storeu_epi8, +}; + +#[repr(C, align(32))] +struct ConstArray { + pub data: [u8; 32], +} + +const BLASH: ConstArray = ConstArray { data: [b'\\'; 32] }; +const QUOTE: ConstArray = ConstArray { data: [b'"'; 32] }; +const X20: ConstArray = ConstArray { data: [32; 32] }; + +macro_rules! impl_format_simd_avx512vl { + ($dst:expr, $src:expr, $value_len:expr) => { + let mut nb: usize = $value_len; + + let blash = _mm256_load_si256(BLASH.data.as_ptr() as *const __m256i); + let quote = _mm256_load_si256(QUOTE.data.as_ptr() as *const __m256i); + let x20 = _mm256_load_si256(X20.data.as_ptr() as *const __m256i); + + unsafe { + while nb >= STRIDE { + let str_vec = _mm256_loadu_si256(transmute::<*const u8, *const __m256i>($src)); + + _mm256_storeu_epi8($dst as *mut i8, str_vec); + + let mask = _mm256_cmpeq_epu8_mask(str_vec, blash) + | _mm256_cmpeq_epu8_mask(str_vec, quote) + | _mm256_cmplt_epu8_mask(str_vec, x20); + + if unlikely!(mask > 0) { + let cn = trailing_zeros!(mask); + $src = $src.add(cn); + $dst = $dst.add(cn); + nb -= cn; + nb -= 1; + + let escape = QUOTE_TAB[*($src) as usize]; + $src = $src.add(1); + + write_escape!(escape, $dst); + $dst = $dst.add(escape.1 as usize); + } else { + nb -= STRIDE; + $dst = $dst.add(STRIDE); + $src = $src.add(STRIDE); + } + } + + if nb > 0 { + loop { + let remainder_mask = !(u32::MAX << nb); + let str_vec = _mm256_maskz_loadu_epi8(remainder_mask, $src as *const i8); + + _mm256_storeu_epi8($dst as *mut i8, str_vec); + + let mask = (_mm256_cmpeq_epu8_mask(str_vec, blash) + | _mm256_cmpeq_epu8_mask(str_vec, quote) + | _mm256_cmplt_epu8_mask(str_vec, x20)) + & remainder_mask; + + if unlikely!(mask > 0) { + let cn = trailing_zeros!(mask); + $src = $src.add(cn); + $dst = $dst.add(cn); + nb -= cn; + nb -= 1; + + let escape = QUOTE_TAB[*($src) as usize]; + $src = $src.add(1); + + write_escape!(escape, $dst); + $dst = $dst.add(escape.1 as usize); + } else { + $dst = $dst.add(nb); + break; + } + } + } + } + }; +} + +#[inline(never)] +#[cfg_attr(feature = "avx512", target_feature(enable = "avx512f"))] +#[cfg_attr(feature = "avx512", target_feature(enable = "avx512bw"))] +#[cfg_attr(feature = "avx512", target_feature(enable = "avx512vl"))] +#[cfg_attr(feature = "avx512", target_feature(enable = "bmi2"))] +#[cfg_attr(feature = "avx512", target_feature(enable = "bmi1"))] +pub unsafe fn format_escaped_str_impl_512vl( + odst: *mut u8, + value_ptr: *const u8, + value_len: usize, +) -> usize { + const STRIDE: usize = 32; + + let mut dst = odst; + let mut src = value_ptr; + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + impl_format_simd_avx512vl!(dst, src, value_len); + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + dst as usize - odst as usize +} diff --git a/src/serialize/writer/str/escape.rs b/src/serialize/writer/str/escape.rs new file mode 100644 index 00000000..6195a24d --- /dev/null +++ b/src/serialize/writer/str/escape.rs @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: (Apache-2.0 OR MIT) +// the constants and SIMD approach are adapted from cloudwego's sonic-rs + + +macro_rules! write_escape { + ($escape:expr, $dst:expr) => { + core::ptr::copy_nonoverlapping($escape.0.as_ptr(), $dst, 8); + }; +} + +pub const NEED_ESCAPED: [u8; 256] = [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +]; + +pub const QUOTE_TAB: [([u8; 7], u8); 96] = [ + (*b"\\u0000\0", 6), + (*b"\\u0001\0", 6), + (*b"\\u0002\0", 6), + (*b"\\u0003\0", 6), + (*b"\\u0004\0", 6), + (*b"\\u0005\0", 6), + (*b"\\u0006\0", 6), + (*b"\\u0007\0", 6), + (*b"\\b\0\0\0\0\0", 2), + (*b"\\t\0\0\0\0\0", 2), + (*b"\\n\0\0\0\0\0", 2), + (*b"\\u000b\0", 6), + (*b"\\f\0\0\0\0\0", 2), + (*b"\\r\0\0\0\0\0", 2), + (*b"\\u000e\0", 6), + (*b"\\u000f\0", 6), + (*b"\\u0010\0", 6), + (*b"\\u0011\0", 6), + (*b"\\u0012\0", 6), + (*b"\\u0013\0", 6), + (*b"\\u0014\0", 6), + (*b"\\u0015\0", 6), + (*b"\\u0016\0", 6), + (*b"\\u0017\0", 6), + (*b"\\u0018\0", 6), + (*b"\\u0019\0", 6), + (*b"\\u001a\0", 6), + (*b"\\u001b\0", 6), + (*b"\\u001c\0", 6), + (*b"\\u001d\0", 6), + (*b"\\u001e\0", 6), + (*b"\\u001f\0", 6), + ([0; 7], 0), + ([0; 7], 0), + (*b"\\\"\0\0\0\0\0", 2), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), + (*b"\\\\\0\0\0\0\0", 2), + ([0; 7], 0), + ([0; 7], 0), + ([0; 7], 0), +]; diff --git a/src/serialize/writer/str/generic.rs b/src/serialize/writer/str/generic.rs new file mode 100644 index 00000000..77352487 --- /dev/null +++ b/src/serialize/writer/str/generic.rs @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: (Apache-2.0 OR MIT) + +use super::escape::{NEED_ESCAPED, QUOTE_TAB}; +use core::simd::cmp::{SimdPartialEq, SimdPartialOrd}; + +macro_rules! impl_format_simd_generic_128 { + ($dst:expr, $src:expr, $value_len:expr) => { + let last_stride_src = $src.add($value_len).sub(STRIDE); + let mut nb: usize = $value_len; + + assume!($value_len >= STRIDE); + + const BLASH: StrVector = StrVector::from_array([b'\\'; STRIDE]); + const QUOTE: StrVector = StrVector::from_array([b'"'; STRIDE]); + const X20: StrVector = StrVector::from_array([32; STRIDE]); + + unsafe { + { + while nb >= STRIDE { + let v = StrVector::from_slice(core::slice::from_raw_parts($src, STRIDE)); + let mask = + (v.simd_eq(BLASH) | v.simd_eq(QUOTE) | v.simd_lt(X20)).to_bitmask() as u32; + v.copy_to_slice(core::slice::from_raw_parts_mut($dst, STRIDE)); + + if unlikely!(mask > 0) { + let cn = trailing_zeros!(mask) as usize; + nb -= cn; + $dst = $dst.add(cn); + $src = $src.add(cn); + nb -= 1; + let escape = QUOTE_TAB[*($src) as usize]; + write_escape!(escape, $dst); + $dst = $dst.add(escape.1 as usize); + $src = $src.add(1); + } else { + nb -= STRIDE; + $dst = $dst.add(STRIDE); + $src = $src.add(STRIDE); + } + } + } + + if nb > 0 { + let mut scratch: [u8; 32] = [b'a'; 32]; + let mut v = + StrVector::from_slice(core::slice::from_raw_parts(last_stride_src, STRIDE)); + v.copy_to_slice(core::slice::from_raw_parts_mut( + scratch.as_mut_ptr(), + STRIDE, + )); + + let mut scratch_ptr = scratch.as_mut_ptr().add(16 - nb); + v = StrVector::from_slice(core::slice::from_raw_parts(scratch_ptr, STRIDE)); + let mut mask = + (v.simd_eq(BLASH) | v.simd_eq(QUOTE) | v.simd_lt(X20)).to_bitmask() as u32; + + while nb > 0 { + v.copy_to_slice(core::slice::from_raw_parts_mut($dst, STRIDE)); + if unlikely!(mask > 0) { + let cn = trailing_zeros!(mask) as usize; + nb -= cn; + $dst = $dst.add(cn); + scratch_ptr = scratch_ptr.add(cn); + nb -= 1; + mask >>= cn + 1; + let escape = QUOTE_TAB[*(scratch_ptr) as usize]; + write_escape!(escape, $dst); + $dst = $dst.add(escape.1 as usize); + scratch_ptr = scratch_ptr.add(1); + v = StrVector::from_slice(core::slice::from_raw_parts(scratch_ptr, STRIDE)); + } else { + $dst = $dst.add(nb); + break; + } + } + } + } + }; +} + +#[allow(dead_code)] +#[inline(never)] +#[cfg_attr(target_arch = "x86_64", target_feature(enable = "sse2"))] +#[cfg_attr(target_arch = "x86_64", target_feature(enable = "bmi1"))] +#[cfg_attr(target_arch = "aarch64", target_feature(enable = "neon"))] +pub unsafe fn format_escaped_str_impl_generic_128( + odst: *mut u8, + value_ptr: *const u8, + value_len: usize, +) -> usize { + const STRIDE: usize = 16; + type StrVector = core::simd::u8x16; + + let mut dst = odst; + let mut src = value_ptr; + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + if value_len < STRIDE { + impl_format_scalar!(dst, src, value_len) + } else { + impl_format_simd_generic_128!(dst, src, value_len); + } + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + dst as usize - odst as usize +} diff --git a/src/serialize/writer/str/mod.rs b/src/serialize/writer/str/mod.rs new file mode 100644 index 00000000..0836ecba --- /dev/null +++ b/src/serialize/writer/str/mod.rs @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 + +#[macro_use] +mod escape; +#[macro_use] +mod scalar; + +#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] +mod avx512; + +#[cfg(feature = "unstable-simd")] +mod generic; + +#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] +pub use avx512::format_escaped_str_impl_512vl; + +#[allow(unused_imports)] +#[cfg(feature = "unstable-simd")] +pub use generic::format_escaped_str_impl_generic_128; + +#[cfg(not(feature = "unstable-simd"))] +pub use scalar::format_escaped_str_scalar; diff --git a/src/serialize/writer/str/scalar.rs b/src/serialize/writer/str/scalar.rs new file mode 100644 index 00000000..f324e098 --- /dev/null +++ b/src/serialize/writer/str/scalar.rs @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: (Apache-2.0 OR MIT) + +#[cfg(not(feature = "unstable-simd"))] +use super::escape::{NEED_ESCAPED, QUOTE_TAB}; + +macro_rules! impl_format_scalar { + ($dst:expr, $src:expr, $value_len:expr) => { + unsafe { + for _ in 0..$value_len { + core::ptr::write($dst, *($src)); + $src = $src.add(1); + $dst = $dst.add(1); + if unlikely!(NEED_ESCAPED[*($src.sub(1)) as usize] > 0) { + let escape = QUOTE_TAB[*($src.sub(1)) as usize]; + write_escape!(escape, $dst.sub(1)); + $dst = $dst.add(escape.1 as usize - 1); + } + } + } + }; +} + +#[cfg(not(feature = "unstable-simd"))] +pub unsafe fn format_escaped_str_scalar( + odst: *mut u8, + value_ptr: *const u8, + value_len: usize, +) -> usize { + let mut dst = odst; + let mut src = value_ptr; + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + impl_format_scalar!(dst, src, value_len); + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + dst as usize - odst as usize +} diff --git a/src/str/avx512.rs b/src/str/avx512.rs new file mode 100644 index 00000000..291ed9a1 --- /dev/null +++ b/src/str/avx512.rs @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: (Apache-2.0 OR MIT) + +use crate::str::pyunicode_new::*; + +use core::arch::x86_64::{ + __m256i, _mm256_and_si256, _mm256_cmpgt_epu8_mask, _mm256_cmpneq_epi8_mask, _mm256_loadu_si256, + _mm256_mask_cmpneq_epi8_mask, _mm256_maskz_loadu_epi8, _mm256_max_epu8, _mm256_set1_epi8, +}; + +macro_rules! u8_as_i8 { + ($val:expr) => { + core::mem::transmute::($val) + }; +} + +macro_rules! impl_kind_simd_avx512vl { + ($buf:expr) => { + unsafe { + const STRIDE: usize = 32; + + assume!($buf.len() > 0); + + let num_loops = $buf.len() / STRIDE; + let remainder = $buf.len() % STRIDE; + + let remainder_mask: u32 = !(u32::MAX << remainder); + let mut str_vec = + _mm256_maskz_loadu_epi8(remainder_mask, $buf.as_bytes().as_ptr() as *const i8); + let sptr = $buf.as_bytes().as_ptr().add(remainder); + + for i in 0..num_loops { + str_vec = _mm256_max_epu8( + str_vec, + _mm256_loadu_si256(sptr.add(STRIDE * i) as *const __m256i), + ); + } + + let vec_128 = _mm256_set1_epi8(u8_as_i8!(0b10000000)); + if _mm256_cmpgt_epu8_mask(str_vec, vec_128) == 0 { + pyunicode_ascii($buf.as_bytes().as_ptr(), $buf.len()) + } else { + let is_four = _mm256_cmpgt_epu8_mask(str_vec, _mm256_set1_epi8(u8_as_i8!(239))) > 0; + let is_not_latin = + _mm256_cmpgt_epu8_mask(str_vec, _mm256_set1_epi8(u8_as_i8!(195))) > 0; + let multibyte = _mm256_set1_epi8(u8_as_i8!(0b11000000)); + + let mut num_chars = popcnt!(_mm256_mask_cmpneq_epi8_mask( + remainder_mask, + _mm256_and_si256( + _mm256_maskz_loadu_epi8( + remainder_mask, + $buf.as_bytes().as_ptr() as *const i8 + ), + multibyte + ), + vec_128 + )); + + for i in 0..num_loops { + num_chars += popcnt!(_mm256_cmpneq_epi8_mask( + _mm256_and_si256( + _mm256_loadu_si256(sptr.add(STRIDE * i) as *const __m256i), + multibyte + ), + vec_128, + )) as usize; + } + + if is_four { + pyunicode_fourbyte($buf, num_chars) + } else if is_not_latin { + pyunicode_twobyte($buf, num_chars) + } else { + pyunicode_onebyte($buf, num_chars) + } + } + } + }; +} + +#[inline(never)] +#[cfg_attr(feature = "avx512", target_feature(enable = "avx512f"))] +#[cfg_attr(feature = "avx512", target_feature(enable = "avx512bw"))] +#[cfg_attr(feature = "avx512", target_feature(enable = "avx512vl"))] +#[cfg_attr(feature = "avx512", target_feature(enable = "bmi2"))] +#[cfg_attr(feature = "avx512", target_feature(enable = "bmi1"))] +#[cfg_attr(feature = "avx512", target_feature(enable = "popcnt"))] +pub unsafe fn create_str_impl_avx512vl(buf: &str) -> *mut pyo3_ffi::PyObject { + impl_kind_simd_avx512vl!(buf) +} + +#[inline(always)] +pub fn unicode_from_str(buf: &str) -> *mut pyo3_ffi::PyObject { + unsafe { + if unlikely!(buf.is_empty()) { + return use_immortal!(crate::typeref::EMPTY_UNICODE); + } + if std::is_x86_feature_detected!("avx512vl") { + create_str_impl_avx512vl(buf) + } else { + super::scalar::unicode_from_str(buf) + } + } +} diff --git a/src/str/check.rs b/src/str/check.rs deleted file mode 100644 index 91b90079..00000000 --- a/src/str/check.rs +++ /dev/null @@ -1,32 +0,0 @@ -// SPDX-License-Identifier: (Apache-2.0 OR MIT) - -const STRIDE_SIZE: usize = 8; - -pub fn is_four_byte(buf: &str) -> bool { - let as_bytes = buf.as_bytes(); - let len = as_bytes.len(); - unsafe { - let mut idx = 0; - while idx < len.saturating_sub(STRIDE_SIZE) { - let mut val: bool = false; - val |= *as_bytes.get_unchecked(idx) > 239; - val |= *as_bytes.get_unchecked(idx + 1) > 239; - val |= *as_bytes.get_unchecked(idx + 2) > 239; - val |= *as_bytes.get_unchecked(idx + 3) > 239; - val |= *as_bytes.get_unchecked(idx + 4) > 239; - val |= *as_bytes.get_unchecked(idx + 5) > 239; - val |= *as_bytes.get_unchecked(idx + 6) > 239; - val |= *as_bytes.get_unchecked(idx + 7) > 239; - idx += STRIDE_SIZE; - if val { - return true; - } - } - let mut ret = false; - while idx < len { - ret |= *as_bytes.get_unchecked(idx) > 239; - idx += 1; - } - ret - } -} diff --git a/src/str/mod.rs b/src/str/mod.rs index 66ccb67c..75149e9b 100644 --- a/src/str/mod.rs +++ b/src/str/mod.rs @@ -1,9 +1,15 @@ // SPDX-License-Identifier: (Apache-2.0 OR MIT) -mod check; -mod create; +#[cfg(feature = "avx512")] +mod avx512; mod ffi; +mod pyunicode_new; +mod scalar; + +#[cfg(not(feature = "avx512"))] +pub use scalar::unicode_from_str; + +#[cfg(feature = "avx512")] +pub use avx512::unicode_from_str; -pub use check::*; -pub use create::*; pub use ffi::*; diff --git a/src/str/create.rs b/src/str/pyunicode_new.rs similarity index 54% rename from src/str/create.rs rename to src/str/pyunicode_new.rs index c84d9df7..b5496891 100644 --- a/src/str/create.rs +++ b/src/str/pyunicode_new.rs @@ -1,49 +1,14 @@ // SPDX-License-Identifier: (Apache-2.0 OR MIT) -use crate::str::is_four_byte; +use pyo3_ffi::{PyASCIIObject, PyCompactUnicodeObject}; -use crate::typeref::EMPTY_UNICODE; -use pyo3_ffi::*; - -enum PyUnicodeKind { - Ascii, - OneByte, - TwoByte, - FourByte, -} - -fn find_str_kind(buf: &str, num_chars: usize) -> PyUnicodeKind { - if buf.len() == num_chars { - PyUnicodeKind::Ascii - } else if is_four_byte(buf) { - PyUnicodeKind::FourByte - } else if encoding_rs::mem::is_str_latin1(buf) { - PyUnicodeKind::OneByte - } else { - PyUnicodeKind::TwoByte - } -} - -pub fn unicode_from_str(buf: &str) -> *mut pyo3_ffi::PyObject { - if buf.is_empty() { - use_immortal!(EMPTY_UNICODE) - } else { - let num_chars = bytecount::num_chars(buf.as_bytes()); - match find_str_kind(buf, num_chars) { - PyUnicodeKind::Ascii => pyunicode_ascii(buf), - PyUnicodeKind::OneByte => pyunicode_onebyte(buf, num_chars), - PyUnicodeKind::TwoByte => pyunicode_twobyte(buf, num_chars), - PyUnicodeKind::FourByte => pyunicode_fourbyte(buf, num_chars), - } - } -} - -pub fn pyunicode_ascii(buf: &str) -> *mut pyo3_ffi::PyObject { +#[inline(never)] +pub fn pyunicode_ascii(buf: *const u8, num_chars: usize) -> *mut pyo3_ffi::PyObject { unsafe { - let ptr = ffi!(PyUnicode_New(buf.len() as isize, 127)); + let ptr = ffi!(PyUnicode_New(num_chars as isize, 127)); let data_ptr = ptr.cast::().offset(1) as *mut u8; - core::ptr::copy_nonoverlapping(buf.as_ptr(), data_ptr, buf.len()); - core::ptr::write(data_ptr.add(buf.len()), 0); + core::ptr::copy_nonoverlapping(buf, data_ptr, num_chars); + core::ptr::write(data_ptr.add(num_chars), 0); ptr } } @@ -63,6 +28,7 @@ pub fn pyunicode_onebyte(buf: &str, num_chars: usize) -> *mut pyo3_ffi::PyObject } } +#[inline(never)] pub fn pyunicode_twobyte(buf: &str, num_chars: usize) -> *mut pyo3_ffi::PyObject { unsafe { let ptr = ffi!(PyUnicode_New(num_chars as isize, 65535)); @@ -76,6 +42,7 @@ pub fn pyunicode_twobyte(buf: &str, num_chars: usize) -> *mut pyo3_ffi::PyObject } } +#[inline(never)] pub fn pyunicode_fourbyte(buf: &str, num_chars: usize) -> *mut pyo3_ffi::PyObject { unsafe { let ptr = ffi!(PyUnicode_New(num_chars as isize, 1114111)); diff --git a/src/str/scalar.rs b/src/str/scalar.rs new file mode 100644 index 00000000..0aa84c7d --- /dev/null +++ b/src/str/scalar.rs @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: (Apache-2.0 OR MIT) + +use crate::str::pyunicode_new::*; +use crate::typeref::EMPTY_UNICODE; + +#[inline(always)] +pub fn str_impl_kind_scalar(buf: &str, num_chars: usize) -> *mut pyo3_ffi::PyObject { + unsafe { + let len = buf.len(); + assume!(len > 0); + + if unlikely!(*(buf.as_bytes().as_ptr()) > 239) { + return pyunicode_fourbyte(buf, num_chars); + } + + let sptr = buf.as_bytes().as_ptr(); + + let mut is_four = false; + let mut not_latin = false; + for i in 0..len { + is_four |= *sptr.add(i) > 239; + not_latin |= *sptr.add(i) > 195; + } + if is_four { + pyunicode_fourbyte(buf, num_chars) + } else if not_latin { + pyunicode_twobyte(buf, num_chars) + } else { + pyunicode_onebyte(buf, num_chars) + } + } +} + +#[inline(never)] +pub fn unicode_from_str(buf: &str) -> *mut pyo3_ffi::PyObject { + if unlikely!(buf.is_empty()) { + return use_immortal!(EMPTY_UNICODE); + } + let num_chars = bytecount::num_chars(buf.as_bytes()); + if buf.len() == num_chars { + pyunicode_ascii(buf.as_ptr(), num_chars) + } else { + str_impl_kind_scalar(buf, num_chars) + } +} diff --git a/src/util.rs b/src/util.rs index db6e8a8c..8add0f75 100644 --- a/src/util.rs +++ b/src/util.rs @@ -273,3 +273,35 @@ macro_rules! assume { }; }; } + +#[allow(unused_macros)] +#[cfg(feature = "intrinsics")] +macro_rules! trailing_zeros { + ($val:expr) => { + core::intrinsics::cttz_nonzero($val) as usize + }; +} + +#[allow(unused_macros)] +#[cfg(not(feature = "intrinsics"))] +macro_rules! trailing_zeros { + ($val:expr) => { + $val.trailing_zeros() as usize + }; +} + +#[allow(unused_macros)] +#[cfg(feature = "intrinsics")] +macro_rules! popcnt { + ($val:expr) => { + core::intrinsics::ctpop(core::mem::transmute::($val)) as usize + }; +} + +#[allow(unused_macros)] +#[cfg(not(feature = "intrinsics"))] +macro_rules! popcnt { + ($val:expr) => { + core::mem::transmute::($val).count_ones() as usize + }; +}