Skip to content

Commit

Permalink
Adjust the plots for fitting GMMs
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Sep 23, 2023
1 parent 2e53a8f commit 1ed30a3
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions workflows/Mixtures/fitting_gmm.smk
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,21 @@ def visualise_points(xs, ys, ax):
dim_x = xs.shape[1]

if dim_x == 2:
ax.scatter(xs[..., 0], xs[..., 1], c=ys, s=3)
ax.scatter(xs[..., 0], xs[..., 1], c=ys, s=3, rasterized=True)
ax.set_xlabel("$X_1$")
ax.set_ylabel("$X_2$")
elif dim_x == 1:
ax.scatter(xs[..., 0], ys[..., 0], c="k", s=3, alpha=0.3)
ax.scatter(xs[..., 0], ys[..., 0], c="k", s=3, alpha=0.3, rasterized=True)
ax.set_xlabel("$X$")
ax.set_ylabel("$Y$")
else:
raise ValueError(f"X dimension is {dim_x} and cannot be visualised")

ax.set_xlim(-2, 2)
ax.set_ylim(-2, 2)
ticks = [-1, 0, 1]
ax.set_xticks(ticks, ticks)
ax.set_yticks(ticks, ticks)

rule plot_pdf:
input:
Expand All @@ -177,17 +186,20 @@ rule plot_pdf:
approx_sample = "approx_samples/{dist_name}-{n_points}-{n_components}-0.npz",
output: "plots/{dist_name}-{n_points}-{n_components}.pdf"
run:
fig, axs = subplots_from_axsize(1, 4, axsize=(2, 1.5), top=0.3)
fig, axs = subplots_from_axsize(1, 4, axsize=(2, 1.5), top=0.3, wspace=0.5)

for ax in axs:
ax.spines[['right', 'top']].set_visible(False)

# Visualise true sample
ax = axs[0]
ax.set_title("True distribution")
ax.set_title("Ground-truth sample")
true_sample = np.load(input.true_sample)
visualise_points(true_sample["xs"], true_sample["ys"], ax)

# Visualise approximate sample
ax = axs[1]
ax.set_title("Approximate distribution")
ax.set_title("Simulated sample")
approx_sample = np.load(input.approx_sample)
visualise_points(approx_sample["xs"], approx_sample["ys"], ax)

Expand All @@ -199,6 +211,7 @@ rule plot_pdf:
ax.set_title("Posterior MI")
mi_true = np.mean(pmi_true)
mi_approx = np.mean(pmi_approx, axis=1) # (num_mcmc_samples,)
ax.set_xlabel("MI")

ax.hist(mi_approx, bins=50, density=True, alpha=0.5, color="red")
ax.axvline(mi_approx.mean(), color="red") # Visualise posterior mean
Expand All @@ -207,6 +220,7 @@ rule plot_pdf:
# Visualise posterior on profile
ax = axs[3]
ax.set_title("Posterior PMI profile")
ax.set_xlabel("PMI")

min_val = np.min([pmi_true.min(), pmi_approx.min()])
max_val = np.max([pmi_true.max(), pmi_approx.max()])
Expand All @@ -219,6 +233,10 @@ rule plot_pdf:
prof_true, _ = np.histogram(pmi_true, bins=bins, density=True)
ax.stairs(prof_true, edges=bins, color="k", alpha=1)

for ax in [axs[2], axs[3]]:
ax.set_ylabel("")
ax.set_yticks([])
ax.spines[['right', 'top', 'left']].set_visible(False)

fig.savefig(str(output))

Expand Down

0 comments on commit 1ed30a3

Please sign in to comment.