From 2e53a8f0a85ad400daeb4946df860155532ee5ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 23 Sep 2023 15:21:22 +0200 Subject: [PATCH] Make plots better --- workflows/Mixtures/cool_tasks.smk | 33 ++++++++++++++++-------- workflows/Mixtures/distinct_profiles.smk | 12 ++++++--- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/workflows/Mixtures/cool_tasks.smk b/workflows/Mixtures/cool_tasks.smk index cf49ec11..4714aed0 100644 --- a/workflows/Mixtures/cool_tasks.smk +++ b/workflows/Mixtures/cool_tasks.smk @@ -36,24 +36,29 @@ assert set(ESTIMATOR_NAMES.keys()) == set(ESTIMATORS.keys()) _SAMPLE_ESTIMATE: int = 200_000 +x_sampler = ed.create_x_distribution(_sample=_SAMPLE_ESTIMATE).sampler +ai_sampler = ed.create_ai_distribution(_sample=_SAMPLE_ESTIMATE).sampler +waves_sampler = ed.create_waves_distribution(_sample=_SAMPLE_ESTIMATE).sampler +galaxy_sampler = ed.create_galaxy_distribution(_sample=_SAMPLE_ESTIMATE).sampler + UNSCALED_TASKS = { "X": bmi.benchmark.Task( - sampler=ed.create_x_distribution(_sample=_SAMPLE_ESTIMATE).sampler, + sampler=x_sampler, task_id="X", task_name="X", ), "AI": bmi.benchmark.Task( - sampler=ed.create_ai_distribution(_sample=_SAMPLE_ESTIMATE).sampler, + sampler=ai_sampler, task_id="AI", task_name="AI", ), "Fence": bmi.benchmark.Task( - sampler=ed.create_waves_distribution(_sample=_SAMPLE_ESTIMATE).sampler, + sampler=waves_sampler, task_id="Fence", task_name="Fence", ), "Balls": bmi.benchmark.Task( - sampler=ed.create_galaxy_distribution(_sample=_SAMPLE_ESTIMATE).sampler, + sampler=galaxy_sampler, task_id="Balls", task_name="Balls", ), @@ -73,41 +78,47 @@ rule all: rule plot_distributions: output: "cool_tasks.pdf" run: - fig, axs = subplots_from_axsize(1, 4, axsize=(3, 3)) + fig, axs = subplots_from_axsize(1, 4, axsize=(1.5, 1.5)) # Plot the X distribution ax = axs[0] xs, ys = x_sampler.sample(1000, 0) - ax.scatter(xs[:, 0], ys[:, 0], s=4**2, alpha=0.3, color="k", rasterized=True) + size = 2**2 + + ax.scatter(xs[:, 0], ys[:, 0], s=size, alpha=0.3, color="k", rasterized=True) ax.set_xlabel("$X$") ax.set_ylabel("$Y$") # Plot the AI distribution ax = axs[1] xs, ys = ai_sampler.sample(2000, 0) - ax.scatter(xs[:, 0], ys[:, 0], s=4**2, alpha=0.3, color="k", rasterized=True) + ax.scatter(xs[:, 0], ys[:, 0], s=size, alpha=0.3, color="k", rasterized=True) ax.set_xlabel("$X$") ax.set_ylabel("$Y$") # Plot the fence distribution ax = axs[2] - xs, ys = fence_sampler.sample(2000, 0) + xs, ys = waves_sampler.sample(2000, 0) - ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=4**2, alpha=0.3, rasterized=True) + ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=size, alpha=0.3, rasterized=True) ax.set_xlabel("$X_1$") ax.set_ylabel("$X_2$") # Plot transformed balls distribution ax = axs[3] - xs, ys = sampler_balls_transformed.sample(2000, 0) - ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=4**2, alpha=0.3, rasterized=True) + xs, ys = galaxy_sampler.sample(2000, 0) + ax.scatter(xs[:, 0], xs[:, 1], c=ys[:, 0], s=size, alpha=0.3, rasterized=True) ax.set_xlabel("$X_1$") ax.set_ylabel("$X_2$") for ax in axs: + ticks = [-1, 0, 1] + ax.set_xticks(ticks, ticks) + ax.set_yticks(ticks, ticks) ax.set_xlim(-2., 2.) ax.set_ylim(-2., 2.) + ax.spines[['right', 'top']].set_visible(False) fig.savefig(str(output)) diff --git a/workflows/Mixtures/distinct_profiles.smk b/workflows/Mixtures/distinct_profiles.smk index 42865303..f912c7f1 100644 --- a/workflows/Mixtures/distinct_profiles.smk +++ b/workflows/Mixtures/distinct_profiles.smk @@ -15,6 +15,8 @@ import jax.numpy as jnp import bmi.samplers._tfp as bmi_tfp from bmi.transforms import invert_cdf, normal_cdf +from subplots_from_axsize import subplots_from_axsize + mpl.use("Agg") @@ -114,6 +116,8 @@ def hide_ticks(ax): ax.set_yticks([]) ax.set_xlabel("$X$") ax.set_ylabel("$Y$") + ax.spines[['right', 'top']].set_visible(False) + rule plot_samples: input: @@ -123,7 +127,7 @@ rule plot_samples: output: "figure_distinct_profiles.pdf" run: - fig, axs = plt.subplots(1, 4, figsize=(8, 2)) + fig, axs = plt.subplots(1, 4, figsize=(7, 2)) color1 = "navy" color2 = "salmon" @@ -159,7 +163,9 @@ rule plot_samples: ax.hist(pmi_u, bins=bins, density=True, color=color2, alpha=0.5, label="Mixture") ax.set_title("PMI profiles") ax.set_xlabel("PMI") - ax.set_ylabel("Density") + ax.set_ylabel("") + ax.set_yticks([]) + ax.spines[['right', 'top', 'left']].set_visible(False) mi_1 = jnp.mean(pmi_normal) mi_2 = jnp.mean(pmi_u) @@ -167,7 +173,7 @@ rule plot_samples: if abs(mi_1 - mi_2) > 0.01: raise ValueError(f"MI different: {mi_1:.2f} != {mi_2:.2f}") - ax.axvline(mi_1, c="k", linewidth=0.5, linestyle="--") + ax.axvline(mi_1, c="k", linewidth=1, linestyle="--") fig.tight_layout() fig.savefig(str(output))