diff --git a/workflows/Mixtures/how_good_integration_is.smk b/workflows/Mixtures/how_good_integration_is.smk index 03c8bc99..47b434ee 100644 --- a/workflows/Mixtures/how_good_integration_is.smk +++ b/workflows/Mixtures/how_good_integration_is.smk @@ -9,6 +9,7 @@ import json import numpy as np import pandas as pd import matplotlib +matplotlib.use("Agg") import matplotlib.pyplot as plt import seaborn as sns @@ -40,29 +41,56 @@ class DistributionAndPMI: workdir: "generated/mixtures/how_good_integration_is" + ESTIMATORS: dict[str, Callable] = { - "NWJ": nwj, - "NWJ-Shifted": nwj_shifted, + "NWJ": nwj_shifted, "InfoNCE": infonce, "DV": donsker_varadhan, "MC": monte_carlo } -_normal_dist = fine.MultivariateNormalDistribution(dim_x=2, dim_y=2, covariance=bmi.samplers.canonical_correlation(rho=[0.8, 0.8])) + +four_balls = fine.mixture( + proportions=jnp.array([0.3, 0.3, 0.2, 0.2]), + components=[ + fine.MultivariateNormalDistribution( + covariance=bmi.samplers.canonical_correlation([0.0]), + mean=jnp.array([-1.25, -1.25]), + dim_x=1, dim_y=1, + ), + fine.MultivariateNormalDistribution( + covariance=bmi.samplers.canonical_correlation([0.0]), + mean=jnp.array([+1.25, +1.25]), + dim_x=1, dim_y=1, + ), + fine.MultivariateNormalDistribution( + covariance=0.2 * bmi.samplers.canonical_correlation([0.0]), + mean=jnp.array([-2.5, +2.5]), + dim_x=1, dim_y=1, + ), + fine.MultivariateNormalDistribution( + covariance=0.2 * bmi.samplers.canonical_correlation([0.0]), + mean=jnp.array([+2.5, -2.5]), + dim_x=1, dim_y=1, + ), + ] +) + + _DISTRIBUTIONS: dict[str, DistributionAndPMI] = { - "Normal": DistributionAndPMI( - dist=_normal_dist, + "Four_Balls": DistributionAndPMI( + dist=four_balls, ), - "NormalBiased": DistributionAndPMI( - dist=_normal_dist, - pmi=lambda x, y: _normal_dist.pmi(x, y) + 0.5, + "Four_Balls_Biased": DistributionAndPMI( + dist=four_balls, + pmi=lambda x, y: four_balls.pmi(x, y) + 0.5, ), - "NormalSinSquare": DistributionAndPMI( - dist=_normal_dist, - pmi=lambda x, y: _normal_dist.pmi(x, y) + jnp.sin(jnp.square(x[..., 0])), + "Four_Balls_SinSquare": DistributionAndPMI( + dist=four_balls, + pmi=lambda x, y: four_balls.pmi(x, y) + jnp.sin(jnp.square(x[..., 0])), ), - "Student": DistributionAndPMI( - dist=fine.MultivariateStudentDistribution(dim_x=2, dim_y=2, dispersion=bmi.samplers.canonical_correlation(rho=[0.8, 0.8]), df=3), + "Normal-25Dim": DistributionAndPMI( + dist=fine.MultivariateNormalDistribution(dim_x=25, dim_y=25, covariance=bmi.samplers.canonical_correlation(rho=[0.8] * 25)) ), } # If PMI is left as None, override it with the PMI of the distribution @@ -70,8 +98,9 @@ DISTRIBUTION_AND_PMIS = { name: DistributionAndPMI(dist=value.dist, pmi=value.dist.pmi) if value.pmi is None else value for name, value in _DISTRIBUTIONS.items() } -N_POINTS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] -SEEDS = list(range(10)) +N_POINTS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] +SEEDS = list(range(20)) + rule all: input: