diff --git a/.github/workflows/debug.yaml b/.github/workflows/debug.yaml index d42079ca..c9c374c7 100644 --- a/.github/workflows/debug.yaml +++ b/.github/workflows/debug.yaml @@ -8,9 +8,10 @@ jobs: fail-fast: false matrix: profile: [ - { rust: "1.72", features: "" }, - { rust: "1.72", features: "--features=yyjson" }, - { rust: "nightly-2024-04-15", features: "--features=yyjson,unstable-simd" }, + { rust: "1.72", features: "", env "" }, + { rust: "1.72", features: "--features=yyjson", env: "" }, + { rust: "nightly-2024-04-15", features: "--features=yyjson,unstable-simd", env: "ORJSON_DISABLE_AVX512=1" }, + { rust: "nightly-2024-04-15", features: "--features=yyjson,unstable-simd", env: "" }, ] python: [ { version: '3.13' }, @@ -37,7 +38,7 @@ jobs: - name: build run: | - PATH="$HOME/.cargo/bin:$PATH" maturin build --release \ + PATH="$HOME/.cargo/bin:$PATH" "${{ matrix.profile.env }}" maturin build --release \ --out=dist \ --profile=dev \ --interpreter python${{ matrix.python.version }} \ diff --git a/build.rs b/build.rs index 0d8477d5..8e26f450 100644 --- a/build.rs +++ b/build.rs @@ -32,6 +32,11 @@ fn main() { if env::var("ORJSON_DISABLE_SIMD").is_err() { if let Some(true) = version_check::supports_feature("portable_simd") { println!("cargo:rustc-cfg=feature=\"unstable-simd\""); + + #[cfg(target_arch = "x86_64")] + if env::var("ORJSON_DISABLE_AVX512").is_err() { + println!("cargo:rustc-cfg=feature=\"avx512\""); + } } } diff --git a/src/lib.rs b/src/lib.rs index 5d8cbad5..44dc8c5f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ // SPDX-License-Identifier: (Apache-2.0 OR MIT) +#![cfg_attr(feature = "avx512", feature(avx512_target_feature))] +#![cfg_attr(feature = "avx512", feature(stdarch_x86_avx512))] #![cfg_attr(feature = "intrinsics", feature(core_intrinsics))] #![cfg_attr(feature = "optimize", feature(optimize_attribute))] #![cfg_attr(feature = "strict_provenance", feature(strict_provenance))] diff --git a/src/serialize/writer/json.rs b/src/serialize/writer/json.rs index c6a51056..7bdc10d6 100644 --- a/src/serialize/writer/json.rs +++ b/src/serialize/writer/json.rs @@ -2,6 +2,7 @@ // This is an adaptation of `src/value/ser.rs` from serde-json. use crate::serialize::writer::formatter::{CompactFormatter, Formatter, PrettyFormatter}; +use crate::serialize::writer::str::*; use crate::serialize::writer::WriteExt; use serde::ser::{self, Impossible, Serialize}; use serde_json::error::{Error, Result}; @@ -148,9 +149,10 @@ where unreachable!(); } - #[inline] + #[inline(always)] fn serialize_str(self, value: &str) -> Result<()> { - format_escaped_str(&mut self.writer, value).map_err(Error::io) + format_escaped_str(&mut self.writer, value); + Ok(()) } fn serialize_bytes(self, value: &[u8]) -> Result<()> { @@ -396,7 +398,7 @@ where type SerializeStruct = Impossible<(), Error>; type SerializeStructVariant = Impossible<(), Error>; - #[inline] + #[inline(always)] fn serialize_str(self, value: &str) -> Result<()> { self.ser.serialize_str(value) } @@ -553,17 +555,25 @@ where } } -#[cfg(feature = "unstable-simd")] +macro_rules! reserve_str { + ($writer:expr, $value:expr) => { + $writer.reserve($value.len() * 8 + 32); + }; +} + +#[cfg(all( + feature = "unstable-simd", + any(not(target_arch = "x86_64"), not(feature = "avx512")) +))] #[inline(always)] -fn format_escaped_str(writer: &mut W, value: &str) -> io::Result<()> +fn format_escaped_str(writer: &mut W, value: &str) where W: ?Sized + io::Write + WriteExt, { unsafe { - let num_reserved_bytes = value.len() * 8 + 32; - writer.reserve(num_reserved_bytes); + reserve_str!(writer, value); - let written = crate::serialize::writer::escape::format_escaped_str_impl_128( + let written = format_escaped_str_impl_128( writer.as_mut_buffer_ptr(), value.as_bytes().as_ptr(), value.len(), @@ -571,28 +581,51 @@ where writer.set_written(written); } - Ok(()) +} + +#[cfg(all(feature = "unstable-simd", target_arch = "x86_64", feature = "avx512"))] +#[inline(always)] +fn format_escaped_str(writer: &mut W, value: &str) +where + W: ?Sized + io::Write + WriteExt, +{ + unsafe { + reserve_str!(writer, value); + + if std::is_x86_feature_detected!("avx512vl") { + let written = format_escaped_str_impl_512vl( + writer.as_mut_buffer_ptr(), + value.as_bytes().as_ptr(), + value.len(), + ); + writer.set_written(written); + } else { + let written = format_escaped_str_impl_128( + writer.as_mut_buffer_ptr(), + value.as_bytes().as_ptr(), + value.len(), + ); + writer.set_written(written); + }; + } } #[cfg(not(feature = "unstable-simd"))] #[inline(always)] -fn format_escaped_str(writer: &mut W, value: &str) -> io::Result<()> +fn format_escaped_str(writer: &mut W, value: &str) where W: ?Sized + io::Write + WriteExt, { unsafe { - let num_reserved_bytes = value.len() * 8 + 32; - writer.reserve(num_reserved_bytes); + reserve_str!(writer, value); - let written = crate::serialize::writer::escape::format_escaped_str_scalar( + let written = format_escaped_str_scalar( writer.as_mut_buffer_ptr(), value.as_bytes().as_ptr(), value.len(), ); - writer.set_written(written); } - Ok(()) } #[inline] diff --git a/src/serialize/writer/mod.rs b/src/serialize/writer/mod.rs index 47fdf503..34248dc7 100644 --- a/src/serialize/writer/mod.rs +++ b/src/serialize/writer/mod.rs @@ -1,9 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 mod byteswriter; -mod escape; mod formatter; mod json; +mod str; pub use byteswriter::{BytesWriter, WriteExt}; pub use json::{to_writer, to_writer_pretty}; diff --git a/src/serialize/writer/escape.rs b/src/serialize/writer/str.rs similarity index 50% rename from src/serialize/writer/escape.rs rename to src/serialize/writer/str.rs index 2a68274c..53cc5a22 100644 --- a/src/serialize/writer/escape.rs +++ b/src/serialize/writer/str.rs @@ -2,9 +2,102 @@ // Copyright 2023-2024 liuq19, ijl // adapted from sonic-rs' src/util/string.rs +#[cfg(feature = "avx512")] +use core::mem::transmute; + +#[cfg(feature = "unstable-simd")] +const STRIDE_128_BIT: usize = 16; + #[cfg(feature = "unstable-simd")] use core::simd::cmp::{SimdPartialEq, SimdPartialOrd}; +#[cfg(feature = "avx512")] +use core::arch::x86_64::{ + __m256i, _mm256_cmpeq_epu8_mask, _mm256_cmplt_epu8_mask, _mm256_lddqu_si256, + _mm256_maskz_loadu_epi8, _mm256_set1_epi8, _mm256_storeu_epi8, _tzcnt_u32, +}; + +#[cfg(feature = "avx512")] +macro_rules! splat_mm256 { + ($val:expr) => { + _mm256_set1_epi8(transmute::($val)) + }; +} + +#[cfg(feature = "avx512")] +macro_rules! impl_format_simd_avx512vl { + ($dst:expr, $src:expr, $value_len:expr) => { + let mut nb: usize = $value_len; + + let blash = splat_mm256!(b'\\'); + let quote = splat_mm256!(b'"'); + let x20 = splat_mm256!(32); + + unsafe { + while nb >= STRIDE { + let str_vec = _mm256_lddqu_si256(transmute::<*const u8, *const __m256i>($src)); + + _mm256_storeu_epi8($dst as *mut i8, str_vec); + + #[allow(unused_mut)] + let mut mask = _mm256_cmpeq_epu8_mask(str_vec, blash) + | _mm256_cmpeq_epu8_mask(str_vec, quote) + | _mm256_cmplt_epu8_mask(str_vec, x20); + + if unlikely!(mask > 0) { + let cn = _tzcnt_u32(mask) as usize; + + $src = $src.add(cn); + let escape = QUOTE_TAB[*($src) as usize]; + $src = $src.add(1); + + nb -= cn; + nb -= 1; + + $dst = $dst.add(cn); + core::ptr::write($dst as *mut u64, *(escape.0.as_ptr() as *const u64)); + $dst = $dst.add(escape.1 as usize); + } else { + nb -= STRIDE; + $dst = $dst.add(STRIDE); + $src = $src.add(STRIDE); + } + } + + while nb > 0 { + let remainder_mask: u32 = !(u32::MAX << nb); + let str_vec = _mm256_maskz_loadu_epi8(remainder_mask, $src as *const i8); + + _mm256_storeu_epi8($dst as *mut i8, str_vec); + + let mut mask = _mm256_cmpeq_epu8_mask(str_vec, blash) + | _mm256_cmpeq_epu8_mask(str_vec, quote) + | _mm256_cmplt_epu8_mask(str_vec, x20); + + mask &= remainder_mask; + + if unlikely!(mask > 0) { + let cn = _tzcnt_u32(mask) as usize; + + $src = $src.add(cn); + let escape = QUOTE_TAB[*($src) as usize]; + $src = $src.add(1); + + nb -= cn; + nb -= 1; + + $dst = $dst.add(cn); + core::ptr::write($dst as *mut u64, *(escape.0.as_ptr() as *const u64)); + $dst = $dst.add(escape.1 as usize); + } else { + $dst = $dst.add(nb); + break; + } + } + } + }; +} + #[cfg(feature = "unstable-simd")] macro_rules! impl_escape_unchecked { ($src:expr, $dst:expr, $nb:expr, $omask:expr, $cn:expr, $v:expr, $rotate:expr) => { @@ -24,7 +117,7 @@ macro_rules! impl_escape_unchecked { $nb -= 1; $omask >>= 1; let escape = QUOTE_TAB[*($src) as usize]; - core::ptr::copy_nonoverlapping(escape.0.as_ptr(), $dst, 6); + core::ptr::write($dst as *mut u64, *(escape.0.as_ptr() as *const u64)); $dst = $dst.add(escape.1 as usize); $src = $src.add(1); if likely!($omask & 1 != 1) { @@ -35,46 +128,41 @@ macro_rules! impl_escape_unchecked { } #[cfg(feature = "unstable-simd")] -macro_rules! impl_format_simd { - ($odptr:expr, $value_ptr:expr, $value_len:expr) => { - let mut dptr = $odptr; - let dstart = $odptr; - let mut sptr = $value_ptr; +macro_rules! impl_format_simd_generic_128 { + ($dst:expr, $src:expr, $value_len:expr) => { + let last_stride_src = $src.add($value_len).sub(STRIDE); let mut nb: usize = $value_len; + assume!($value_len >= STRIDE_128_BIT); + const BLASH: StrVector = StrVector::from_array([b'\\'; STRIDE]); const QUOTE: StrVector = StrVector::from_array([b'"'; STRIDE]); const X20: StrVector = StrVector::from_array([32; STRIDE]); unsafe { - *dptr = b'"'; - dptr = dptr.add(1); - { const ROTATE: bool = false; - while nb >= STRIDE { - let mut v = StrVector::from_slice(core::slice::from_raw_parts(sptr, STRIDE)); + while nb > STRIDE { + let mut v = StrVector::from_slice(core::slice::from_raw_parts($src, STRIDE)); let mut mask = (v.simd_eq(BLASH) | v.simd_eq(QUOTE) | v.simd_lt(X20)).to_bitmask() as u32; - v.copy_to_slice(core::slice::from_raw_parts_mut(dptr, STRIDE)); + v.copy_to_slice(core::slice::from_raw_parts_mut($dst, STRIDE)); - if likely!(mask == 0) { - nb -= STRIDE; - dptr = dptr.add(STRIDE); - sptr = sptr.add(STRIDE); - } else { + if unlikely!(mask > 0) { let cn = mask.trailing_zeros() as usize; - impl_escape_unchecked!(sptr, dptr, nb, mask, cn, v, ROTATE); + impl_escape_unchecked!($src, $dst, nb, mask, cn, v, ROTATE); + } else { + nb -= STRIDE; + $dst = $dst.add(STRIDE); + $src = $src.add(STRIDE); } } } - { + if nb > 0 { const ROTATE: bool = true; - let mut v = StrVector::from_slice(core::slice::from_raw_parts( - sptr.add(nb).sub(STRIDE), - STRIDE, - )); + let mut v = + StrVector::from_slice(core::slice::from_raw_parts(last_stride_src, STRIDE)); let mut to_skip = STRIDE - nb; while to_skip >= 4 { to_skip -= 4; @@ -90,98 +178,117 @@ macro_rules! impl_format_simd { & (STRIDE_SATURATION >> (32 - STRIDE - nb)); while nb > 0 { - v.copy_to_slice(core::slice::from_raw_parts_mut(dptr, STRIDE)); + v.copy_to_slice(core::slice::from_raw_parts_mut($dst, STRIDE)); - if likely!(mask == 0) { - dptr = dptr.add(nb); - break; - } else { + if unlikely!(mask > 0) { let cn = mask.trailing_zeros() as usize; - impl_escape_unchecked!(sptr, dptr, nb, mask, cn, v, ROTATE); + impl_escape_unchecked!($src, $dst, nb, mask, cn, v, ROTATE); + } else { + $dst = $dst.add(nb); + break; } } } - - *dptr = b'"'; - dptr = dptr.add(1); } + }; +} - return dptr as usize - dstart as usize; +macro_rules! impl_format_scalar { + ($dst:expr, $src:expr, $value_len:expr) => { + unsafe { + for _ in 0..$value_len { + core::ptr::write($dst, *($src)); + $src = $src.add(1); + $dst = $dst.add(1); + if unlikely!(NEED_ESCAPED[*($src.sub(1)) as usize] > 0) { + let escape = QUOTE_TAB[*($src.sub(1)) as usize]; + core::ptr::write($dst.sub(1) as *mut u64, *(escape.0.as_ptr() as *const u64)); + $dst = $dst.add(escape.1 as usize - 1); + } + } + } }; } +#[inline(never)] +#[cfg(all(feature = "unstable-simd", feature = "avx512"))] +#[cfg_attr(target_arch = "x86_64", target_feature(enable = "avx512f"))] +#[cfg_attr(target_arch = "x86_64", target_feature(enable = "avx512bw"))] +#[cfg_attr(target_arch = "x86_64", target_feature(enable = "avx512vl"))] +#[cfg_attr(target_arch = "x86_64", target_feature(enable = "bmi2"))] +pub unsafe fn format_escaped_str_impl_512vl( + odst: *mut u8, + value_ptr: *const u8, + value_len: usize, +) -> usize { + const STRIDE: usize = 32; + + let mut dst = odst; + let mut src = value_ptr; + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + impl_format_simd_avx512vl!(dst, src, value_len); + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + dst as usize - odst as usize +} + #[inline(never)] #[cfg(feature = "unstable-simd")] #[cfg_attr(target_arch = "aarch64", target_feature(enable = "neon"))] #[cfg_attr(target_arch = "x86_64", target_feature(enable = "sse2"))] pub unsafe fn format_escaped_str_impl_128( - odptr: *mut u8, + odst: *mut u8, value_ptr: *const u8, value_len: usize, ) -> usize { - const STRIDE: usize = 16; + const STRIDE: usize = STRIDE_128_BIT; const STRIDE_SATURATION: u32 = u16::MAX as u32; type StrVector = core::simd::u8x16; - impl_format_simd!(odptr, value_ptr, value_len); + let mut dst = odst; + let mut src = value_ptr; + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + if value_len < STRIDE_128_BIT { + impl_format_scalar!(dst, src, value_len) + } else { + impl_format_simd_generic_128!(dst, src, value_len); + } + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + dst as usize - odst as usize } #[inline(never)] #[cfg(not(feature = "unstable-simd"))] pub unsafe fn format_escaped_str_scalar( - odptr: *mut u8, + odst: *mut u8, value_ptr: *const u8, value_len: usize, ) -> usize { - unsafe { - let mut dst = odptr; - let mut src = value_ptr; - let mut nb = value_len; - - *dst = b'"'; - dst = dst.add(1); - - let mut clean = 0; - while clean <= value_len.saturating_sub(8) { - let mut escapes = 0; - escapes |= *NEED_ESCAPED.get_unchecked(*(src.add(clean)) as usize); - escapes |= *NEED_ESCAPED.get_unchecked(*(src.add(clean + 1)) as usize); - escapes |= *NEED_ESCAPED.get_unchecked(*(src.add(clean + 2)) as usize); - escapes |= *NEED_ESCAPED.get_unchecked(*(src.add(clean + 3)) as usize); - escapes |= *NEED_ESCAPED.get_unchecked(*(src.add(clean + 4)) as usize); - escapes |= *NEED_ESCAPED.get_unchecked(*(src.add(clean + 5)) as usize); - escapes |= *NEED_ESCAPED.get_unchecked(*(src.add(clean + 6)) as usize); - escapes |= *NEED_ESCAPED.get_unchecked(*(src.add(clean + 7)) as usize); - if unlikely!(escapes != 0) { - break; - } - clean += 8; - } - if clean > 0 { - core::ptr::copy_nonoverlapping(src, dst, clean); - nb -= clean; - src = src.add(clean); - dst = dst.add(clean); - } - for _ in 0..nb { - core::ptr::write(dst, *(src)); - src = src.add(1); - dst = dst.add(1); - if unlikely!(NEED_ESCAPED[*(src.sub(1)) as usize] != 0) { - let escape = QUOTE_TAB[*(src.sub(1)) as usize]; - core::ptr::copy_nonoverlapping(escape.0.as_ptr(), dst.sub(1), 6); - dst = dst.add(escape.1 as usize - 1); - } - } + let mut dst = odst; + let mut src = value_ptr; - *dst = b'"'; - dst = dst.add(1); + core::ptr::write(dst, b'"'); + dst = dst.add(1); - dst as usize - odptr as usize - } + impl_format_scalar!(odst, value_ptr, value_len); + + core::ptr::write(dst, b'"'); + dst = dst.add(1); + + dst as usize - odst as usize } -#[cfg(not(feature = "unstable-simd"))] const NEED_ESCAPED: [u8; 256] = [ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,