Skip to content

Commit

Permalink
make rng a property of RandomizableMap
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Sep 19, 2023
1 parent d9b6ad6 commit e52adc1
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 82 deletions.
34 changes: 27 additions & 7 deletions heracles/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from .catalog import Catalog, CatalogPage

logger = logging.getLogger(__name__)
_RANDOM_SEED = 30


def _nativebyteorder(fn):
Expand Down Expand Up @@ -171,9 +170,19 @@ class RandomizableMap(Map):
"""

def __init__(self, randomize: bool, **kwargs) -> None:
default_rng: np.random.Generator = np.random.default_rng()
"""Default random number generator for randomizable maps."""

def __init__(
self,
randomize: bool,
*,
rng: t.Optional[np.random.Generator] = None,
**kwargs,
) -> None:
"""Initialise map with the given randomize property."""
self._randomize = randomize
self._rng = rng
super().__init__(**kwargs)

@property
Expand All @@ -185,6 +194,16 @@ def randomize(self, randomize: bool) -> None:
"""Set the randomize flag."""
self._randomize = randomize

@property
def rng(self) -> np.random.Generator:
"""Random number generator of this map."""
return self._rng or self.default_rng

@rng.setter
def rng(self, rng: np.random.Generator) -> None:
"""Set the random number generator of this map."""
self._rng = rng


class NormalizableMap(Map):
"""Abstract base class for normalisable maps.
Expand Down Expand Up @@ -227,9 +246,10 @@ def __init__(
*,
overdensity: bool = True,
randomize: bool = False,
rng: t.Optional[np.random.Generator] = None,
) -> None:
"""Create a position map with the given properties."""
super().__init__(columns=(lon, lat), nside=nside, randomize=randomize)
super().__init__(columns=(lon, lat), nside=nside, randomize=randomize, rng=rng)
self._overdensity: bool = overdensity

@property
Expand Down Expand Up @@ -285,8 +305,7 @@ def mapper(page: "CatalogPage") -> None:
p = np.full(npix, 1 / npix)
else:
p = vmap / np.sum(vmap)
rng = np.random.default_rng(_RANDOM_SEED)
pos[:] = rng.multinomial(ngal, p)
pos[:] = self.rng.multinomial(ngal, p)

# compute average number density
nbar = ngal / npix
Expand Down Expand Up @@ -449,6 +468,7 @@ def __init__(
conjugate: bool = False,
normalize: bool = True,
randomize: bool = False,
rng: t.Optional[np.random.Generator] = None,
) -> None:
"""Create a new shear map."""

Expand All @@ -459,6 +479,7 @@ def __init__(
nside=nside,
normalize=normalize,
randomize=randomize,
rng=rng,
)

@property
Expand Down Expand Up @@ -522,8 +543,7 @@ def mapper(page: "CatalogPage") -> None:
im = -im

if randomize:
rng = np.random.default_rng(_RANDOM_SEED)
a = rng.uniform(0.0, 2 * np.pi, size=page.size)
a = self.rng.uniform(0.0, 2 * np.pi, size=page.size)
r = np.hypot(re, im)
re, im = r * np.cos(a), r * np.sin(a)
del a, r
Expand Down
12 changes: 2 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,5 @@ def warns(*types):


@pytest.fixture(scope="session")
def random_generator(random_seed: int = 50) -> np.random.Generator:
"""A generator object consistent across all tests
Args:
random_seed: A seed to initialise the BitGenerator
Returns:
The initialised generator object
"""
return np.random.default_rng(random_seed)
def rng(seed: int = 50) -> np.random.Generator:
return np.random.default_rng(seed)
32 changes: 16 additions & 16 deletions tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@


@pytest.fixture
def catalog(random_generator):
def catalog(rng):
from heracles.catalog import CatalogBase, CatalogPage

# fix a set of rows to be returned for testing
size = 100
x = random_generator.random(size)
y = random_generator.random(size)
z = random_generator.random(size)
x = rng.random(size)
y = rng.random(size)
z = rng.random(size)

class TestCatalog(CatalogBase):
SIZE = size
Expand Down Expand Up @@ -299,19 +299,19 @@ def test_invalid_value_filter(catalog):
npt.assert_array_equal(page.get(k), v[2:])


def test_footprint_filter(catalog, random_generator):
def test_footprint_filter(catalog, rng):
from healpy import ang2pix

from heracles.catalog import FootprintFilter

# footprint for northern hemisphere
nside = 8
m = np.round(random_generator.random(12 * nside**2))
m = np.round(rng.random(12 * nside**2))

