diff --git a/libxnd/shape.c b/libxnd/shape.c index 8934095..1c401bb 100644 --- a/libxnd/shape.c +++ b/libxnd/shape.c @@ -86,7 +86,7 @@ zero_in_shape(const ndt_ndarray_t *x) } static void -init_contiguous_strides(ndt_ndarray_t *dest, const ndt_ndarray_t *src) +init_contiguous_c_strides(ndt_ndarray_t *dest, const ndt_ndarray_t *src) { int64_t q; int64_t i; @@ -102,13 +102,43 @@ init_contiguous_strides(ndt_ndarray_t *dest, const ndt_ndarray_t *src) } } +static void +init_contiguous_f_strides(ndt_ndarray_t *dest, const ndt_ndarray_t *src) +{ + int64_t q; + int64_t i; + + if (src->ndim == 0 && dest->ndim == 0) { + return; + } + + q = 1; + for (i = 0; i < dest->ndim; i++) { + dest->steps[i] = q; + q *= dest->shape[i]; + } +} + xnd_t -xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, ndt_context_t *ctx) +xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, char order, + ndt_context_t *ctx) { const ndt_t *t = x->type; ndt_ndarray_t src, dest; int64_t p, q; int ret; + int use_fortran = 0; + + if (order == 'F') { + use_fortran = 1; + } + else if (order == 'A') { + use_fortran = ndt_is_f_contiguous(t); + } + else if (order != 'C') { + ndt_err_format(ctx, NDT_ValueError, "'order' must be 'C', 'F' or 'A'"); + return xnd_error; + } if (ndt_as_ndarray(&src, t, ctx) < 0) { return xnd_error; @@ -141,12 +171,15 @@ xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, ndt_context_t *ctx) else if (zero_in_shape(&dest)) { ; } - else if (ndt_is_c_contiguous(t) || ndt_is_c_contiguous(t)) { - init_contiguous_strides(&dest, &src); + else if (!use_fortran && ndt_is_c_contiguous(t)) { + init_contiguous_c_strides(&dest, &src); + } + else if (use_fortran && ndt_is_f_contiguous(t)) { + init_contiguous_f_strides(&dest, &src); } else { ret = xnd_nocopy_reshape(dest.shape, dest.steps, dest.ndim, - src.shape, src.steps, src.ndim, 0); + src.shape, src.steps, src.ndim, use_fortran); if (!ret) { ndt_err_format(ctx, NDT_ValueError, "inplace reshape not possible"); return xnd_error; diff --git a/libxnd/xnd.h b/libxnd/xnd.h index 0c533a0..115be31 100644 --- a/libxnd/xnd.h +++ b/libxnd/xnd.h @@ -184,7 +184,7 @@ XND_API xnd_t xnd_subtree(const xnd_t *x, const xnd_index_t indices[], int len, XND_API xnd_t xnd_subscript(const xnd_t *x, const xnd_index_t indices[], int len, ndt_context_t *ctx); -XND_API xnd_t xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, ndt_context_t *ctx); +XND_API xnd_t xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, char order, ndt_context_t *ctx); XND_API xnd_t *xnd_split(const xnd_t *x, int64_t *n, int max_outer, ndt_context_t *ctx); diff --git a/python/test_xnd.py b/python/test_xnd.py index bf19145..9c0c8cb 100644 --- a/python/test_xnd.py +++ b/python/test_xnd.py @@ -3133,6 +3133,18 @@ def test_readonly(self): check_copy_contiguous(self, y) +class TestReshape(XndTestCase): + + def test_reshape_api(self): + x = xnd([[1,2,3], [4,5,6]], type="!2 * 3 * float32") + self.assertRaises(ValueError, x.reshape, 2**32, 2**32) + + def test_reshape_fortran(self): + x = xnd([[1,2,3], [4,5,6]], type="!2 * 3 * float32") + y = x.reshape(3,2,order='F') + self.assertEqual(y, [[1,5], [4,3], [2,6]]) + + class TestSplit(XndTestCase): def test_split(self): @@ -3587,6 +3599,7 @@ def test_transpose_and_reshape(self): TestAPI, TestRepr, TestBuffer, + TestReshape, TestSplit, TestTranspose, TestView, diff --git a/python/xnd/_xnd.c b/python/xnd/_xnd.c index 2d990ff..130ddfa 100644 --- a/python/xnd/_xnd.c +++ b/python/xnd/_xnd.c @@ -2126,6 +2126,7 @@ pyxnd_reshape(PyObject *self, PyObject *args, PyObject *kwds) PyObject *tuple = NULL; PyObject *order = Py_None; int64_t shape[NDT_MAX_DIM]; + char ord = 'C'; Py_ssize_t n; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &tuple, @@ -2133,10 +2134,14 @@ pyxnd_reshape(PyObject *self, PyObject *args, PyObject *kwds) return NULL; } - if (order && order != Py_None) { - PyErr_SetString(PyExc_NotImplementedError, - "'order' argument is not implemented"); - return NULL; + if (order != Py_None) { + const char *c = PyUnicode_AsUTF8(order); + if (strlen(c) != 1) { + PyErr_SetString(PyExc_TypeError, + "'order' argument must be a 'C', 'F' or 'A'"); + return NULL; + } + ord = c[0]; } if (!PyTuple_Check(tuple)) { @@ -2162,7 +2167,7 @@ pyxnd_reshape(PyObject *self, PyObject *args, PyObject *kwds) } } - xnd_t view = xnd_reshape(XND(self), shape, n, &ctx); + xnd_t view = xnd_reshape(XND(self), shape, n, ord, &ctx); if (xnd_err_occurred(&view)) { return seterr(&ctx); }