Skip to content

Commit

Permalink
Add weighted histograms
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Sep 25, 2023
1 parent 642e681 commit a8f10f1
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions stingray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,51 @@ def hist3d_numba_seq(tracks, bins, ranges, use_memmap=False, tmp=None):
return _hist3d_numba_seq(H, np.asarray(tracks), np.asarray(list(bins)), np.asarray(ranges))


@njit(nogil=True, parallel=False)
def _hist1d_numba_seq_weight(H, tracks, weights, bins, ranges):
delta = 1 / ((ranges[1] - ranges[0]) / bins)

for t in range(tracks.size):
i = (tracks[t] - ranges[0]) * delta
if 0 <= i < bins:
H[int(i)] += weights[t]

return H


def hist1d_numba_seq_weight(a, weights, bins, ranges, use_memmap=False, tmp=None):
"""
Examples
--------
>>> if os.path.exists('out.npy'): os.unlink('out.npy')
>>> x = np.random.uniform(0., 1., 100)
>>> weights = np.random.uniform(0, 1, 100)
>>> H, xedges = np.histogram(x, bins=5, range=[0., 1.], weights=weights)
>>> Hn = hist1d_numba_seq_weight(x, weights, bins=5, ranges=[0., 1.], tmp='out.npy',
... use_memmap=True)
>>> assert np.all(H == Hn)
>>> # The number of bins is small, memory map was not used!
>>> assert not os.path.exists('out.npy')
>>> H, xedges = np.histogram(x, bins=10**8, range=[0., 1.], weights=weights)
>>> Hn = hist1d_numba_seq_weight(x, weights, bins=10**8, ranges=[0., 1.], tmp='out.npy',
... use_memmap=True)
>>> assert np.all(H == Hn)
>>> assert os.path.exists('out.npy')
>>> # Now use memmap but do not specify a tmp file
>>> Hn = hist1d_numba_seq_weight(x, weights, bins=10**8, ranges=[0., 1.],
... use_memmap=True)
>>> assert np.all(H == Hn)
"""
if bins > 10**7 and use_memmap:
if tmp is None:
tmp = tempfile.NamedTemporaryFile("w+").name
hist_arr = np.lib.format.open_memmap(tmp, mode="w+", dtype=a.dtype, shape=(bins,))
else:
hist_arr = np.zeros((bins,), dtype=a.dtype)

return _hist1d_numba_seq_weight(hist_arr, a, weights, bins, np.asarray(ranges))


@njit(nogil=True, parallel=False)
def _hist2d_numba_seq_weight(H, tracks, weights, bins, ranges):
delta = 1 / ((ranges[:, 1] - ranges[:, 0]) / bins)
Expand Down Expand Up @@ -1914,13 +1959,64 @@ def histnd_numba_seq(tracks, bins, ranges, use_memmap=False, tmp=None):
if HAS_NUMBA:

def histogram2d(*args, **kwargs):
"""
Examples
--------
>>> x = np.random.uniform(0., 1., 100)
>>> y = np.random.uniform(2., 3., 100)
>>> weight = np.random.uniform(0, 1, 100)
>>> H, xedges, yedges = np.histogram2d(x, y, bins=(5, 5),
... range=[(0., 1.), (2., 3.)],
... weights=weight)
>>> Hn = histogram2d(x, y, bins=(5, 5),
... ranges=[[0., 1.], [2., 3.]],
... weights=weight)
>>> assert np.all(H == Hn)
>>> Hn1 = histogram2d(x, y, bins=(5, 5),
... ranges=[[0., 1.], [2., 3.]],
... weights=None)
>>> Hn2 = histogram2d(x, y, bins=(5, 5),
... ranges=[[0., 1.], [2., 3.]])
>>> assert np.all(Hn1 == Hn2)
"""
if "range" in kwargs:
kwargs["ranges"] = kwargs.pop("range")

if "weights" not in kwargs:
return hist2d_numba_seq(*args, **kwargs)

weights = kwargs.pop("weights")

if weights is not None:
return hist2d_numba_seq_weight(*args, weights, **kwargs)

return hist2d_numba_seq(*args, **kwargs)

def histogram(*args, **kwargs):
"""
Examples
--------
>>> x = np.random.uniform(0., 1., 100)
>>> weights = np.random.uniform(0, 1, 100)
>>> H, xedges = np.histogram(x, bins=5, range=[0., 1.], weights=weights)
>>> Hn = histogram(x, weights=weights, bins=5, ranges=[0., 1.], tmp='out.npy',
... use_memmap=True)
>>> assert np.all(H == Hn)
>>> Hn1 = histogram(x, weights=None, bins=5, ranges=[0., 1.])
>>> Hn2 = histogram(x, bins=5, ranges=[0., 1.])
>>> assert np.all(Hn1 == Hn2)
"""
if "range" in kwargs:
kwargs["ranges"] = kwargs.pop("range")

if "weights" not in kwargs:
return hist1d_numba_seq(*args, **kwargs)

weights = kwargs.pop("weights")

if weights is not None:
return hist1d_numba_seq_weight(*args, weights, **kwargs)

return hist1d_numba_seq(*args, **kwargs)

else:
Expand Down

0 comments on commit a8f10f1

Please sign in to comment.