# replace x and y in catalog with lon and lat
catalog.DATA["x"] = lon = random_generator.uniform(-180, 180, size=catalog.SIZE)
catalog.DATA["x"] = lon = rng.uniform(-180, 180, size=catalog.SIZE)
catalog.DATA["y"] = lat = np.degrees(
np.arcsin(random_generator.uniform(-1, 1, size=catalog.SIZE)),
np.arcsin(rng.uniform(-1, 1, size=catalog.SIZE)),
)

filt = FootprintFilter(m, "x", "y")
Expand All @@ -329,12 +329,12 @@ def test_footprint_filter(catalog, random_generator):
np.testing.assert_array_equal(page[k], v[good])


def test_array_catalog(random_generator):
def test_array_catalog(rng):
from heracles.catalog import ArrayCatalog, Catalog

arr = np.empty(100, [("lon", float), ("lat", float), ("x", float), ("y", float)])
for name in arr.dtype.names:
arr[name] = random_generator.random(len(arr))
arr[name] = rng.random(len(arr))

catalog = ArrayCatalog(arr)

Expand Down Expand Up @@ -372,15 +372,15 @@ def test_array_catalog(random_generator):
assert copied.__dict__ == catalog.__dict__


def test_fits_catalog(random_generator, tmp_path):
def test_fits_catalog(rng, tmp_path):
import fitsio

from heracles.catalog import Catalog
from heracles.catalog.fits import FitsCatalog

size = 100
ra = random_generator.uniform(-180, 180, size=size)
dec = random_generator.uniform(-90, 90, size=size)
ra = rng.uniform(-180, 180, size=size)
dec = rng.uniform(-90, 90, size=size)

filename = str(tmp_path / "catalog.fits")

Expand Down Expand Up @@ -436,16 +436,16 @@ def test_fits_catalog(random_generator, tmp_path):
assert copied._ext == catalog._ext


def test_fits_catalog_caching(random_generator, tmp_path):
def test_fits_catalog_caching(rng, tmp_path):
import gc

import fitsio

from heracles.catalog.fits import FitsCatalog

size = 100
ra = random_generator.uniform(-180, 180, size=size)
dec = random_generator.uniform(-90, 90, size=size)
ra = rng.uniform(-180, 180, size=size)
dec = rng.uniform(-90, 90, size=size)

filename = str(tmp_path / "cached.fits")

Expand Down
12 changes: 6 additions & 6 deletions tests/test_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import pytest


def test_sample_covariance(random_generator):
def test_sample_covariance(rng):
from heracles.covariance import SampleCovariance, add_sample

n = 10
size = 3
size2 = 5

samples = [random_generator.standard_normal(size) for _ in range(n)]
samples2 = [random_generator.standard_normal(size2) for _ in range(n)]
samples = [rng.standard_normal(size) for _ in range(n)]
samples2 = [rng.standard_normal(size2) for _ in range(n)]

cov = SampleCovariance(size)

Expand Down Expand Up @@ -53,7 +53,7 @@ def test_sample_covariance(random_generator):
add_sample(cov, np.zeros(size + 1), np.zeros(size2 - 1))


def test_update_covariance(random_generator):
def test_update_covariance(rng):
from itertools import combinations_with_replacement

from heracles.covariance import update_covariance
Expand All @@ -62,7 +62,7 @@ def test_update_covariance(random_generator):

cov = {}

sample = {i: random_generator.standard_normal(i + 1) for i in range(n)}
sample = {i: rng.standard_normal(i + 1) for i in range(n)}
update_covariance(cov, sample)

assert len(cov) == n * (n + 1) // 2
Expand All @@ -71,7 +71,7 @@ def test_update_covariance(random_generator):
assert cov[k1, k2].shape == (sample[k1].size, sample[k2].size)
assert np.all(cov[k1, k2] == 0)

sample2 = {i: random_generator.standard_normal(i + 1) for i in range(n)}
sample2 = {i: rng.standard_normal(i + 1) for i in range(n)}
update_covariance(cov, sample2)

assert len(cov) == n * (n + 1) // 2
Expand Down
36 changes: 18 additions & 18 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def zbins():


@pytest.fixture
def mock_alms(random_generator, zbins):
def mock_alms(rng, zbins):
import numpy as np

lmax = 32
Expand All @@ -21,18 +21,18 @@ def mock_alms(random_generator, zbins):
alms = {}
for n in names:
for i in zbins:
a = random_generator.standard_normal((Nlm, 2)) @ [1, 1j]
a = rng.standard_normal((Nlm, 2)) @ [1, 1j]
a.dtype = np.dtype(a.dtype, metadata={"nside": 32})
alms[n, i] = a

