Skip to content

Commit

Permalink
Add 'order' argument to xnd.reshape().
Browse files Browse the repository at this point in the history
  • Loading branch information
skrah committed Mar 15, 2019
1 parent a0bb1ae commit 0956c62
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 11 deletions.
43 changes: 38 additions & 5 deletions libxnd/shape.c
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ zero_in_shape(const ndt_ndarray_t *x)
}

static void
init_contiguous_strides(ndt_ndarray_t *dest, const ndt_ndarray_t *src)
init_contiguous_c_strides(ndt_ndarray_t *dest, const ndt_ndarray_t *src)
{
int64_t q;
int64_t i;
Expand All @@ -102,13 +102,43 @@ init_contiguous_strides(ndt_ndarray_t *dest, const ndt_ndarray_t *src)
}
}

static void
init_contiguous_f_strides(ndt_ndarray_t *dest, const ndt_ndarray_t *src)
{
int64_t q;
int64_t i;

if (src->ndim == 0 && dest->ndim == 0) {
return;
}

q = 1;
for (i = 0; i < dest->ndim; i++) {
dest->steps[i] = q;
q *= dest->shape[i];
}
}

xnd_t
xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, ndt_context_t *ctx)
xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, char order,
ndt_context_t *ctx)
{
const ndt_t *t = x->type;
ndt_ndarray_t src, dest;
int64_t p, q;
int ret;
int use_fortran = 0;

if (order == 'F') {
use_fortran = 1;
}
else if (order == 'A') {
use_fortran = ndt_is_f_contiguous(t);
}
else if (order != 'C') {
ndt_err_format(ctx, NDT_ValueError, "'order' must be 'C', 'F' or 'A'");
return xnd_error;
}

if (ndt_as_ndarray(&src, t, ctx) < 0) {
return xnd_error;
Expand Down Expand Up @@ -141,12 +171,15 @@ xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, ndt_context_t *ctx)
else if (zero_in_shape(&dest)) {
;
}
else if (ndt_is_c_contiguous(t) || ndt_is_c_contiguous(t)) {
init_contiguous_strides(&dest, &src);
else if (!use_fortran && ndt_is_c_contiguous(t)) {
init_contiguous_c_strides(&dest, &src);
}
else if (use_fortran && ndt_is_f_contiguous(t)) {
init_contiguous_f_strides(&dest, &src);
}
else {
ret = xnd_nocopy_reshape(dest.shape, dest.steps, dest.ndim,
src.shape, src.steps, src.ndim, 0);
src.shape, src.steps, src.ndim, use_fortran);
if (!ret) {
ndt_err_format(ctx, NDT_ValueError, "inplace reshape not possible");
return xnd_error;
Expand Down
2 changes: 1 addition & 1 deletion libxnd/xnd.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ XND_API xnd_t xnd_subtree(const xnd_t *x, const xnd_index_t indices[], int len,
XND_API xnd_t xnd_subscript(const xnd_t *x, const xnd_index_t indices[], int len,
ndt_context_t *ctx);

XND_API xnd_t xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, ndt_context_t *ctx);
XND_API xnd_t xnd_reshape(const xnd_t *x, int64_t shape[], int ndim, char order, ndt_context_t *ctx);

XND_API xnd_t *xnd_split(const xnd_t *x, int64_t *n, int max_outer, ndt_context_t *ctx);

Expand Down
13 changes: 13 additions & 0 deletions python/test_xnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3133,6 +3133,18 @@ def test_readonly(self):
check_copy_contiguous(self, y)


class TestReshape(XndTestCase):

def test_reshape_api(self):
x = xnd([[1,2,3], [4,5,6]], type="!2 * 3 * float32")
self.assertRaises(ValueError, x.reshape, 2**32, 2**32)

def test_reshape_fortran(self):
x = xnd([[1,2,3], [4,5,6]], type="!2 * 3 * float32")
y = x.reshape(3,2,order='F')
self.assertEqual(y, [[1,5], [4,3], [2,6]])


class TestSplit(XndTestCase):

def test_split(self):
Expand Down Expand Up @@ -3587,6 +3599,7 @@ def test_transpose_and_reshape(self):
TestAPI,
TestRepr,
TestBuffer,
TestReshape,
TestSplit,
TestTranspose,
TestView,
Expand Down
15 changes: 10 additions & 5 deletions python/xnd/_xnd.c
Original file line number Diff line number Diff line change
Expand Up @@ -2126,17 +2126,22 @@ pyxnd_reshape(PyObject *self, PyObject *args, PyObject *kwds)
PyObject *tuple = NULL;
PyObject *order = Py_None;
int64_t shape[NDT_MAX_DIM];
char ord = 'C';
Py_ssize_t n;

if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &tuple,
&order)) {
return NULL;
}

if (order && order != Py_None) {
PyErr_SetString(PyExc_NotImplementedError,
"'order' argument is not implemented");
return NULL;
if (order != Py_None) {
const char *c = PyUnicode_AsUTF8(order);
if (strlen(c) != 1) {
PyErr_SetString(PyExc_TypeError,
"'order' argument must be a 'C', 'F' or 'A'");
return NULL;
}
ord = c[0];
}

if (!PyTuple_Check(tuple)) {
Expand All @@ -2162,7 +2167,7 @@ pyxnd_reshape(PyObject *self, PyObject *args, PyObject *kwds)
}
}

xnd_t view = xnd_reshape(XND(self), shape, n, &ctx);
xnd_t view = xnd_reshape(XND(self), shape, n, ord, &ctx);
if (xnd_err_occurred(&view)) {
return seterr(&ctx);
}
Expand Down

0 comments on commit 0956c62

Please sign in to comment.