Skip to content

Commit

Permalink
Add PMI profile drawing
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Sep 14, 2023
1 parent 4b3ff09 commit 50b3961
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions workflows/Mixtures/cool_tasks.smk
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ x_sampler = fine.FineSampler(x_dist)
# The fence distribution
n_components = 12

base_dist = fine.mixture(
fence_base_dist = fine.mixture(
proportions=jnp.ones(n_components) / n_components,
components=[
fine.MultivariateNormalDistribution(
Expand All @@ -46,7 +46,7 @@ base_dist = fine.mixture(
) for x in range(n_components)
]
)
base_sampler = fine.FineSampler(base_dist)
base_sampler = fine.FineSampler(fence_base_dist)
fence_aux_sampler = bmi.samplers.TransformedSampler(
base_sampler,
transform_x=lambda x: x + jnp.array([5., 0.]) * jnp.sin(3 * x[1]),
Expand Down Expand Up @@ -192,7 +192,8 @@ rule all:
input:
'cool_tasks.pdf',
'results.csv',
'results.pdf'
'results.pdf',
'profiles.pdf'

rule plot_distributions:
output: "cool_tasks.pdf"
Expand Down Expand Up @@ -235,6 +236,23 @@ rule plot_distributions:

fig.savefig(str(output))

rule plot_pmi_profiles:
output: "profiles.pdf"
run:
fig, axs = subplots_from_axsize(1, 4, axsize=(4, 3))
dists = [x_dist, ai_dist, fence_base_dist, balls_mixt]
tasks_official = ['X', 'AI', 'Waves', 'Galaxy']
for dist, task_name, ax in zip(dists, tasks_official, axs):
import jax
key = jax.random.PRNGKey(1024)
pmi_values = fine.pmi_profile(key=key, dist=dist, n=100_000)
bins = np.linspace(-5, 5, 101)
ax.hist(pmi_values, bins=bins, density=True, alpha=0.5)
ax.set_xlabel(task_name)
axs[0].set_ylabel("Density")
fig.savefig(str(output))


rule plot_results:
output: 'results.pdf'
input: 'results.csv'
Expand Down

0 comments on commit 50b3961

Please sign in to comment.