Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Figure presenting BMMs vs other estimators #157

Merged
merged 7 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions scripts/Mixtures/plot_appearing_and_vanishing_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,30 @@ def main() -> None:

X, Y = np.meshgrid(x, y)

fig, axs = plt.subplots(2, 3, dpi=300)
fig, axs = plt.subplots(1, 5, dpi=300, sharex=True, sharey=True, figsize=(5.5, 1.2))

# First row: appearing MI
# Component 1 (bottom left)
ax = axs[0, 0]
ax = axs[0]
mask1 = (0 < X) & (X < 1) & (0 < Y) & (Y < 1)
plot_density(ax, mask1, "$I=0$")

# Component 2 (top right)
ax = axs[0, 1]
ax = axs[1]
mask2 = (1 < X) & (X < 2) & (1 < Y) & (Y < 2)
plot_density(ax, mask2, "$I=0$")

# Mixture
ax = axs[0, 2]
ax = axs[2]
mask3 = mask1 | mask2
plot_density(ax, 0.5 * mask3, "$I=\\log 2$")

# Second row
# Component 1: mixture from first row
ax = axs[1, 0]
plot_density(ax, 0.5 * mask3, "$I=\\log 2$")

# Component 2: symmetric mixture
ax = axs[1, 1]
# A "complementary" mixture
ax = axs[3]
mask4 = (0 < X) & (X < 1) & (1 < Y) & (Y < 2) | (1 < X) & (X < 2) & (0 < Y) & (Y < 1)
plot_density(ax, 0.5 * mask4, "$I=\\log 2$")

# Mixture: independent
ax = axs[1, 2]
ax = axs[4]
plot_density(ax, 0.25 * (mask3 | mask4), "$I=0$")

fig.tight_layout()
Expand Down
7 changes: 4 additions & 3 deletions workflows/projects/Mixtures/distinct_profiles.smk
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ rule plot_samples:
output:
"figure_distinct_profiles.pdf"
run:
fig, axs = plt.subplots(1, 4, figsize=(7, 2))
fig, axs = plt.subplots(1, 4, figsize=(7, 1.5), dpi=500)

color1 = "navy"
color2 = "salmon"
color1 = "mediumblue"
color2 = "forestgreen"

# Plot normal distribution
ax = axs[0]
Expand Down Expand Up @@ -163,6 +163,7 @@ rule plot_samples:
ax.hist(pmi_u, bins=bins, density=True, color=color2, alpha=0.5, label="Mixture")
ax.set_title("PMI profiles")
ax.set_xlabel("PMI")
ax.set_xlim(-1, 2)
ax.set_ylabel("")
ax.set_yticks([])
ax.spines[['right', 'top', 'left']].set_visible(False)
Expand Down
150 changes: 150 additions & 0 deletions workflows/projects/Mixtures/figure_bmm_vs_other.smk
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Figure comparing BMMs and other estimators on selected problems.
# Note: to run this workflow, you need to have the results.csv files from:
# - the benchmark (version 2) in `generated/benchmark/v2/results.csv`
# - the BMM minibenchmark in `generated/projects/Mixtures/gmm_benchmark/results.csv`
from dataclasses import dataclass

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("Agg")

import numpy as np
import pandas as pd
from subplots_from_axsize import subplots_from_axsize



rule all:
input: "generated/projects/Mixtures/figure_bmm_vs_other.pdf"


class YScaler:
def __init__(self, estimator_ids: list[str], eps: float = 0.1):
self._estimator_ids = estimator_ids
assert eps > 0
self._eps = eps

@property
def n(self) -> int:
return len(self._estimator_ids)

@property
def offset(self) -> float:
return self._eps * 0.5 * 1 / self.n

def get_y(self, estimator_id: str, n_points: int) -> np.ndarray:
index = self._estimator_ids.index(estimator_id)
y0 = index / self.n
y1 = (index + 1) / self.n

return np.linspace(y0 + self._eps, y1 - self._eps, n_points)

def get_tick_locations(self) -> list[float]:
return (np.arange(self.n, dtype=float) + 0.5) / self.n


@dataclass
class TaskConfig:
name: str
xlim: tuple[float, float]
xticks: list[float] | tuple[float, ...]

