Skip to content

Commit

Permalink
Update to the integration workflow (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz authored Sep 21, 2023
1 parent 21dbdca commit de2264a
Showing 1 changed file with 44 additions and 15 deletions.
59 changes: 44 additions & 15 deletions workflows/Mixtures/how_good_integration_is.smk
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import json
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns

Expand Down Expand Up @@ -40,38 +41,66 @@ class DistributionAndPMI:
workdir: "generated/mixtures/how_good_integration_is"



ESTIMATORS: dict[str, Callable] = {
"NWJ": nwj,
"NWJ-Shifted": nwj_shifted,
"NWJ": nwj_shifted,
"InfoNCE": infonce,
"DV": donsker_varadhan,
"MC": monte_carlo
}

_normal_dist = fine.MultivariateNormalDistribution(dim_x=2, dim_y=2, covariance=bmi.samplers.canonical_correlation(rho=[0.8, 0.8]))

four_balls = fine.mixture(
proportions=jnp.array([0.3, 0.3, 0.2, 0.2]),
components=[
fine.MultivariateNormalDistribution(
covariance=bmi.samplers.canonical_correlation([0.0]),
mean=jnp.array([-1.25, -1.25]),
dim_x=1, dim_y=1,
),
fine.MultivariateNormalDistribution(
covariance=bmi.samplers.canonical_correlation([0.0]),
mean=jnp.array([+1.25, +1.25]),
dim_x=1, dim_y=1,
),
fine.MultivariateNormalDistribution(
covariance=0.2 * bmi.samplers.canonical_correlation([0.0]),
mean=jnp.array([-2.5, +2.5]),
dim_x=1, dim_y=1,
),
fine.MultivariateNormalDistribution(
covariance=0.2 * bmi.samplers.canonical_correlation([0.0]),
mean=jnp.array([+2.5, -2.5]),
dim_x=1, dim_y=1,
),
]
)


_DISTRIBUTIONS: dict[str, DistributionAndPMI] = {
"Normal": DistributionAndPMI(
dist=_normal_dist,
"Four_Balls": DistributionAndPMI(
dist=four_balls,
),
"NormalBiased": DistributionAndPMI(
dist=_normal_dist,
pmi=lambda x, y: _normal_dist.pmi(x, y) + 0.5,
"Four_Balls_Biased": DistributionAndPMI(
dist=four_balls,
pmi=lambda x, y: four_balls.pmi(x, y) + 0.5,
),
"NormalSinSquare": DistributionAndPMI(
dist=_normal_dist,
pmi=lambda x, y: _normal_dist.pmi(x, y) + jnp.sin(jnp.square(x[..., 0])),
"Four_Balls_SinSquare": DistributionAndPMI(
dist=four_balls,
pmi=lambda x, y: four_balls.pmi(x, y) + jnp.sin(jnp.square(x[..., 0])),
),
"Student": DistributionAndPMI(
dist=fine.MultivariateStudentDistribution(dim_x=2, dim_y=2, dispersion=bmi.samplers.canonical_correlation(rho=[0.8, 0.8]), df=3),
"Normal-25Dim": DistributionAndPMI(
dist=fine.MultivariateNormalDistribution(dim_x=25, dim_y=25, covariance=bmi.samplers.canonical_correlation(rho=[0.8] * 25))
),
}
# If PMI is left as None, override it with the PMI of the distribution
DISTRIBUTION_AND_PMIS = {
name: DistributionAndPMI(dist=value.dist, pmi=value.dist.pmi) if value.pmi is None else value for name, value in _DISTRIBUTIONS.items()
}

N_POINTS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]
SEEDS = list(range(10))
N_POINTS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
SEEDS = list(range(20))


rule all:
input:
Expand Down

0 comments on commit de2264a

Please sign in to comment.