Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename fine distributions to BMMs #160

Merged
merged 7 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.11", "3.12"]
# Use only 3.11 as for 3.12 pytype is not supported yet
python-version: ["3.11"]
poetry-version: ["1.3.2"]

steps:
Expand Down
4 changes: 2 additions & 2 deletions docs/api/fine-distributions.md → docs/api/bmm.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Fine distributions
# Bend and Mix Models

## Core utilities

Expand All @@ -12,7 +12,7 @@

::: bmi.samplers._tfp.ProductDistribution

::: bmi.samplers._tfp.FineSampler
::: bmi.samplers._tfp.BMMSampler

## Basic distributions

Expand Down
4 changes: 2 additions & 2 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

[Samplers](samplers.md) represent joint probability distributions with known mutual information from which one can sample. They are lower level than `Tasks` and can be used to define new tasks by transformations which preserve mutual information.

### Fine distributions
[Subpackage](fine-distributions.md) implementing distributions in which the ground-truth mutual information may not be known analytically, but can be efficiently approximated using Monte Carlo methods.
### Bend and Mix Models
[Subpackage](bmm.md) implementing distributions known as *Bend and Mix Models*, for which the ground-truth mutual information may not be known analytically, but can be efficiently approximated using Monte Carlo methods.

## Interfaces
[Interfaces](interfaces.md) defines the main interfaces used in the package.
4 changes: 2 additions & 2 deletions docs/api/samplers.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ Samplers represent probability distributions with known mutual information.

::: bmi.samplers.ZeroInflatedPoissonizationSampler

## Fine distributions
## Bend and Mix Models

See the [fine distributions subpackage API](fine-distributions.md) for more information.
See the [Bend and Mix Models subpackage API](bmm.md) for more information.

### Auxiliary

Expand Down
106 changes: 54 additions & 52 deletions docs/fine-distributions.md → docs/bmm.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ nav:
- Home: index.md
- Estimators: estimators.md
- Benchmark: benchmarking-new-estimator.md
- Fine distributions: fine-distributions.md
- Bend and Mix Models: bmm.md
- Contributing: contributing.md
- API: api/index.md

Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ scipy = "^1.10.1"
tqdm = "^4.64.1"
tensorflow-probability = {extras = ["jax"], version = "^0.20.1"}

[tool.poetry.group.bayes]
optional = true

[tool.poetry.group.bayes.dependencies]
numpyro = "^0.14.0"

[tool.poetry.group.dev]
optional = true
Expand Down
7 changes: 4 additions & 3 deletions references.bib
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
@misc{mixtures-neural-critics-2023,
title={The Mixtures and the Neural Critics: On the Pointwise Mutual Information Profiles of Fine Distributions},
@misc{pmi-profiles-bmms-2023,
title={On the Properties and Estimation of Pointwise Mutual Information Profiles},
author={Paweł Czyż and Frederic Grabowski and Julia E. Vogt and Niko Beerenwinkel and Alexander Marx},
year={2023},
eprint={2310.10240},
archivePrefix={arXiv},
primaryClass={stat.ML}
primaryClass={stat.ML},
url={https://arxiv.org/abs/2310.10240}
}

@inproceedings{beyond-normal-2023,
Expand Down
50 changes: 25 additions & 25 deletions src/bmi/benchmark/tasks/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import bmi.samplers as samplers
import bmi.transforms as transforms
from bmi.benchmark.task import Task
from bmi.samplers import fine
from bmi.samplers import bmm

_MC_MI_ESTIMATE_SAMPLE = 100_000

Expand All @@ -15,10 +15,10 @@ def task_x(
) -> Task:
"""The X distribution."""

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.array([0.5, 0.5]),
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=samplers.canonical_correlation([x * gaussian_correlation]),
mean=jnp.zeros(2),
dim_x=1,
Expand All @@ -27,7 +27,7 @@ def task_x(
for x in [-1, 1]
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.BMMSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
Expand All @@ -47,44 +47,44 @@ def task_ai(
corr = 0.95
var_x = 0.04

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.full(6, fill_value=1 / 6),
components=[
# I components
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, 0.0]),
covariance=np.diag([0.01, 0.2]),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, 1]),
covariance=np.diag([0.05, 0.001]),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([1.0, -1]),
covariance=np.diag([0.05, 0.001]),
),
# A components
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-0.8, -0.2]),
covariance=np.diag([0.03, 0.001]),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-1.2, 0.0]),
covariance=jnp.array(
[[var_x, jnp.sqrt(var_x * 0.2) * corr], [jnp.sqrt(var_x * 0.2) * corr, 0.2]]
),
),
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=1,
dim_y=1,
mean=jnp.array([-0.4, 0.0]),
Expand All @@ -94,7 +94,7 @@ def task_ai(
),
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.BMMSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
Expand All @@ -110,10 +110,10 @@ def task_galaxy(
) -> Task:
"""The Galaxy distribution."""

