From 1ed30a3345f4b6148ea0426072acb719f258faad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 23 Sep 2023 16:01:16 +0200 Subject: [PATCH] Adjust the plots for fitting GMMs --- workflows/Mixtures/fitting_gmm.smk | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/workflows/Mixtures/fitting_gmm.smk b/workflows/Mixtures/fitting_gmm.smk index d38d8acd..5d74beb2 100644 --- a/workflows/Mixtures/fitting_gmm.smk +++ b/workflows/Mixtures/fitting_gmm.smk @@ -162,12 +162,21 @@ def visualise_points(xs, ys, ax): dim_x = xs.shape[1] if dim_x == 2: - ax.scatter(xs[..., 0], xs[..., 1], c=ys, s=3) + ax.scatter(xs[..., 0], xs[..., 1], c=ys, s=3, rasterized=True) + ax.set_xlabel("$X_1$") + ax.set_ylabel("$X_2$") elif dim_x == 1: - ax.scatter(xs[..., 0], ys[..., 0], c="k", s=3, alpha=0.3) + ax.scatter(xs[..., 0], ys[..., 0], c="k", s=3, alpha=0.3, rasterized=True) + ax.set_xlabel("$X$") + ax.set_ylabel("$Y$") else: raise ValueError(f"X dimension is {dim_x} and cannot be visualised") + ax.set_xlim(-2, 2) + ax.set_ylim(-2, 2) + ticks = [-1, 0, 1] + ax.set_xticks(ticks, ticks) + ax.set_yticks(ticks, ticks) rule plot_pdf: input: @@ -177,17 +186,20 @@ rule plot_pdf: approx_sample = "approx_samples/{dist_name}-{n_points}-{n_components}-0.npz", output: "plots/{dist_name}-{n_points}-{n_components}.pdf" run: - fig, axs = subplots_from_axsize(1, 4, axsize=(2, 1.5), top=0.3) + fig, axs = subplots_from_axsize(1, 4, axsize=(2, 1.5), top=0.3, wspace=0.5) + + for ax in axs: + ax.spines[['right', 'top']].set_visible(False) # Visualise true sample ax = axs[0] - ax.set_title("True distribution") + ax.set_title("Ground-truth sample") true_sample = np.load(input.true_sample) visualise_points(true_sample["xs"], true_sample["ys"], ax) # Visualise approximate sample ax = axs[1] - ax.set_title("Approximate distribution") + ax.set_title("Simulated sample") approx_sample = np.load(input.approx_sample) visualise_points(approx_sample["xs"], approx_sample["ys"], ax) @@ -199,6 +211,7 @@ rule plot_pdf: ax.set_title("Posterior MI") mi_true = np.mean(pmi_true) mi_approx = np.mean(pmi_approx, axis=1) # (num_mcmc_samples,) + ax.set_xlabel("MI") ax.hist(mi_approx, bins=50, density=True, alpha=0.5, color="red") ax.axvline(mi_approx.mean(), color="red") # Visualise posterior mean @@ -207,6 +220,7 @@ rule plot_pdf: # Visualise posterior on profile ax = axs[3] ax.set_title("Posterior PMI profile") + ax.set_xlabel("PMI") min_val = np.min([pmi_true.min(), pmi_approx.min()]) max_val = np.max([pmi_true.max(), pmi_approx.max()]) @@ -219,6 +233,10 @@ rule plot_pdf: prof_true, _ = np.histogram(pmi_true, bins=bins, density=True) ax.stairs(prof_true, edges=bins, color="k", alpha=1) + for ax in [axs[2], axs[3]]: + ax.set_ylabel("") + ax.set_yticks([]) + ax.spines[['right', 'top', 'left']].set_visible(False) fig.savefig(str(output))