diff --git a/workflows/Mixtures/cool_tasks.smk b/workflows/Mixtures/cool_tasks.smk index 9dc2fe94..89c1b1c7 100644 --- a/workflows/Mixtures/cool_tasks.smk +++ b/workflows/Mixtures/cool_tasks.smk @@ -36,7 +36,7 @@ x_sampler = fine.FineSampler(x_dist) # The fence distribution n_components = 12 -base_dist = fine.mixture( +fence_base_dist = fine.mixture( proportions=jnp.ones(n_components) / n_components, components=[ fine.MultivariateNormalDistribution( @@ -46,7 +46,7 @@ base_dist = fine.mixture( ) for x in range(n_components) ] ) -base_sampler = fine.FineSampler(base_dist) +base_sampler = fine.FineSampler(fence_base_dist) fence_aux_sampler = bmi.samplers.TransformedSampler( base_sampler, transform_x=lambda x: x + jnp.array([5., 0.]) * jnp.sin(3 * x[1]), @@ -192,7 +192,8 @@ rule all: input: 'cool_tasks.pdf', 'results.csv', - 'results.pdf' + 'results.pdf', + 'profiles.pdf' rule plot_distributions: output: "cool_tasks.pdf" @@ -235,6 +236,23 @@ rule plot_distributions: fig.savefig(str(output)) +rule plot_pmi_profiles: + output: "profiles.pdf" + run: + fig, axs = subplots_from_axsize(1, 4, axsize=(4, 3)) + dists = [x_dist, ai_dist, fence_base_dist, balls_mixt] + tasks_official = ['X', 'AI', 'Waves', 'Galaxy'] + for dist, task_name, ax in zip(dists, tasks_official, axs): + import jax + key = jax.random.PRNGKey(1024) + pmi_values = fine.pmi_profile(key=key, dist=dist, n=100_000) + bins = np.linspace(-5, 5, 101) + ax.hist(pmi_values, bins=bins, density=True, alpha=0.5) + ax.set_xlabel(task_name) + axs[0].set_ylabel("Density") + fig.savefig(str(output)) + + rule plot_results: output: 'results.pdf' input: 'results.csv'