diff --git a/Cargo.lock b/Cargo.lock index 0130e1f7..1c92a0e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -141,6 +141,16 @@ dependencies = [ "wasi", ] +[[package]] +name = "half" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5eceaaeec696539ddaf7b333340f1af35a5aa87ae3e4f3ead0532f72affab2e" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "itoa" version = "1.0.11" @@ -201,6 +211,7 @@ dependencies = [ "chrono", "compact_str", "encoding_rs", + "half", "itoa", "itoap", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index b001eb4c..428197fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ bytecount = { version = "^0.6.7", default_features = false, features = ["runtime chrono = { version = "=0.4.34", default_features = false } compact_str = { version = "0.7", default_features = false, features = ["serde"] } encoding_rs = { version = "0.8", default_features = false } +half = { version = "2", default_features = false, features = ["std"] } itoa = { version = "1", default_features = false } itoap = { version = "1", features = ["std", "simd"] } once_cell = { version = "1", default_features = false, features = ["race"] } diff --git a/README.md b/README.md index 6d1f0644..73c17131 100644 --- a/README.md +++ b/README.md @@ -813,7 +813,7 @@ JSONEncodeError: Integer exceeds 53-bit range ### numpy orjson natively serializes `numpy.ndarray` and individual -`numpy.float64`, `numpy.float32`, +`numpy.float64`, `numpy.float32`, `numpy.float16` (`numpy.half`), `numpy.int64`, `numpy.int32`, `numpy.int16`, `numpy.int8`, `numpy.uint64`, `numpy.uint32`, `numpy.uint16`, `numpy.uint8`, `numpy.uintp`, `numpy.intp`, `numpy.datetime64`, and `numpy.bool` diff --git a/script/pynumpy b/script/pynumpy index 102a395d..bfaeaa31 100755 --- a/script/pynumpy +++ b/script/pynumpy @@ -23,25 +23,42 @@ os.sched_setaffinity(os.getpid(), {0, 1}) kind = sys.argv[1] if len(sys.argv) >= 1 else "" -if kind == "int32": - array = numpy.random.randint(((2**31) - 1), size=(100000, 100), dtype=numpy.int32) + +if kind == "float16": + dtype = numpy.float16 + array = numpy.random.random(size=(50000, 100)).astype(dtype) +elif kind == "float32": + dtype = numpy.float32 + array = numpy.random.random(size=(50000, 100)).astype(dtype) elif kind == "float64": + dtype = numpy.float64 array = numpy.random.random(size=(50000, 100)) assert array.dtype == numpy.float64 elif kind == "bool": + dtype = numpy.bool array = numpy.random.choice((True, False), size=(100000, 200)) elif kind == "int8": - array = numpy.random.randint(((2**7) - 1), size=(100000, 100), dtype=numpy.int8) + dtype = numpy.int8 + array = numpy.random.randint(((2**7) - 1), size=(100000, 100), dtype=dtype) elif kind == "int16": - array = numpy.random.randint(((2**15) - 1), size=(100000, 100), dtype=numpy.int16) + dtype = numpy.int16 + array = numpy.random.randint(((2**15) - 1), size=(100000, 100), dtype=dtype) elif kind == "int32": - array = numpy.random.randint(((2**31) - 1), size=(100000, 100), dtype=numpy.int32) + dtype = numpy.int32 + array = numpy.random.randint(((2**31) - 1), size=(100000, 100), dtype=dtype) elif kind == "uint8": - array = numpy.random.randint(((2**8) - 1), size=(100000, 100), dtype=numpy.uint8) + dtype = numpy.uint8 + array = numpy.random.randint(((2**8) - 1), size=(100000, 100), dtype=dtype) elif kind == "uint16": - array = numpy.random.randint(((2**16) - 1), size=(100000, 100), dtype=numpy.uint16) + dtype = numpy.uint16 + array = numpy.random.randint(((2**16) - 1), size=(100000, 100), dtype=dtype) +elif kind == "uint32": + dtype = numpy.uint32 + array = numpy.random.randint(((2**31) - 1), size=(100000, 100), dtype=dtype) else: - print("usage: pynumpy (bool|int16|int32|float64|int8|uint8|uint16)") + print( + "usage: pynumpy (bool|int16|int32|float16|float32|float64|int8|uint8|uint16|uint32)" + ) sys.exit(1) proc = psutil.Process() @@ -49,6 +66,7 @@ proc = psutil.Process() def default(__obj): if isinstance(__obj, numpy.ndarray): return __obj.tolist() + raise TypeError headers = ("Library", "Latency (ms)", "RSS diff (MiB)", "vs. orjson") @@ -92,7 +110,7 @@ def per_iter_latency(val): def test_correctness(func): - return orjson.loads(func()) == array.tolist() + return numpy.array_equal(array, numpy.array(orjson.loads(func()), dtype=dtype)) table = [] diff --git a/src/serialize/per_type/numpy.rs b/src/serialize/per_type/numpy.rs index d6515f98..9f5577cf 100644 --- a/src/serialize/per_type/numpy.rs +++ b/src/serialize/per_type/numpy.rs @@ -117,6 +117,7 @@ pub struct PyArrayInterface { pub enum ItemType { BOOL, DATETIME64(NumpyDatetimeUnit), + F16, F32, F64, I8, @@ -137,6 +138,7 @@ impl ItemType { let unit = NumpyDatetimeUnit::from_pyobject(ptr); Some(ItemType::DATETIME64(unit)) } + (102, 2) => Some(ItemType::F16), (102, 4) => Some(ItemType::F32), (102, 8) => Some(ItemType::F64), (105, 1) => Some(ItemType::I8), @@ -312,6 +314,10 @@ impl Serialize for NumpyArray { NumpyF32Array::new(slice!(self.data() as *const f32, self.num_items())) .serialize(serializer) } + ItemType::F16 => { + NumpyF16Array::new(slice!(self.data() as *const u16, self.num_items())) + .serialize(serializer) + } ItemType::U64 => { NumpyU64Array::new(slice!(self.data() as *const u64, self.num_items())) .serialize(serializer) @@ -439,6 +445,49 @@ impl Serialize for DataTypeF32 { } } +#[repr(transparent)] +struct NumpyF16Array<'a> { + data: &'a [u16], +} + +impl<'a> NumpyF16Array<'a> { + fn new(data: &'a [u16]) -> Self { + Self { data } + } +} + +impl<'a> Serialize for NumpyF16Array<'a> { + #[cold] + #[cfg_attr(feature = "optimize", optimize(size))] + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut seq = serializer.serialize_seq(None).unwrap(); + for &each in self.data.iter() { + seq.serialize_element(&DataTypeF16 { obj: each }).unwrap(); + } + seq.end() + } +} + +#[repr(transparent)] +struct DataTypeF16 { + obj: u16, +} + +impl Serialize for DataTypeF16 { + #[cold] + #[cfg_attr(feature = "optimize", optimize(size))] + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let as_f16 = half::f16::from_bits(self.obj); + serializer.serialize_f32(as_f16.to_f32()) + } +} + #[repr(transparent)] struct NumpyU64Array<'a> { data: &'a [u64], @@ -826,6 +875,8 @@ impl Serialize for NumpyScalar { (*(self.ptr as *mut NumpyFloat64)).serialize(serializer) } else if ob_type == scalar_types.float32 { (*(self.ptr as *mut NumpyFloat32)).serialize(serializer) + } else if ob_type == scalar_types.float16 { + (*(self.ptr as *mut NumpyFloat16)).serialize(serializer) } else if ob_type == scalar_types.int64 { (*(self.ptr as *mut NumpyInt64)).serialize(serializer) } else if ob_type == scalar_types.int32 { @@ -994,6 +1045,36 @@ impl Serialize for NumpyUint64 { } } +#[repr(C)] +pub struct NumpyFloat16 { + ob_refcnt: Py_ssize_t, + ob_type: *mut PyTypeObject, + value: [u8; 2], +} + +impl Serialize for NumpyFloat16 { + #[cfg(target_endian = "little")] + #[cold] + #[cfg_attr(feature = "optimize", optimize(size))] + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let as_f16 = half::f16::from_le_bytes(self.value); + serializer.serialize_f32(as_f16.to_f32()) + } + #[cfg(target_endian = "big")] + #[cold] + #[cfg_attr(feature = "optimize", optimize(size))] + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let as_f16 = half::f16::from_be_bytes(self.value); + serializer.serialize_f32(as_f16.to_f32()) + } +} + #[repr(C)] pub struct NumpyFloat32 { ob_refcnt: Py_ssize_t, diff --git a/src/typeref.rs b/src/typeref.rs index eb54306c..0fe13f9f 100644 --- a/src/typeref.rs +++ b/src/typeref.rs @@ -16,6 +16,7 @@ pub struct NumpyTypes { pub array: *mut PyTypeObject, pub float64: *mut PyTypeObject, pub float32: *mut PyTypeObject, + pub float16: *mut PyTypeObject, pub int64: *mut PyTypeObject, pub int32: *mut PyTypeObject, pub int16: *mut PyTypeObject, @@ -239,6 +240,7 @@ pub fn load_numpy_types() -> Box>> { let numpy_module_dict = PyObject_GenericGetDict(numpy, null_mut()); let types = Box::new(NumpyTypes { array: look_up_numpy_type(numpy_module_dict, "ndarray\0"), + float16: look_up_numpy_type(numpy_module_dict, "half\0"), float32: look_up_numpy_type(numpy_module_dict, "float32\0"), float64: look_up_numpy_type(numpy_module_dict, "float64\0"), int8: look_up_numpy_type(numpy_module_dict, "int8\0"), diff --git a/test/test_numpy.py b/test/test_numpy.py index e1b60c5c..6460fdef 100644 --- a/test/test_numpy.py +++ b/test/test_numpy.py @@ -12,7 +12,9 @@ def numpy_default(obj): - return obj.tolist() + if isinstance(obj, numpy.ndarray): + return obj.tolist() + raise TypeError @pytest.mark.skipif(numpy is None, reason="numpy is not installed") @@ -114,6 +116,94 @@ def test_numpy_array_d1_f32(self): == b"[1.0,3.4028235e38]" ) + def test_numpy_array_d1_f16(self): + assert ( + orjson.dumps( + numpy.array([-1.0, 0.0009765625, 1.0, 65504.0], numpy.float16), + option=orjson.OPT_SERIALIZE_NUMPY, + ) + == b"[-1.0,0.0009765625,1.0,65504.0]" + ) + + def test_numpy_array_f16_roundtrip(self): + ref = [ + -1.0, + -2.0, + 0.000000059604645, + 0.000060975552, + 0.00006103515625, + 0.0009765625, + 0.33325195, + 0.99951172, + 1.0, + 1.00097656, + 65504.0, + ] + obj = numpy.array(ref, numpy.float16) + serialized = orjson.dumps( + obj, + option=orjson.OPT_SERIALIZE_NUMPY, + ) + deserialized = numpy.array(orjson.loads(serialized), numpy.float16) + assert numpy.array_equal(obj, deserialized) + + def test_numpy_array_f16_edge(self): + assert ( + orjson.dumps( + numpy.array( + [ + numpy.inf, + numpy.NINF, + numpy.nan, + numpy.NZERO, + numpy.PZERO, + numpy.pi, + ], + numpy.float16, + ), + option=orjson.OPT_SERIALIZE_NUMPY, + ) + == b"[null,null,null,-0.0,0.0,3.140625]" + ) + + def test_numpy_array_f32_edge(self): + assert ( + orjson.dumps( + numpy.array( + [ + numpy.inf, + numpy.NINF, + numpy.nan, + numpy.NZERO, + numpy.PZERO, + numpy.pi, + ], + numpy.float32, + ), + option=orjson.OPT_SERIALIZE_NUMPY, + ) + == b"[null,null,null,-0.0,0.0,3.1415927]" + ) + + def test_numpy_array_f64_edge(self): + assert ( + orjson.dumps( + numpy.array( + [ + numpy.inf, + numpy.NINF, + numpy.nan, + numpy.NZERO, + numpy.PZERO, + numpy.pi, + ], + numpy.float64, + ), + option=orjson.OPT_SERIALIZE_NUMPY, + ) + == b"[null,null,null,-0.0,0.0,3.141592653589793]" + ) + def test_numpy_array_d1_f64(self): assert ( orjson.dumps( @@ -375,13 +465,10 @@ def test_numpy_array_non_contiguous_message(self): ) def test_numpy_array_unsupported_dtype(self): - array = numpy.array([[1, 2], [3, 4]], numpy.float16) # type: ignore + array = numpy.array([[1, 2], [3, 4]], numpy.csingle) # type: ignore with pytest.raises(orjson.JSONEncodeError) as cm: orjson.dumps(array, option=orjson.OPT_SERIALIZE_NUMPY) assert "unsupported datatype in numpy array" in str(cm) - assert orjson.dumps( - array, default=numpy_default, option=orjson.OPT_SERIALIZE_NUMPY - ) == orjson.dumps(array.tolist()) def test_numpy_array_d1(self): array = numpy.array([1])