Skip to content

Commit

Permalink
numpy serialization rejects non-native endianness
Browse files Browse the repository at this point in the history
  • Loading branch information
ijl committed Apr 15, 2024
1 parent 62edcb5 commit ee81a96
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 3 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -864,8 +864,12 @@ b'"2021-01-01T00:00:00+00:00"'
If an array is not a contiguous C array, contains an unsupported datatype,
or contains a `numpy.datetime64` using an unsupported representation
(e.g., picoseconds), orjson falls through to `default`. In `default`,
`obj.tolist()` can be specified. If an array is malformed, which
is not expected, `orjson.JSONEncodeError` is raised.
`obj.tolist()` can be specified.

If an array is not in the native endianness, e.g., an array of big-endian values
on a little-endian system, `orjson.JSONEncodeError` is raised.

If an array is malformed, `orjson.JSONEncodeError` is raised.

This measures serializing 92MiB of JSON from an `numpy.ndarray` with
dimensions of `(50000, 100)` and `numpy.float64` values:
Expand Down
5 changes: 5 additions & 0 deletions src/serialize/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub enum SerializeError {
DictKeyInvalidType,
NumpyMalformed,
NumpyNotCContiguous,
NumpyNotNativeEndian,
NumpyUnsupportedDatatype,
UnsupportedType(NonNull<pyo3_ffi::PyObject>),
}
Expand Down Expand Up @@ -48,6 +49,10 @@ impl std::fmt::Display for SerializeError {
f,
"numpy array is not C contiguous; use ndarray.tolist() in default"
),
SerializeError::NumpyNotNativeEndian => write!(
f,
"numpy array is not native-endianness"
),
SerializeError::NumpyUnsupportedDatatype => {
write!(f, "unsupported datatype in numpy array")
}
Expand Down
13 changes: 12 additions & 1 deletion src/serialize/per_type/numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ impl<'a> Serialize for NumpySerializer<'a> {
Err(PyArrayError::NotContiguous) => {
err!(SerializeError::NumpyNotCContiguous)
}
Err(PyArrayError::NotNativeEndian) => {
err!(SerializeError::NumpyNotNativeEndian)
}
Err(PyArrayError::UnsupportedDataType) => {
err!(SerializeError::NumpyUnsupportedDatatype)
}
Expand Down Expand Up @@ -101,6 +104,9 @@ pub struct PyCapsule {

// https://docs.scipy.org/doc/numpy/reference/arrays.interface.html#c.__array_struct__

const NPY_ARRAY_C_CONTIGUOUS: c_int = 0x1;
const NPY_ARRAY_NOTSWAPPED: c_int = 0x200;

#[repr(C)]
pub struct PyArrayInterface {
pub two: c_int,
Expand Down Expand Up @@ -154,9 +160,11 @@ impl ItemType {
}
}
}

pub enum PyArrayError {
Malformed,
NotContiguous,
NotNativeEndian,
UnsupportedDataType,
}

Expand Down Expand Up @@ -187,9 +195,12 @@ impl NumpyArray {
if unsafe { (*array).two != 2 } {
ffi!(Py_DECREF(capsule));
Err(PyArrayError::Malformed)
} else if unsafe { (*array).flags } & 0x1 != 0x1 {
} else if unsafe { (*array).flags } & NPY_ARRAY_C_CONTIGUOUS != NPY_ARRAY_C_CONTIGUOUS {
ffi!(Py_DECREF(capsule));
Err(PyArrayError::NotContiguous)
} else if unsafe { (*array).flags } & NPY_ARRAY_NOTSWAPPED != NPY_ARRAY_NOTSWAPPED {
ffi!(Py_DECREF(capsule));
Err(PyArrayError::NotNativeEndian)
} else {
let num_dimensions = unsafe { (*array).nd as usize };
if num_dimensions == 0 {
Expand Down
11 changes: 11 additions & 0 deletions test/test_numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: (Apache-2.0 OR MIT)


import sys

import pytest

import orjson
Expand Down Expand Up @@ -864,3 +866,12 @@ def test_numpy_float64(self):
[-1.7976931348623157e308, 1.7976931348623157e308], numpy.float64
)
)


@pytest.mark.skipif(numpy is None, reason="numpy is not installed")
class NumpyEndianness:
def test_numpy_array_dimension_zero(self):
wrong_endianness = ">" if sys.byteorder == "little" else "<"
array = numpy.array([0, 1, 0.4, 5.7], dtype=f"{wrong_endianness}f8")
with pytest.raises(orjson.JSONEncodeError):
orjson.dumps(array, option=orjson.OPT_SERIALIZE_NUMPY)

0 comments on commit ee81a96

Please sign in to comment.