From 520525860b2ecbbf82e7dc154cc0f8e5cdfa61f7 Mon Sep 17 00:00:00 2001 From: ijl Date: Thu, 18 Jan 2024 13:46:20 +0000 Subject: [PATCH] Escape str using SIMD --- ci/azure-macos.yml | 2 +- src/lib.rs | 1 + src/serialize/writer/json.rs | 79 +++++++++++++++++++ src/serialize/writer/mod.rs | 3 + src/serialize/writer/simd.rs | 149 +++++++++++++++++++++++++++++++++++ src/util.rs | 16 ++++ test/test_type.py | 90 ++++++++++++++++++--- 7 files changed, 330 insertions(+), 10 deletions(-) create mode 100644 src/serialize/writer/simd.rs diff --git a/ci/azure-macos.yml b/ci/azure-macos.yml index 9f9ebc6d..f490383f 100644 --- a/ci/azure-macos.yml +++ b/ci/azure-macos.yml @@ -23,7 +23,7 @@ steps: PATH=$HOME/.cargo/bin:$PATH \ MACOSX_DEPLOYMENT_TARGET=$(macosx_deployment_target) \ PYO3_CROSS_LIB_DIR=$(python -c "import sysconfig;print(sysconfig.get_config_var('LIBDIR'))") \ - maturin build --release --strip --features=encoding_rs/simd-accel,no-panic,yyjson --interpreter $(interpreter) --target=universal2-apple-darwin + maturin build --release --strip --features=encoding_rs/simd-accel,no-panic,unstable-simd,yyjson --interpreter $(interpreter) --target=universal2-apple-darwin env: CC: "clang" CFLAGS: "-O2 -fstrict-aliasing -flto=full" diff --git a/src/lib.rs b/src/lib.rs index 2e0a86c0..bc512174 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ #![cfg_attr(feature = "optimize", feature(optimize_attribute))] #![cfg_attr(feature = "strict_provenance", feature(strict_provenance))] #![cfg_attr(feature = "strict_provenance", warn(fuzzy_provenance_casts))] +#![cfg_attr(feature = "unstable-simd", feature(portable_simd))] #![allow(unknown_lints)] // internal_features #![allow(internal_features)] // core_intrinsics #![allow(unused_unsafe)] diff --git a/src/serialize/writer/json.rs b/src/serialize/writer/json.rs index eec755a3..35f429cd 100644 --- a/src/serialize/writer/json.rs +++ b/src/serialize/writer/json.rs @@ -565,6 +565,85 @@ where } } +#[cfg(all( + feature = "unstable-simd", + target_arch = "x86_64", + target_feature = "avx2" +))] +#[inline(always)] +fn format_escaped_str(writer: &mut W, value: &str) -> io::Result<()> +where + W: ?Sized + io::Write + WriteExt, +{ + unsafe { + let num_reserved_bytes = value.len() * 8 + 32 + 3; + writer.reserve(num_reserved_bytes); + + let written = crate::serialize::writer::simd::format_escaped_str_impl_256( + writer.as_mut_buffer_ptr(), + value.as_bytes().as_ptr(), + value.len(), + ); + writer.set_written(written); + } + Ok(()) +} + +#[cfg(all( + feature = "unstable-simd", + target_arch = "x86_64", + not(target_feature = "avx2") +))] +#[inline(always)] +fn format_escaped_str(writer: &mut W, value: &str) -> io::Result<()> +where + W: ?Sized + io::Write + WriteExt, +{ + unsafe { + let num_reserved_bytes = value.len() * 8 + 32 + 3; + writer.reserve(num_reserved_bytes); + + if std::is_x86_feature_detected!("avx2") { + let written = crate::serialize::writer::simd::format_escaped_str_impl_256( + writer.as_mut_buffer_ptr(), + value.as_bytes().as_ptr(), + value.len(), + ); + writer.set_written(written); + } else { + let written = crate::serialize::writer::simd::format_escaped_str_impl_128( + writer.as_mut_buffer_ptr(), + value.as_bytes().as_ptr(), + value.len(), + ); + writer.set_written(written); + } + } + Ok(()) +} + +#[cfg(all(feature = "unstable-simd", not(target_arch = "x86_64")))] +#[inline(always)] +fn format_escaped_str(writer: &mut W, value: &str) -> io::Result<()> +where + W: ?Sized + io::Write + WriteExt, +{ + unsafe { + let num_reserved_bytes = value.len() * 8 + 32 + 3; + writer.reserve(num_reserved_bytes); + + let written = crate::serialize::writer::simd::format_escaped_str_impl_128( + writer.as_mut_buffer_ptr(), + value.as_bytes().as_ptr(), + value.len(), + ); + + writer.set_written(written); + } + Ok(()) +} + +#[cfg(not(feature = "unstable-simd"))] #[inline(always)] fn format_escaped_str(writer: &mut W, value: &str) -> io::Result<()> where diff --git a/src/serialize/writer/mod.rs b/src/serialize/writer/mod.rs index 47fdf503..fc489426 100644 --- a/src/serialize/writer/mod.rs +++ b/src/serialize/writer/mod.rs @@ -1,9 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 mod byteswriter; +#[cfg(not(feature = "unstable-simd"))] mod escape; mod formatter; mod json; +#[cfg(feature = "unstable-simd")] +mod simd; pub use byteswriter::{BytesWriter, WriteExt}; pub use json::{to_writer, to_writer_pretty}; diff --git a/src/serialize/writer/simd.rs b/src/serialize/writer/simd.rs new file mode 100644 index 00000000..9d0a9e13 --- /dev/null +++ b/src/serialize/writer/simd.rs @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2023-2024 liuq19, ijl +// adapted from sonic-rs' src/util/string.rs + +use std::simd::cmp::{SimdPartialEq, SimdPartialOrd}; + +macro_rules! impl_escape_unchecked { + ($src:expr, $dst:expr, $nb:expr, $omask:expr, $cn:expr) => { + $nb -= $cn; + $dst = $dst.add($cn); + $src = $src.add($cn); + let mut mask = $omask << $cn; + loop { + $nb -= 1; + mask = mask << 1; + let replacement = if *($src) == b'"' { + (*b"\\\"\0\0\0\0\0\0", 2) + } else if *($src) == b'\\' { + (*b"\\\\\0\0\0\0\0\0", 2) + } else { + match *($src) { + 0 => (*b"\\u0000\0\0", 6), + 1 => (*b"\\u0001\0\0", 6), + 2 => (*b"\\u0002\0\0", 6), + 3 => (*b"\\u0003\0\0", 6), + 4 => (*b"\\u0004\0\0", 6), + 5 => (*b"\\u0005\0\0", 6), + 6 => (*b"\\u0006\0\0", 6), + 7 => (*b"\\u0007\0\0", 6), + 8 => (*b"\\b\0\0\0\0\0\0", 2), + 9 => (*b"\\t\0\0\0\0\0\0", 2), + 10 => (*b"\\n\0\0\0\0\0\0", 2), + 11 => (*b"\\u000b\0\0", 6), + 12 => (*b"\\f\0\0\0\0\0\0", 2), + 13 => (*b"\\r\0\0\0\0\0\0", 2), + 14 => (*b"\\u000e\0\0", 6), + 15 => (*b"\\u000f\0\0", 6), + 16 => (*b"\\u0010\0\0", 6), + 17 => (*b"\\u0011\0\0", 6), + 18 => (*b"\\u0012\0\0", 6), + 19 => (*b"\\u0013\0\0", 6), + 20 => (*b"\\u0014\0\0", 6), + 21 => (*b"\\u0015\0\0", 6), + 22 => (*b"\\u0016\0\0", 6), + 23 => (*b"\\u0017\0\0", 6), + 24 => (*b"\\u0018\0\0", 6), + 25 => (*b"\\u0019\0\0", 6), + 26 => (*b"\\u001a\0\0", 6), + 27 => (*b"\\u001b\0\0", 6), + 28 => (*b"\\u001c\0\0", 6), + 29 => (*b"\\u001d\0\0", 6), + 30 => (*b"\\u001e\0\0", 6), + 31 => (*b"\\u001f\0\0", 6), + _ => unreachable!(), + } + }; + std::ptr::copy_nonoverlapping(replacement.0.as_ptr(), $dst, 8); + $dst = $dst.add(replacement.1 as usize); + $src = $src.add(1); + if likely!(mask & (1 << (STRIDE - 1)) != 1) { + break; + } + } + }; +} +macro_rules! impl_format_simd { + ($odptr:expr, $value_ptr:expr, $value_len:expr) => { + let mut dptr = $odptr; + let dstart = $odptr; + let mut sptr = $value_ptr; + let mut nb: usize = $value_len; + + let blash = StrVector::from_array([b'\\'; STRIDE]); + let quote = StrVector::from_array([b'"'; STRIDE]); + let x20 = StrVector::from_array([32; STRIDE]); + + unsafe { + *dptr = b'"'; + dptr = dptr.add(1); + + while nb >= STRIDE { + let v = StrVector::from_slice(std::slice::from_raw_parts(sptr, STRIDE)); + v.copy_to_slice(std::slice::from_raw_parts_mut(dptr, STRIDE)); + let mask = + (v.simd_eq(blash) | v.simd_eq(quote) | v.simd_lt(x20)).to_bitmask() as u32; + + if likely!(mask == 0) { + nb -= STRIDE; + dptr = dptr.add(STRIDE); + sptr = sptr.add(STRIDE); + } else { + let cn = mask.trailing_zeros() as usize; + impl_escape_unchecked!(sptr, dptr, nb, mask, cn); + } + } + + while nb > 0 { + let v = StrVector::from_slice(std::slice::from_raw_parts(sptr, STRIDE)); + v.copy_to_slice(std::slice::from_raw_parts_mut(dptr, STRIDE)); + let mask = (v.simd_eq(blash) | v.simd_eq(quote) | v.simd_lt(x20)).to_bitmask() + as u32 + & (STRIDE_SATURATION >> (STRIDE - nb)); + if likely!(mask == 0) { + dptr = dptr.add(nb); + break; + } else { + let cn = mask.trailing_zeros() as usize; + impl_escape_unchecked!(sptr, dptr, nb, mask, cn); + } + } + + *dptr = b'"'; + dptr = dptr.add(1); + } + + return dptr as usize - dstart as usize; + }; +} + +#[inline(never)] +#[cfg(not(target_feature = "avx2"))] +#[cfg_attr(target_arch = "x86_64", cold)] +pub unsafe fn format_escaped_str_impl_128( + odptr: *mut u8, + value_ptr: *const u8, + value_len: usize, +) -> usize { + const STRIDE: usize = 16; + const STRIDE_SATURATION: u32 = u16::MAX as u32; + type StrVector = std::simd::u8x16; + + impl_format_simd!(odptr, value_ptr, value_len); +} + +#[cfg(target_arch = "x86_64")] +#[inline(never)] +#[cfg_attr(not(target_feature = "avx2"), target_feature(enable = "avx2"))] +#[cfg_attr(not(target_feature = "avx2"), target_feature(enable = "bmi2"))] +pub unsafe fn format_escaped_str_impl_256( + odptr: *mut u8, + value_ptr: *const u8, + value_len: usize, +) -> usize { + const STRIDE: usize = 32; + const STRIDE_SATURATION: u32 = u32::MAX; + type StrVector = std::simd::u8x32; + + impl_format_simd!(odptr, value_ptr, value_len); +} diff --git a/src/util.rs b/src/util.rs index 95ed6b33..a51dc633 100644 --- a/src/util.rs +++ b/src/util.rs @@ -70,6 +70,22 @@ macro_rules! unlikely { }; } +#[allow(unused_macros)] +#[cfg(feature = "intrinsics")] +macro_rules! likely { + ($exp:expr) => { + std::intrinsics::likely($exp) + }; +} + +#[allow(unused_macros)] +#[cfg(not(feature = "intrinsics"))] +macro_rules! likely { + ($exp:expr) => { + $exp + }; +} + macro_rules! nonnull { ($exp:expr) => { unsafe { std::ptr::NonNull::new_unchecked($exp) } diff --git a/test/test_type.py b/test/test_type.py index 0e04ee50..35ca8aeb 100644 --- a/test/test_type.py +++ b/test/test_type.py @@ -75,33 +75,105 @@ def test_str_ascii_control(self): assert orjson.loads(orjson.dumps(ref)) == ref assert orjson.loads(orjson.dumps(ref, option=orjson.OPT_INDENT_2)) == ref - def test_str_escape_0(self): + def test_str_escape_quote_0(self): assert orjson.dumps('"aaaaaaabb') == b'"\\"aaaaaaabb"' - def test_str_escape_1(self): + def test_str_escape_quote_1(self): assert orjson.dumps('a"aaaaaabb') == b'"a\\"aaaaaabb"' - def test_str_escape_2(self): + def test_str_escape_quote_2(self): assert orjson.dumps('aa"aaaaabb') == b'"aa\\"aaaaabb"' - def test_str_escape_3(self): + def test_str_escape_quote_3(self): assert orjson.dumps('aaa"aaaabb') == b'"aaa\\"aaaabb"' - def test_str_escape_4(self): + def test_str_escape_quote_4(self): assert orjson.dumps('aaaa"aaabb') == b'"aaaa\\"aaabb"' - def test_str_escape_5(self): + def test_str_escape_quote_5(self): assert orjson.dumps('aaaaa"aabb') == b'"aaaaa\\"aabb"' - def test_str_escape_6(self): + def test_str_escape_quote_6(self): assert orjson.dumps('aaaaaa"abb') == b'"aaaaaa\\"abb"' - def test_str_escape_7(self): + def test_str_escape_quote_7(self): assert orjson.dumps('aaaaaaa"bb') == b'"aaaaaaa\\"bb"' - def test_str_escape_8(self): + def test_str_escape_quote_8(self): assert orjson.dumps('aaaaaaaab"') == b'"aaaaaaaab\\""' + def test_str_escape_quote_multi(self): + assert ( + orjson.dumps('aa"aaaaabbbbbbbbbbbbbbbbbbbb"bb') + == b'"aa\\"aaaaabbbbbbbbbbbbbbbbbbbb\\"bb"' + ) + + def test_str_escape_backslash_0(self): + assert orjson.dumps("\\aaaaaaabb") == b'"\\\\aaaaaaabb"' + + def test_str_escape_backslash_1(self): + assert orjson.dumps("a\\aaaaaabb") == b'"a\\\\aaaaaabb"' + + def test_str_escape_backslash_2(self): + assert orjson.dumps("aa\\aaaaabb") == b'"aa\\\\aaaaabb"' + + def test_str_escape_backslash_3(self): + assert orjson.dumps("aaa\\aaaabb") == b'"aaa\\\\aaaabb"' + + def test_str_escape_backslash_4(self): + assert orjson.dumps("aaaa\\aaabb") == b'"aaaa\\\\aaabb"' + + def test_str_escape_backslash_5(self): + assert orjson.dumps("aaaaa\\aabb") == b'"aaaaa\\\\aabb"' + + def test_str_escape_backslash_6(self): + assert orjson.dumps("aaaaaa\\abb") == b'"aaaaaa\\\\abb"' + + def test_str_escape_backslash_7(self): + assert orjson.dumps("aaaaaaa\\bb") == b'"aaaaaaa\\\\bb"' + + def test_str_escape_backslash_8(self): + assert orjson.dumps("aaaaaaaab\\") == b'"aaaaaaaab\\\\"' + + def test_str_escape_backslash_multi(self): + assert ( + orjson.dumps("aa\\aaaaabbbbbbbbbbbbbbbbbbbb\\bb") + == b'"aa\\\\aaaaabbbbbbbbbbbbbbbbbbbb\\\\bb"' + ) + + def test_str_escape_x32_0(self): + assert orjson.dumps("\taaaaaaabb") == b'"\\taaaaaaabb"' + + def test_str_escape_x32_1(self): + assert orjson.dumps("a\taaaaaabb") == b'"a\\taaaaaabb"' + + def test_str_escape_x32_2(self): + assert orjson.dumps("aa\taaaaabb") == b'"aa\\taaaaabb"' + + def test_str_escape_x32_3(self): + assert orjson.dumps("aaa\taaaabb") == b'"aaa\\taaaabb"' + + def test_str_escape_x32_4(self): + assert orjson.dumps("aaaa\taaabb") == b'"aaaa\\taaabb"' + + def test_str_escape_x32_5(self): + assert orjson.dumps("aaaaa\taabb") == b'"aaaaa\\taabb"' + + def test_str_escape_x32_6(self): + assert orjson.dumps("aaaaaa\tabb") == b'"aaaaaa\\tabb"' + + def test_str_escape_x32_7(self): + assert orjson.dumps("aaaaaaa\tbb") == b'"aaaaaaa\\tbb"' + + def test_str_escape_x32_8(self): + assert orjson.dumps("aaaaaaaab\t") == b'"aaaaaaaab\\t"' + + def test_str_escape_x32_multi(self): + assert ( + orjson.dumps("aa\taaaaabbbbbbbbbbbbbbbbbbbb\tbb") + == b'"aa\\taaaaabbbbbbbbbbbbbbbbbbbb\\tbb"' + ) + def test_str_emoji(self): ref = "®️" assert orjson.loads(orjson.dumps(ref)) == ref