Skip to content

Commit

Permalink
Merge pull request #428 from grlee77/backport_fixes_from_master
Browse files Browse the repository at this point in the history
Backport bug fixes from master to 1.0.x (#423 and #427)
  • Loading branch information
grlee77 authored Sep 25, 2018
2 parents 92f0115 + 07c1fc7 commit a3473b2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
28 changes: 21 additions & 7 deletions pywt/_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,10 @@ def ravel_coeffs(coeffs, axes=None):
if np.any([d is None for d in coeff_dict.values()]):
raise ValueError("coeffs_to_array does not support missing "
"coefficients.")
for key, d in coeff_dict.items():
# sort to make sure key order is consistent across Python versions
keys = sorted(coeff_dict.keys())
for key in keys:
d = coeff_dict[key]
sl = slice(offset, offset + d.size)
offset += d.size
coeff_arr[sl] = d.ravel()
Expand Down Expand Up @@ -1171,8 +1174,8 @@ class FswavedecnResult(object):
----------
coeffs : ndarray
The coefficient array.
coeff_slices : dict
Dictionary of slices corresponding to each detail or approximation
coeff_slices : list
List of slices corresponding to each detail or approximation
coefficient array.
wavelets : list of pywt.DiscreteWavelet objects
The wavelets used. Will be a list with length equal to
Expand Down Expand Up @@ -1212,7 +1215,7 @@ def coeffs(self, c):

@property
def coeff_slices(self):
"""Dict: Dictionary of coeffficient slices."""
"""List: List of coefficient slices."""
return self._coeff_slices

@property
Expand Down Expand Up @@ -1256,7 +1259,7 @@ def _get_coef_sl(self, levels):
sl = [slice(None), ] * self.ndim
for n, (ax, lev) in enumerate(zip(self.axes, levels)):
sl[ax] = self.coeff_slices[n][lev]
return sl
return tuple(sl)

@property
def approx(self):
Expand Down Expand Up @@ -1319,12 +1322,13 @@ def __setitem__(self, levels, x):
"""
self._validate_index(levels)
sl = self._get_coef_sl(levels)
current_dtype = self._coeffs[sl].dtype
if self._coeffs[sl].shape != x.shape:
raise ValueError(
"x does not match the shape of the requested coefficient")
if x.dtype != sl.dtype:
if x.dtype != current_dtype:
warnings.warn("dtype mismatch: converting the provided array to"
"dtype {}".format(sl.dtype))
"dtype {}".format(current_dtype))
self._coeffs[sl] = x

def detail_keys(self):
Expand Down Expand Up @@ -1375,6 +1379,16 @@ def fswavedecn(data, wavelet, mode='symmetric', levels=None, axes=None):
the coefficients per detail or approximation level, and more.
See `FswavedecnResult` for details.
Examples
--------
>>> from pywt import fswavedecn
>>> fs_result = fswavedecn(np.ones((32, 32)), 'sym2', levels=(1, 3))
>>> print(fs_result.detail_keys())
[(0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3)]
>>> approx_coeffs = fs_result.approx
>>> detail_1_2 = fs_result[(1, 2)]
Notes
-----
This transformation has been variously referred to as the (fully) separable
Expand Down
33 changes: 33 additions & 0 deletions pywt/tests/test_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,39 @@ def test_fswavedecn_fswaverecn_axes_subsets():
assert_raises(ValueError, pywt.fswavedecn, data, 'haar', axes=(1, 5))


def test_fswavedecnresult():
data = np.ones((32, 32))
levels = (1, 2)
result = pywt.fswavedecn(data, 'sym2', levels=levels)

# can access the lowpass band via .approx or via __getitem__
approx_key = (0, ) * data.ndim
assert_array_equal(result[approx_key], result.approx)

dkeys = result.detail_keys()
# the approximation key shouldn't be present in the detail_keys
assert_(approx_key not in dkeys)

# can access all detail coefficients and they have matching ndim
for k in dkeys:
d = result[k]
assert_equal(d.ndim, data.ndim)

# can assign modified coefficients
result[k] = np.zeros_like(d)

# assigning a differently sized array raises a ValueError
assert_raises(ValueError, result.__setitem__,
k, np.zeros(tuple([s + 1 for s in d.shape])))

# warns on assigning with a non-matching dtype
assert_warns(UserWarning, result.__setitem__,
k, np.zeros_like(d).astype(np.float32))

# all coefficients are stacked into result.coeffs (same ndim)
assert_equal(result.coeffs.ndim, data.ndim)


def test_error_on_continuous_wavelet():
# A ValueError is raised if a Continuous wavelet is selected
data = np.ones((16, 16))
Expand Down

0 comments on commit a3473b2

Please sign in to comment.