balls_mixt = fine.mixture(
balls_mixt = bmm.mixture(
proportions=jnp.array([0.5, 0.5]),
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=samplers.canonical_correlation([0.0], additional_y=1),
mean=jnp.array([x, x, x]) * distance / 2,
dim_x=2,
Expand All @@ -123,7 +123,7 @@ def task_galaxy(
],
)

base_sampler = fine.FineSampler(balls_mixt, mi_estimate_sample=mi_estimate_sample)
base_sampler = bmm.BMMSampler(balls_mixt, mi_estimate_sample=mi_estimate_sample)
a = jnp.array([[0, -1], [1, 0]])
spiral = transforms.Spiral(a, speed=speed)

Expand All @@ -150,10 +150,10 @@ def task_waves(

assert n_components > 0

base_dist = fine.mixture(
base_dist = bmm.mixture(
proportions=jnp.ones(n_components) / n_components,
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=jnp.diag(jnp.array([0.1, 1.0, 0.1])),
mean=jnp.array([x, 0, x % 4]) * 1.5,
dim_x=2,
Expand All @@ -162,7 +162,7 @@ def task_waves(
for x in range(n_components)
],
)
base_sampler = fine.FineSampler(base_dist, mi_estimate_sample=mi_estimate_sample)
base_sampler = bmm.BMMSampler(base_dist, mi_estimate_sample=mi_estimate_sample)
aux_sampler = samplers.TransformedSampler(
base_sampler,
transform_x=lambda x: x
Expand Down Expand Up @@ -193,10 +193,10 @@ def task_concentric_multinormal(

assert n_components > 0

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.ones(n_components) / n_components,
components=[
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
covariance=jnp.diag(jnp.array(dim_x * [i**2] + [0.0001])),
mean=jnp.array(dim_x * [0.0] + [1.0 * i]),
dim_x=dim_x,
Expand All @@ -205,7 +205,7 @@ def task_concentric_multinormal(
for i in range(1, 1 + n_components)
],
)
sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.BMMSampler(dist, mi_estimate_sample=mi_estimate_sample)

return Task(
sampler=sampler,
Expand Down Expand Up @@ -238,23 +238,23 @@ def task_multinormal_sparse_w_inliers(
eta_x=strength,
)

signal_dist = fine.MultivariateNormalDistribution(
signal_dist = bmm.MultivariateNormalDistribution(
dim_x=dim_x,
dim_y=dim_y,
covariance=params.correlation,
)

noise_dist = fine.ProductDistribution(
noise_dist = bmm.ProductDistribution(
dist_x=signal_dist.dist_x,
dist_y=signal_dist.dist_y,
)

dist = fine.mixture(
dist = bmm.mixture(
proportions=jnp.array([1 - inlier_fraction, inlier_fraction]),
components=[signal_dist, noise_dist],
)

sampler = fine.FineSampler(dist, mi_estimate_sample=mi_estimate_sample)
sampler = bmm.BMMSampler(dist, mi_estimate_sample=mi_estimate_sample)

task_id = f"mult-sparse-w-inliers-{dim_x}-{dim_y}-{n_interacting}-{strength}-{inlier_fraction}"
return Task(
Expand Down
30 changes: 20 additions & 10 deletions src/bmi/estimators/external/gmm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""A Gaussian mixture model estimator, allowing for model-based
Bayesian estimator of mutual information.
The full description can be found [here](https://arxiv.org/abs/2310.10240).

Note that to use this estimator you need to install external dependencies:
```bash
$ pip install benchmark-mi[bayes]
```
"""

try:
import numpyro # type: ignore
import numpyro.distributions as dist # type: ignore
Expand All @@ -12,7 +22,7 @@
from numpy.typing import ArrayLike

from bmi.interface import BaseModel, IMutualInformationPointEstimator
from bmi.samplers import fine
from bmi.samplers import bmm
from bmi.utils import ProductSpace


Expand Down Expand Up @@ -74,14 +84,14 @@ def model(
)


def sample_into_fine_distribution(
def sample_into_bmm_distribution(
means: jnp.ndarray,
covariances: jnp.ndarray,
proportions: jnp.ndarray,
dim_x: int,
dim_y: int,
) -> fine.JointDistribution:
"""Builds a fine distribution from a Gaussian mixture model parameters."""
) -> bmm.JointDistribution:
"""Builds a bmm distribution from a Gaussian mixture model parameters."""
# Check if the dimensions are right
n_components = proportions.shape[0]
n_dims = dim_x + dim_y
Expand All @@ -90,7 +100,7 @@ def sample_into_fine_distribution(

# Build components
components = [
fine.MultivariateNormalDistribution(
bmm.MultivariateNormalDistribution(
dim_x=dim_x,
dim_y=dim_y,
mean=mean,
Expand All @@ -100,7 +110,7 @@ def sample_into_fine_distribution(
]

# Build a mixture model
return fine.mixture(proportions=proportions, components=components)
return bmm.mixture(proportions=proportions, components=components)


class GMMEstimatorParams(BaseModel):
Expand Down Expand Up @@ -185,12 +195,12 @@ def run_mcmc(self, x: ArrayLike, y: ArrayLike):
self._dim_x = space.dim_x
self._dim_y = space.dim_y

def get_fine_distribution(self, idx: int) -> fine.JointDistribution:
def get_bmm_distribution(self, idx: int) -> bmm.JointDistribution:
if self._mcmc is None:
raise ValueError("You need to run MCMC first. See the `run_mcmc` method.")

samples = self._mcmc.get_samples()
return sample_into_fine_distribution(
return sample_into_bmm_distribution(
means=samples["mu"][idx],
covariances=samples["cov"][idx],
proportions=samples["pi"][idx],
Expand All @@ -204,8 +214,8 @@ def get_sample_mi(self, idx: int, mc_samples: Optional[int] = None, key=None) ->
if key is None:
self.key, key = jax.random.split(self.key)

distribution = self.get_fine_distribution(idx)
mi, _ = fine.monte_carlo_mi_estimate(key=key, dist=distribution, n=mc_samples)
distribution = self.get_bmm_distribution(idx)
mi, _ = bmm.monte_carlo_mi_estimate(key=key, dist=distribution, n=mc_samples)
return mi

def get_posterior_mi(
Expand Down
4 changes: 2 additions & 2 deletions src/bmi/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

# isort: on
import bmi.samplers._tfp as fine
import bmi.samplers._tfp as bmm
from bmi.samplers._independent_coordinates import IndependentConcatenationSampler
from bmi.samplers._split_student_t import SplitStudentT
from bmi.samplers._splitmultinormal import BivariateNormalSampler, SplitMultinormal
Expand All @@ -33,7 +33,7 @@
"AdditiveUniformSampler",
"BaseSampler",
"canonical_correlation",
"fine",
"bmm",
"parametrised_correlation_matrix",
"BivariateNormalSampler",
"SplitMultinormal",
Expand Down
4 changes: 2 additions & 2 deletions src/bmi/samplers/_tfp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# isort: on
from bmi.samplers._tfp._product import ProductDistribution
from bmi.samplers._tfp._wrapper import FineSampler
from bmi.samplers._tfp._wrapper import BMMSampler

__all__ = [
"JointDistribution",
Expand All @@ -30,7 +30,7 @@
"MultivariateNormalDistribution",
"MultivariateStudentDistribution",
"ProductDistribution",
"FineSampler",
"BMMSampler",
"construct_multivariate_normal_distribution",
"construct_multivariate_student_distribution",
]
Loading
Loading