@dataclass
class EstimatorConfig:
id: str
name: str
color: str

TASKS = {# task_id: task_name,
'1v1-AI': TaskConfig(name="AI", xlim=(0.5, 0.85), xticks=[0.6, 0.7, 0.8]),
'mult-sparse-w-inliers-5-5-2-2.0-0.2': TaskConfig(name="Inliers (5-dim, 0.2)", xlim=(0.4, 0.8), xticks=[0.45, 0.55, 0.65, 0.75]),
'5v1-concentric_gaussians-5': TaskConfig(name="Concentric (5-dim, 5)", xlim=(0.35, 0.75), xticks=[0.4, 0.5, 0.6, 0.7]),
'multinormal-sparse-5-5-2-2.0': TaskConfig(name="Normal (5-dim, sparse)", xlim=(0.65, 1.15), xticks=[0.7, 0.8, 0.9, 1.0, 1.1]),
}


# NAMES = {
# # one-dimensional
# '1v1-additive-0.75': "Additive",
# '1v1-AI': "AI",
# '1v1-X-0.9': "X",
# '2v1-galaxy-0.5-3.0': "Galaxy",
# # Concentric
# '3v1-concentric_gaussians-10': "Concentric (3-dim, 10)",
# '3v1-concentric_gaussians-5': "Concentric (3-dim, 5)",
# '5v1-concentric_gaussians-10': "Concentric (5-dim, 10)",
# '5v1-concentric_gaussians-5': "Concentric (5-dim, 5)",
# # Inliers
# 'mult-sparse-w-inliers-5-5-2-2.0-0.2': "Inliers (5-dim, 0.2)",
# 'mult-sparse-w-inliers-5-5-2-2.0-0.5': "Inliers (5-dim, 0.5)",
# # Multivariate normal
# 'multinormal-dense-5-5-0.5': "Normal (5-dim, dense)",
# 'multinormal-sparse-5-5-2-2.0': "Normal (5-dim, sparse)",
# # Student
# 'asinh-student-identity-1-1-1': "Student (1-dim)",
# 'asinh-student-identity-2-2-1': "Student (2-dim)",
# 'asinh-student-identity-3-3-2': "Student (3-dim)",
# 'asinh-student-identity-5-5-2': "Student (5-dim)",
# }

# TASKS = {
# id_v: TaskConfig(name=name, xlim=(0.2, 1), xticks=[]) for id_v, name in NAMES.items()
# }


N_SAMPLES = 5_000
POINT_ESTIMATORS = [
EstimatorConfig(id="KSG-10", name="KSG", color="green"),
EstimatorConfig(id="InfoNCE", name="InfoNCE", color="magenta"),
]

DOT_SIZE = 7


rule generate_figure:
output: "generated/projects/Mixtures/figure_bmm_vs_other.pdf"
input:
v2 = "generated/benchmark/v2/results.csv",
bmm = "generated/projects/Mixtures/gmm_benchmark/results.csv"
run:
data_v2 = pd.read_csv(input.v2)
data_bmm = pd.read_csv(input.bmm)

fig, axs = subplots_from_axsize(1, len(TASKS), (2.3, 0.8), left=0.8, right=0.05, top=0.3, bottom=0.3, dpi=350, wspace=0.05)

y_scaler = YScaler(estimator_ids=["BMM"] + [config.id for config in POINT_ESTIMATORS], eps=0.12)

for ax, (task_id, task_config) in zip(axs.ravel(), TASKS.items()):
ax.set_title(task_config.name)
ax.set_xlim(*task_config.xlim)
ax.set_xticks(task_config.xticks)
ax.set_yticks([])
ax.set_ylim(-0.05, 1.01)
ax.spines[["top", "left", "right"]].set_visible(False)

mi_true = data_v2.groupby("task_id")["mi_true"].mean()[task_id]
ax.axvline(mi_true, linestyle=":", color="black", linewidth=2)

# Plot credible intervals from the BMM
bmm_subtable = data_bmm[(data_bmm["task_id"] == task_id)].copy()
bmm_subtable["errorbar_low"] = bmm_subtable["mi_mean"] - bmm_subtable["mi_q_low"]
bmm_subtable["errorbar_high"] = bmm_subtable["mi_q_high"] - bmm_subtable["mi_mean"]

