Skip to content

Commit

Permalink
Refactoring getitem, fixing bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
perimosocordiae committed Jun 1, 2016
1 parent a181f92 commit dc34b55
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 58 deletions.
123 changes: 67 additions & 56 deletions sparray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def diagonal(self, offset=0, axis1=0, axis2=1):
dtype=self.indices.dtype)
if n < 0:
return SpArray([], [], shape=(0,), is_canonical=True)
return self._slice_ranges(ranges, (n,), inner=True)

flat_idx = combine_ranges(ranges, self.shape, n, inner=True)
return self._getitem_flatidx(flat_idx, (n,))

def setdiag(self, values, offset=0):
if self.ndim < 2:
Expand All @@ -160,7 +162,7 @@ def setdiag(self, values, offset=0):
ranges = np.array([[-offset, n - offset, 1], [0, n, 1]],
dtype=self.indices.dtype)

if n < 0:
if n <= 0:
return self

diag_indices = combine_ranges(ranges, self.shape, n, inner=True)
Expand Down Expand Up @@ -261,49 +263,6 @@ def _prepare_indices(self, index):
idx_type |= ARRAY_INDEX_MASK
return tuple(mut_indices), idx_type

def _slice_multi(self, indices, inner=True):
'''Helper for making a new SpArray using (int,array-like) indices.
dense[ii,jj] -> sparse._slice_multi((ii, jj), inner=True)
dense[ii[:,None],jj] -> sparse._slice_multi((ii, jj), inner=False)
'''
shape = tuple(len(idx) for idx in indices
if not isinstance(idx, numbers.Integral))
# easy when there's a zero dimension
if any(s == 0 for s in shape):
return SpArray([], [], shape, is_canonical=True)

if inner:
assert len(set(shape)) == 1
shape = (shape[0],)
flat_idx = np.ravel_multi_index(indices, self.shape)
else:
# outer indexing is more tricky
# TODO: share more code with combine_ranges
strides = np.ones(len(self.shape), dtype=self.indices.dtype)
np.cumprod(self.shape[:0:-1], out=strides[1:])
strides = strides[::-1]
flat_idx = indices[0] * strides[0]
for idx, s in zip(indices[1:], strides[1:]):
flat_idx = np.add.outer(flat_idx, idx * s).ravel()

_, data_inds, new_indices = intersect1d_sorted(self.indices, flat_idx,
return_inds=True)
new_data = self.data[data_inds]
return SpArray(new_indices, new_data, shape, is_canonical=True)

def _slice_ranges(self, ranges, new_shape, inner=False):
'''Helper for making a new SpArray using slice/range indices.
ranges : a (d, 3) array with rows of [start, stop, step] values
new_shape : the resulting shape after slicing
inner : boolean, see _slice_multi for explanation
'''
result_size = np.product(new_shape)
flat_idx = combine_ranges(ranges, self.shape, result_size, inner=inner)
_, data_inds, new_indices = intersect1d_sorted(self.indices, flat_idx,
return_inds=True)
return SpArray(new_indices, self.data[data_inds], new_shape,
is_canonical=True)

def __getitem__(self, indices):
indices, idx_type = self._prepare_indices(indices)

Expand All @@ -322,19 +281,60 @@ def __getitem__(self, indices):
# non-fancy case: all indices are slices or integers
if not (idx_type & ARRAY_INDEX_MASK):
ranges, new_shape = self._indices_to_ranges(indices)
return self._slice_ranges(ranges, tuple(new_shape), inner=False)
flat_idx = combine_ranges(ranges, self.shape, np.product(new_shape),
inner=False)
return self._getitem_flatidx(flat_idx, new_shape)

# fancy indexing cases
# some slices are present, trigger outer indexing
if idx_type & (EMPTY_SLICE_INDEX_MASK | SLICE_INDEX_MASK):
mut_indices = list(indices)
for i, (idx, dim) in enumerate(zip(indices, self.shape)):
if isinstance(idx, slice):
mut_indices[i] = np.arange(*idx.indices(dim))
return self._slice_multi(mut_indices, inner=False)
# inner-only fancy indexing
# TODO: ndim index arrays are NYI for now
if not (idx_type & (EMPTY_SLICE_INDEX_MASK | SLICE_INDEX_MASK)):
flat_idx = np.ravel_multi_index(indices, self.shape)
return self._getitem_flatidx(flat_idx, (len(flat_idx),))

# remaining case: inner indexing (ndim index arrays are NYI)
return self._slice_multi(indices, inner=True)
# compute the new shape, pulling out int/array indices
new_shape = []
inner_indices, outer_indices = [], []
inner_shape_idx = None
non_slice_idxs = []
for i, idx in enumerate(indices):
if isinstance(idx, slice):
x = np.arange(*idx.indices(self.shape[i]))
new_shape.append(len(x))
inner_indices.append(0) # placeholder
outer_indices.append(x)
else:
non_slice_idxs.append(i)
inner_indices.append(idx)
if inner_shape_idx is None:
inner_shape_idx = len(new_shape)
# make placeholders
if isinstance(idx, numbers.Integral):
new_shape.append(-1)
else:
new_shape.append(len(idx))
outer_indices.append(None)
elif not isinstance(idx, numbers.Integral):
new_shape[inner_shape_idx] = max(len(idx), new_shape[inner_shape_idx])