return alms


@pytest.fixture
def mock_cls(random_generator):
def mock_cls(rng):
import numpy as np

cl = random_generator.random(101)
cl = rng.random(101)
cl.dtype = np.dtype(cl.dtype, metadata={"nside_1": 32, "nside_2": 64})

return {
Expand Down Expand Up @@ -71,13 +71,13 @@ def datadir(tmp_path_factory):


@pytest.fixture(scope="session")
def mock_mask_fields(nside, random_generator):
def mock_mask_fields(nside, rng):
import healpy as hp
import numpy as np

npix = hp.nside2npix(nside)
maps = random_generator.random(npix * NFIELDS_TEST).reshape((npix, NFIELDS_TEST))
pixels = np.unique(random_generator.integers(0, npix, size=npix // 3))
maps = rng.random(npix * NFIELDS_TEST).reshape((npix, NFIELDS_TEST))
pixels = np.unique(rng.integers(0, npix, size=npix // 3))
maskpix = np.delete(np.arange(0, npix), pixels)
for i in range(NFIELDS_TEST):
maps[:, i][maskpix] = 0
Expand Down Expand Up @@ -141,13 +141,13 @@ def mock_writemask_full(mock_mask_fields, nside, datadir):


@pytest.fixture(scope="session")
def mock_mask_extra(nside, random_generator):
def mock_mask_extra(nside, rng):
import healpy as hp
import numpy as np

npix = hp.nside2npix(nside)
maps = random_generator.random(npix)
pixels = np.unique(random_generator.integers(0, npix, size=npix // 3))
maps = rng.random(npix)
pixels = np.unique(rng.integers(0, npix, size=npix // 3))
maskpix = np.delete(np.arange(0, npix), pixels)
maps[maskpix] = 0
return [maps, pixels]
Expand Down Expand Up @@ -178,7 +178,7 @@ def mock_writemask_extra(mock_mask_extra, nside, datadir):
return filename


def test_write_read_maps(random_generator, tmp_path):
def test_write_read_maps(rng, tmp_path):
import healpy as hp
import numpy as np

Expand All @@ -187,9 +187,9 @@ def test_write_read_maps(random_generator, tmp_path):
nside = 4
npix = 12 * nside**2

p = random_generator.random(npix)
v = random_generator.random(npix)
g = random_generator.random((2, npix))
p = rng.random(npix)
v = rng.random(npix)
g = rng.random((2, npix))

p.dtype = np.dtype(p.dtype, metadata={"spin": 0})
v.dtype = np.dtype(v.dtype, metadata={"spin": 0})
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_write_read_cls(mock_cls, tmp_path):
assert cl.dtype.metadata == mock_cls[key].dtype.metadata


def test_write_read_mms(random_generator, tmp_path):
def test_write_read_mms(rng, tmp_path):
import numpy as np

from heracles.io import read_mms, write_mms
Expand All @@ -263,9 +263,9 @@ def test_write_read_mms(random_generator, tmp_path):
workdir = str(tmp_path)

mms = {
("00", 0, 1): random_generator.standard_normal((10, 10)),
("0+", 1, 2): random_generator.standard_normal((20, 5)),
("++", 2, 3): random_generator.standard_normal((10, 5, 2)),
("00", 0, 1): rng.standard_normal((10, 10)),
("0+", 1, 2): rng.standard_normal((20, 5)),
("++", 2, 3): rng.standard_normal((10, 5, 2)),
}

write_mms(filename, mms, workdir=workdir)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_kmeans_radec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


@pytest.mark.flaky(reruns=2)
def test_kmeans_sample(random_generator):
def test_kmeans_sample(rng):
import numpy as np

from heracles._kmeans_radec import kmeans_sample
Expand All @@ -11,8 +11,8 @@ def test_kmeans_sample(random_generator):
ncen = 20

pts = np.empty((npts, 2))
pts[:, 0] = random_generator.uniform(-180, 180, size=npts)
pts[:, 1] = np.degrees(np.arcsin(random_generator.uniform(-1, 1, size=npts)))
pts[:, 0] = rng.uniform(-180, 180, size=npts)
pts[:, 1] = np.degrees(np.arcsin(rng.uniform(-1, 1, size=npts)))

km = kmeans_sample(pts, ncen, verbose=2)

Expand Down
Loading

0 comments on commit e52adc1

Please sign in to comment.