Skip to content

Commit

Permalink
Adjust color scheme in the plots (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz authored Dec 4, 2024
1 parent 92b2c4e commit 9030cbe
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 31 deletions.
55 changes: 38 additions & 17 deletions workflows/projects/Mixtures/cool_tasks.smk
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,25 @@ ESTIMATOR_NAMES = {
"Hist-10": "Histogram",
"CCA": "CCA",
}
ESTIMATOR_COLORS = {
"MINE": '#377eb8',
"InfoNCE": '#ff7f00',
"KSG-10": '#4daf4a',
"Hist-10": '#f781bf',
"CCA": '#a65628',
}

ESTIMATOR_MARKERS = {
"MINE": 'o',
"InfoNCE": 'v',
"KSG-10": '^',
"Hist-10": 'D',
"CCA": 'X',
}

assert set(ESTIMATOR_NAMES.keys()) == set(ESTIMATORS.keys())
assert set(ESTIMATOR_COLORS.keys()) == set(ESTIMATORS.keys())
assert set(ESTIMATOR_MARKERS.keys()) == set(ESTIMATORS.keys())

_SAMPLE_ESTIMATE: int = 200_000

Expand Down Expand Up @@ -73,7 +91,7 @@ rule all:
'cool_tasks.pdf',
'results.csv',
'cool_tasks-results.pdf',
'profiles.pdf'
# 'profiles.pdf'

rule plot_distributions:
output: "cool_tasks.pdf"
Expand Down Expand Up @@ -122,21 +140,21 @@ rule plot_distributions:

fig.savefig(str(output), dpi=300)

rule plot_pmi_profiles:
output: "profiles.pdf"
run:
fig, axs = subplots_from_axsize(1, 4, axsize=(4, 3))
dists = [x_dist, ai_dist, fence_base_dist, balls_mixt]
tasks_official = ['X', 'AI', 'Waves', 'Galaxy']
for dist, task_name, ax in zip(dists, tasks_official, axs):
import jax
key = jax.random.PRNGKey(1024)
pmi_values = bmm.pmi_profile(key=key, dist=dist, n=100_000)
bins = np.linspace(-5, 5, 101)
ax.hist(pmi_values, bins=bins, density=True, alpha=0.5)
ax.set_xlabel(task_name)
axs[0].set_ylabel("Density")
fig.savefig(str(output))
# rule plot_pmi_profiles:
# output: "profiles.pdf"
# run:
# fig, axs = subplots_from_axsize(1, 4, axsize=(4, 3))
# dists = [x_dist, ai_dist, fence_base_dist, balls_mixt]
# tasks_official = ['X', 'AI', 'Waves', 'Galaxy']
# for dist, task_name, ax in zip(dists, tasks_official, axs):
# import jax
# key = jax.random.PRNGKey(1024)
# pmi_values = bmm.pmi_profile(key=key, dist=dist, n=100_000)
# bins = np.linspace(-5, 5, 101)
# ax.hist(pmi_values, bins=bins, density=True, alpha=0.5)
# ax.set_xlabel(task_name)
# axs[0].set_ylabel("Density")
# fig.savefig(str(output))


rule plot_results:
Expand All @@ -155,7 +173,10 @@ rule plot_results:
data_est['task_id'].apply(lambda e: tasks.index(e)) + 0.05 * np.random.normal(size=len(data_est)),
data_est['mi_estimate'],
label=ESTIMATOR_NAMES[estimator_id],
alpha=0.4, s=3**2,
alpha=0.4, s=5**2,
marker=ESTIMATOR_MARKERS[estimator_id],
c=ESTIMATOR_COLORS[estimator_id],
edgecolor="none",
)

for task_id, data_task in data_5k.groupby('task_id'):
Expand Down
27 changes: 19 additions & 8 deletions workflows/projects/Mixtures/how_good_integration_is.smk
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,19 @@ ESTIMATORS: dict[str, Callable] = {
}

ESTIMATOR_COLORS = {
"InfoNCE": "magenta",
"DV": "red",
"NWJ": "limegreen",
"MC": "mediumblue",
"InfoNCE": '#ff7f00',
"DV": '#984ea3',
"NWJ": "#999999",
"MC": "#dede00",
}
ESTIMATOR_MARKERS = {
"InfoNCE": 'v',
"DV": 'D',
"NWJ": "X",
"MC": ".",
}



four_balls = bmm.mixture(
proportions=jnp.array([0.3, 0.3, 0.2, 0.2]),
Expand Down Expand Up @@ -210,9 +218,6 @@ def plot_estimates(ax: plt.Axes, estimates_path, ground_truth_path, alpha: float
with open(ground_truth_path) as fh:
ground_truth = json.load(fh)

# Add ground-truth information
x_axis =[df["n_points"].min(), df["n_points"].max()]
ax.plot(x_axis, [ground_truth["mi_mean"]] * 2, c="k", linestyle=":")
# ax.fill_between(
# x_axis,
# [ground_truth["mi_mean"] - ground_truth["mi_std"]] * 2,
Expand All @@ -232,10 +237,16 @@ def plot_estimates(ax: plt.Axes, estimates_path, ground_truth_path, alpha: float

color = ESTIMATOR_COLORS[estimator]

ax.plot(points, mean, color=color, label=estimator)
ax.plot(points, mean, color=color)
ax.scatter(points, mean, color=color, marker=ESTIMATOR_MARKERS[estimator], label=estimator)
ax.fill_between(points, mean - std, mean + std, alpha=alpha, color=color)


# Add ground-truth information
x_axis =[df["n_points"].min(), df["n_points"].max()]
ax.plot(x_axis, [ground_truth["mi_mean"]] * 2, c="k", linestyle=":")


rule plot_performance_all:
input:
simple_ground_truth="Four_Balls/ground_truth.json",
Expand Down
20 changes: 14 additions & 6 deletions workflows/projects/Mixtures/outliers.smk
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,21 @@ for variance in VARIANCES:

UNSCALED_TASKS = {**MIXING_TASKS, **VARIANCE_TASKS}


ESTIMATOR_COLORS = {
"InfoNCE": "magenta",
"MINE": "red",
"KSG": "green",
"CCA": "purple",
"InfoNCE": '#ff7f00',
"MINE": '#377eb8',
"KSG": '#4daf4a',
"CCA": '#a65628',
}

ESTIMATOR_MARKERS = {
"InfoNCE": 'v',
"MINE": '.',
"KSG": '^',
"CCA": 'X',
}


ESTIMATORS = {
"KSG": bmi.estimators.KSGEnsembleFirstEstimator(neighborhoods=(10,)),
"CCA": bmi.estimators.CCAMutualInformationEstimator(),
Expand Down Expand Up @@ -162,7 +169,8 @@ def plot_data(ax: plt.Axes, data: pd.DataFrame, key: str = "mixing", use_legend:
subset = grouped[grouped['estimator_id'] == estimator]

color = ESTIMATOR_COLORS[estimator]
ax.plot(subset[key], subset['mean'], color=color, label=estimator)
ax.plot(subset[key], subset['mean'], color=color)
ax.scatter(subset[key], subset['mean'], color=color, marker=ESTIMATOR_MARKERS[estimator], label=estimator)
ax.fill_between(subset[key], subset['mean'] - subset['std'], subset['mean'] + subset['std'], alpha=0.3, color=color)

if use_legend:
Expand Down

0 comments on commit 9030cbe

Please sign in to comment.