Skip to content

Commit

Permalink
Add wrapper for fine distributions (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz authored Sep 14, 2023
1 parent 136eb7d commit c8035ef
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/bmi/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

# isort: on
import bmi.samplers._tfp as fine
from bmi.samplers._split_student_t import SplitStudentT
from bmi.samplers._splitmultinormal import BivariateNormalSampler, SplitMultinormal
from bmi.samplers._transformed import TransformedSampler
Expand All @@ -21,6 +22,7 @@
"AdditiveUniformSampler",
"BaseSampler",
"canonical_correlation",
"fine",
"parametrised_correlation_matrix",
"BivariateNormalSampler",
"SplitMultinormal",
Expand Down
2 changes: 2 additions & 0 deletions src/bmi/samplers/_tfp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# isort: on
from bmi.samplers._tfp._normal import MultivariateNormalDistribution
from bmi.samplers._tfp._student import MultivariateStudentDistribution
from bmi.samplers._tfp._wrapper import FineSampler

__all__ = [
"JointDistribution",
Expand All @@ -19,4 +20,5 @@
"monte_carlo_mi_estimate",
"MultivariateNormalDistribution",
"MultivariateStudentDistribution",
"FineSampler",
]
12 changes: 7 additions & 5 deletions src/bmi/samplers/_tfp/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ class JointDistribution:
dim_y: int
analytic_mi: Optional[float] = None

def sample(self, key: jax.random.PRNGKeyArray, n: int) -> tuple[jnp.ndarray, jnp.ndarray]:
def sample(
self, n_points: int, key: jax.random.PRNGKeyArray
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Sample from the joint distribution.
Args:
n_points: number of samples to draw
key: JAX random key
n: number of samples to draw
"""
if n < 1:
if n_points < 1:
raise ValueError("n must be positive")

xy = self.dist_joint.sample(seed=key, sample_shape=(n,))
xy = self.dist_joint.sample(seed=key, sample_shape=(n_points,))
return xy[..., : self.dim_x], xy[..., self.dim_x :] # noqa: E203 (formatting discrepancy)

def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
Expand Down Expand Up @@ -160,7 +162,7 @@ def pmi_profile(key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int) -
Returns:
PMI profile, shape `(n,)`
"""
x, y = dist.sample(key, n)
x, y = dist.sample(key=key, n_points=n)
return dist.pmi(x, y)


Expand Down
48 changes: 48 additions & 0 deletions src/bmi/samplers/_tfp/_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""A wrapper from TFP distributions to BMI samplers."""
from typing import Optional, Union

import jax

from bmi.samplers._tfp._core import JointDistribution, monte_carlo_mi_estimate
from bmi.samplers.base import BaseSampler, KeyArray, cast_to_rng


class FineSampler(BaseSampler):
"""Wrapper around a fine distribution."""

def __init__(
self,
dist: JointDistribution,
mi: Optional[float] = None,
mi_estimate_seed: Union[KeyArray, int] = 0,
mi_estimate_sample: int = 200_000,
) -> None:
"""
Args:
dist: fine distribution to be wrapped
mi: mutual information of the fine distribution, if already calculated.
If not provided, it will be estimated via Monte Carlo sampling.
mi_estimate_seed: seed for the Monte Carlo sampling
mi_estimate_sample: number of samples for the Monte Carlo sampling
"""
super().__init__(dim_x=dist.dim_x, dim_y=dist.dim_y)
self._dist = dist

if mi is None:
rng = cast_to_rng(mi_estimate_seed)
self._mi, self._mi_stderr = monte_carlo_mi_estimate(
key=rng, dist=self._dist, n=mi_estimate_sample
)
else:
self._mi = mi
self._mi_stderr = None

def sample(
self, n_points: int, rng: Union[int, KeyArray]
) -> tuple[jax.numpy.ndarray, jax.numpy.ndarray]:
key = cast_to_rng(rng)
return self._dist.sample(n_points=n_points, key=key)

def mutual_information(self) -> float:
return self._mi
6 changes: 3 additions & 3 deletions tests/samplers/tfp/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def distributions(dim_x: int = 2, dim_y: int = 3) -> list[bmi_tfp.JointDistribut
@pytest.mark.parametrize("dist", distributions())
def test_sample_and_pmi(dist: bmi_tfp.JointDistribution, n_samples: int = 10) -> None:
"""Checks whether we can sample from the distribution and calculate PMI."""
x, y = dist.sample(jax.random.PRNGKey(0), n=n_samples)
x, y = dist.sample(n_samples, jax.random.PRNGKey(0))

assert x.shape == (n_samples, dist.dim_x)
assert y.shape == (n_samples, dist.dim_y)
Expand All @@ -52,8 +52,8 @@ def test_transformed(dist: bmi_tfp.JointDistribution, n_points: int = 1_000) ->

key = jax.random.PRNGKey(0)

x_base, y_base = base_dist.sample(key, n=n_points)
x_tran, y_tran = transformed.sample(key, n=n_points)
x_base, y_base = base_dist.sample(n_points, key)
x_tran, y_tran = transformed.sample(n_points, key)

# Check shapes
assert x_base.shape == x_tran.shape
Expand Down
4 changes: 2 additions & 2 deletions tests/samplers/tfp/test_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_1v1(correlation: float = 0.5, n: int = 10):

key = jax.random.PRNGKey(0)

x, y = dist.sample(key, n=n)
x, y = dist.sample(n, key)

assert x.shape == (n, 1)
assert y.shape == (n, 1)
Expand All @@ -23,7 +23,7 @@ def test_1v1(correlation: float = 0.5, n: int = 10):
== BivariateNormalSampler(correlation=correlation).mutual_information()
)
# Check whether the Monte Carlo estimate is correct
estimate, _ = monte_carlo_mi_estimate(key, dist, n=5_000)
estimate, _ = monte_carlo_mi_estimate(key, dist=dist, n=5_000)
assert pytest.approx(estimate, abs=0.01) == dist.analytic_mi


Expand Down
19 changes: 19 additions & 0 deletions tests/samplers/tfp/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import jax.numpy as jnp
import pytest

from bmi.samplers import fine


def test_can_create_sampler() -> None:
dist = fine.MultivariateNormalDistribution(
dim_x=1, dim_y=1, covariance=jnp.asarray([[1, 0.5], [0.5, 1]])
)
mi = -0.5 * jnp.log(1 - 0.5**2)

sampler = fine.FineSampler(dist=dist, mi_estimate_seed=0, mi_estimate_sample=1_000)

x_sample, y_sample = sampler.sample(n_points=10, rng=0)
assert x_sample.shape == (10, 1)
assert y_sample.shape == (10, 1)

assert sampler.mutual_information() == pytest.approx(mi, abs=0.01)

0 comments on commit c8035ef

Please sign in to comment.