y = y_scaler.get_y(estimator_id="BMM", n_points=len(bmm_subtable))
ax.errorbar(x=bmm_subtable["mi_mean"].values, y=y, xerr=bmm_subtable[["errorbar_low", "errorbar_high"]].T, capsize=3, ls="none", color="darkblue")
ax.scatter(x=bmm_subtable["mi_mean"].values, y=y, color="darkblue", s=DOT_SIZE)

# Plot the scatterplot representing estimators
for estimator_config in POINT_ESTIMATORS:
estimator_id = estimator_config.id

index = (data_v2["task_id"] == task_id) & (data_v2["estimator_id"] == estimator_id) & (data_v2["n_samples"] == N_SAMPLES)
estimates = data_v2[index]["mi_estimate"].values
y = y_scaler.get_y(estimator_id=estimator_id, n_points=len(estimates))
ax.scatter(estimates, y, color=estimator_config.color, s=DOT_SIZE, alpha=0.4)

ax = axs[0]
ax.set_yticks(y_scaler.get_tick_locations(), ["BMM"] + [config.name for config in POINT_ESTIMATORS])
ax.spines["left"].set_visible(True)

fig.savefig(str(output))
20 changes: 12 additions & 8 deletions workflows/projects/Mixtures/fitting_gmm.smk
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ DISTRIBUTIONS = {
rule all:
# For the main part of the manuscript
input:
expand("plots/{dist_name}-{n_points}-10.pdf", dist_name=["AI", "Galaxy"], n_points=[250])
expand("plots/{dist_name}-{n_points}-10.pdf", dist_name=["AI", "Galaxy"], n_points=[500])


rule plots_all:
Expand Down Expand Up @@ -192,20 +192,22 @@ rule plot_pdf:
approx_sample = "approx_samples/{dist_name}-{n_points}-{n_components}-0.npz",
output: "plots/{dist_name}-{n_points}-{n_components}.pdf"
run:
fig, axs = subplots_from_axsize(1, 4, axsize=(1.5, 1.5), top=0.3, wspace=0.3)
fig, axs = subplots_from_axsize(1, 4, axsize=(1.2, 1.2), top=0.3, wspace=[0.3, 0.05, 0.05], left=0.5, right=0.15)

for ax in axs:
ax.spines[['right', 'top']].set_visible(False)

FONTDICT = {'fontsize': 10}

# Visualise true sample
ax = axs[0]
ax.set_title("Ground-truth sample")
ax.set_title("Ground-truth sample", fontdict=FONTDICT)
true_sample = np.load(input.true_sample)
visualise_points(true_sample["xs"], true_sample["ys"], ax)

# Visualise approximate sample
ax = axs[1]
ax.set_title("Simulated sample")
ax.set_title("Simulated sample", fontdict=FONTDICT)
approx_sample = np.load(input.approx_sample)
visualise_points(approx_sample["xs"], approx_sample["ys"], ax)

Expand All @@ -214,7 +216,7 @@ rule plot_pdf:

# Visualise posterior on mutual information
ax = axs[2]
ax.set_title("Posterior MI")
ax.set_title("Posterior MI", fontdict=FONTDICT)
mi_true = np.mean(pmi_true)
mi_approx = np.mean(pmi_approx, axis=1) # (num_mcmc_samples,)
ax.set_xlabel("MI")
Expand All @@ -225,11 +227,13 @@ rule plot_pdf:

# Visualise posterior on profile
ax = axs[3]
ax.set_title("Posterior PMI profile")
ax.set_title("Posterior PMI profile", fontdict=FONTDICT)
ax.set_xlabel("PMI")

min_val = np.min([pmi_true.min(), pmi_approx.min()])
max_val = np.max([pmi_true.max(), pmi_approx.max()])
quantile_min = 0.02
quantile_max = 1 - quantile_min
min_val = np.min([np.quantile(pmi_true, quantile_min), np.quantile(pmi_approx, quantile_min)])
max_val = np.max([np.quantile(pmi_true, quantile_max), np.quantile(pmi_approx, quantile_max)])

bins = np.linspace(min_val, max_val, 50)
for pmi_vals in pmi_approx:
Expand Down
Loading