Skip to content

Commit

Permalink
Make plots better
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Sep 23, 2023
1 parent bd69dad commit 2e53a8f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
33 changes: 22 additions & 11 deletions workflows/Mixtures/cool_tasks.smk
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,29 @@ assert set(ESTIMATOR_NAMES.keys()) == set(ESTIMATORS.keys())

_SAMPLE_ESTIMATE: int = 200_000

x_sampler = ed.create_x_distribution(_sample=_SAMPLE_ESTIMATE).sampler
ai_sampler = ed.create_ai_distribution(_sample=_SAMPLE_ESTIMATE).sampler
waves_sampler = ed.create_waves_distribution(_sample=_SAMPLE_ESTIMATE).sampler
galaxy_sampler = ed.create_galaxy_distribution(_sample=_SAMPLE_ESTIMATE).sampler

UNSCALED_TASKS = {
"X": bmi.benchmark.Task(
sampler=ed.create_x_distribution(_sample=_SAMPLE_ESTIMATE).sampler,
sampler=x_sampler,
task_id="X",
task_name="X",
),
"AI": bmi.benchmark.Task(
sampler=ed.create_ai_distribution(_sample=_SAMPLE_ESTIMATE).sampler,
sampler=ai_sampler,
task_id="AI",
task_name="AI",
),
"Fence": bmi.benchmark.Task(
sampler=ed.create_waves_distribution(_sample=_SAMPLE_ESTIMATE).sampler,
sampler=waves_sampler,
task_id="Fence",
task_name="Fence",
),
"Balls": bmi.benchmark.Task(
sampler=ed.create_galaxy_distribution(_sample=_SAMPLE_ESTIMATE).sampler,
sampler=galaxy_sampler,
task_id="Balls",
task_name="Balls",
),
Expand All @@ -73,41 +78,47 @@ rule all:
rule plot_distributions:
output: "cool_tasks.pdf"
run:
fig, axs = subplots_from_axsize(1, 4, axsize=(3, 3))
fig, axs = subplots_from_axsize(1, 4, axsize=(1.5, 1.5))

# Plot the X distribution
ax = axs[0]
xs, ys = x_sampler.sample(1000, 0)

ax.scatter(xs[:, 0], ys[:, 0], s=4**2, alpha=0.3, color="k", rasterized=True)
size = 2**2

ax.scatter(xs[:, 0], ys[:, 0], s=size, alpha=0.3, color="k", rasterized=True)
ax.set_xlabel("$X$")
ax.set_ylabel("$Y$")

# Plot the AI distribution
ax = axs[1]
xs, ys = ai_sampler.sample(2000, 0)
ax.scatter(xs[:, 0], ys[:, 0], s=4**2, alpha=0.3, color="k", rasterized=True)
ax.scatter(xs[:, 0], ys[:, 0], s=size, alpha=0.3, color="k", rasterized=True)
ax.set_xlabel("$X$")
ax.set_ylabel("$Y$")

# Plot the fence distribution
ax = axs[2]
xs, ys = fence_sampler.sample(2000, 0)
xs, ys = waves_sampler.sample(2000, 0)

ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=4**2, alpha=0.3, rasterized=True)
ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=size, alpha=0.3, rasterized=True)
ax.set_xlabel("$X_1$")
ax.set_ylabel("$X_2$")

# Plot transformed balls distribution
ax = axs[3]
xs, ys = sampler_balls_transformed.sample(2000, 0)
ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=4**2, alpha=0.3, rasterized=True)
xs, ys = galaxy_sampler.sample(2000, 0)
ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=size, alpha=0.3, rasterized=True)
ax.set_xlabel("$X_1$")
ax.set_ylabel("$X_2$")

for ax in axs:
ticks = [-1, 0, 1]
ax.set_xticks(ticks, ticks)
ax.set_yticks(ticks, ticks)
ax.set_xlim(-2., 2.)
ax.set_ylim(-2., 2.)
ax.spines[['right', 'top']].set_visible(False)

fig.savefig(str(output))

Expand Down
12 changes: 9 additions & 3 deletions workflows/Mixtures/distinct_profiles.smk
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import jax.numpy as jnp
import bmi.samplers._tfp as bmi_tfp
from bmi.transforms import invert_cdf, normal_cdf

from subplots_from_axsize import subplots_from_axsize

mpl.use("Agg")


Expand Down Expand Up @@ -114,6 +116,8 @@ def hide_ticks(ax):
ax.set_yticks([])
ax.set_xlabel("$X$")
ax.set_ylabel("$Y$")
ax.spines[['right', 'top']].set_visible(False)


rule plot_samples:
input:
Expand All @@ -123,7 +127,7 @@ rule plot_samples:
output:
"figure_distinct_profiles.pdf"
run:
fig, axs = plt.subplots(1, 4, figsize=(8, 2))
fig, axs = plt.subplots(1, 4, figsize=(7, 2))

color1 = "navy"
color2 = "salmon"
Expand Down Expand Up @@ -159,15 +163,17 @@ rule plot_samples:
ax.hist(pmi_u, bins=bins, density=True, color=color2, alpha=0.5, label="Mixture")
ax.set_title("PMI profiles")
ax.set_xlabel("PMI")
ax.set_ylabel("Density")
ax.set_ylabel("")
ax.set_yticks([])
ax.spines[['right', 'top', 'left']].set_visible(False)

mi_1 = jnp.mean(pmi_normal)
mi_2 = jnp.mean(pmi_u)

if abs(mi_1 - mi_2) > 0.01:
raise ValueError(f"MI different: {mi_1:.2f} != {mi_2:.2f}")

ax.axvline(mi_1, c="k", linewidth=0.5, linestyle="--")
ax.axvline(mi_1, c="k", linewidth=1, linestyle="--")

fig.tight_layout()
fig.savefig(str(output))

0 comments on commit 2e53a8f

Please sign in to comment.