# exit now if there's a zero dimension
if any(s == 0 for s in new_shape):
return SpArray([], [], tuple(new_shape), is_canonical=True)

# coalesce the inner indices
if inner_shape_idx is not None:
x = np.ravel_multi_index(inner_indices, self.shape)
new_shape[inner_shape_idx] = len(x)
outer_indices[inner_shape_idx] = x

# only outer indexes remain
strides = np.ones(len(self.shape), dtype=self.indices.dtype)
np.cumprod(self.shape[:0:-1], out=strides[1:])
strides = strides[::-1]
strides[non_slice_idxs] = 1
flat_idx = outer_indices[0] * strides[0]
for idx, s in zip(outer_indices[1:], strides[1:]):
flat_idx = np.add.outer(flat_idx, idx * s).ravel()
return self._getitem_flatidx(flat_idx, new_shape, is_sorted=(self.ndim<2))

def __setitem__(self, indices, val):
indices, idx_type = self._prepare_indices(indices)
Expand Down Expand Up @@ -392,6 +392,17 @@ def _indices_to_ranges(self, indices):
ranges[i,:] = (idx, idx + 1, 1)
return ranges, new_shape

def _getitem_flatidx(self, flat_idx, new_shape, is_sorted=True):
if not is_sorted:
order = np.argsort(flat_idx, kind='mergesort')
flat_idx = flat_idx[order]
_, data_inds, new_indices = intersect1d_sorted(self.indices, flat_idx,
return_inds=True)
new_data = self.data[data_inds]
if not is_sorted:
new_indices = order[new_indices]
return SpArray(new_indices, new_data, tuple(new_shape), is_canonical=True)

def _setitem_flatidx(self, flat_idx, values):
idx, lut, lhs_only, rhs_only = union1d_sorted(self.indices, flat_idx,
return_masks=True)
Expand Down
8 changes: 7 additions & 1 deletion sparray/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
dense1d_indices = [0,1,3,4]
dense1d_data = [-2,-1,1,2]

dense3d = np.arange(24).reshape((3,2,4))[::-1]
dense3d[[0,2],:,2:] = 0
dense3d[1,0,:] = 0


def assert_sparse_equal(a, b, err_msg=''):
if hasattr(a, 'A'):
Expand All @@ -40,11 +44,13 @@ class BaseSpArrayTest(unittest.TestCase):
def setUp(self):
self.sp1d = SpArray(dense1d_indices, dense1d_data, shape=dense1d.shape)
self.sp2d = SpArray(dense2d_indices, dense2d_data, shape=dense2d.shape)
self.sp3d = SpArray.from_ndarray(dense3d)
self.pairs = [
(dense1d, self.sp1d),
(dense2d, self.sp2d),
(np.array([]), SpArray([],[],shape=(0,))),
(np.zeros((1,2,3)), SpArray([],[],shape=(1,2,3))),
(dense3d, self.sp3d),
]

def _same_op(self, op, assertFn):
Expand All @@ -62,7 +68,7 @@ def test_init(self):
assert_array_equal(b.toarray(), dense1d)

def test_from_ndarray(self):
for arr in (dense2d, dense1d):
for arr in (dense2d, dense1d, dense3d):
a = SpArray.from_ndarray(arr)
assert_array_equal(a.toarray(), arr)

Expand Down
16 changes: 15 additions & 1 deletion sparray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numpy.testing import assert_array_equal

from .test_base import (
BaseSpArrayTest, dense2d, dense1d, sparse2d, assert_sparse_equal)
BaseSpArrayTest, dense1d, dense2d, sparse2d, dense3d, assert_sparse_equal)


class TestIndexing(BaseSpArrayTest):
Expand Down Expand Up @@ -52,6 +52,20 @@ def test_slicing(self):
assert_array_equal(dense1d[1:], self.sp1d[1:].toarray())
assert_array_equal(dense2d[1:,1:], self.sp2d[1:,1:].toarray())

def test_mixed_fancy_indexing(self):
idx = [0,2]
assert_array_equal(dense2d[:,idx], self.sp2d[:,idx].toarray())
assert_array_equal(dense2d[idx,:], self.sp2d[idx,:].toarray())

assert_array_equal(dense3d[idx,:,idx], self.sp3d[idx,:,idx].toarray())
assert_array_equal(dense3d[[1],:,idx], self.sp3d[[1],:,idx].toarray())
assert_array_equal(dense3d[:,[1],idx], self.sp3d[:,[1],idx].toarray())
assert_array_equal(dense3d[idx,[1],:], self.sp3d[idx,[1],:].toarray())

assert_array_equal(dense3d[2,:,idx], self.sp3d[2,:,idx].toarray())
assert_array_equal(dense3d[:,1,idx], self.sp3d[:,1,idx].toarray())
assert_array_equal(dense3d[idx,1,:], self.sp3d[idx,1,:].toarray())

def test_inner_indexing(self):
idx = [0,2]
assert_array_equal(dense1d[idx], self.sp1d[idx].toarray())
Expand Down

0 comments on commit dc34b55

Please sign in to comment.