Skip to content

Commit

Permalink
add weighting by the cell frequency
Browse files Browse the repository at this point in the history
  • Loading branch information
JPapir committed Nov 11, 2024
1 parent 0cea1d1 commit c8c15d0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
5 changes: 2 additions & 3 deletions src/qumin/calc_paradigm_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .entropy.distribution import PatternDistribution, SplitPatternDistribution
from .representations import segments, patterns, create_paradigms, create_features
from .representations.frequencies import Frequencies
from itertools import permutations

log = logging.getLogger()

Expand Down Expand Up @@ -113,7 +112,7 @@ def H_command(cfg, md):
if onePred:
if not md.bipartite: # Already computed in bipartite systems :)
distrib.one_pred_entropy()
mean = distrib.get_mean()
mean = distrib.get_mean(weighting=cfg.entropy.weighting)
log.info("Mean H(c1 -> c2) = %s ", mean)
if verbose:
distrib.one_pred_distrib_log()
Expand All @@ -123,7 +122,7 @@ def H_command(cfg, md):

for n in preds:
distrib.n_preds_entropy_matrix(n)
mean = distrib.get_mean(n=n)
mean = distrib.get_mean(n=n, weighting=cfg.entropy.weighting)
log.info(f"Mean H(c1, ..., c{n} -> c) = {mean}")

if verbose:
Expand Down
27 changes: 25 additions & 2 deletions src/qumin/entropy/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from collections import Counter, defaultdict
from functools import reduce
from itertools import combinations, chain
from itertools import permutations

import pandas as pd
from tqdm import tqdm

from ..representations.frequencies import Frequencies
from . import cond_entropy

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,11 +100,32 @@ def __init__(self, paradigms, patterns, classes, name, features=None):
"dataset"
])

def get_mean(self, measure="cond_entropy", n=1):
def get_mean(self, measure="cond_entropy", n=1, weighting=False):
is_cond_ent = self.data.loc[:, "measure"] == measure
is_one_pred = self.data.loc[:, "n_preds"] == n
results = self.data.loc[is_cond_ent & is_one_pred, :].set_index(['predictor', 'predicted'])

return self.data.loc[is_cond_ent & is_one_pred, "value"].mean()
if weighting and Frequencies.source['cells'] != "empty":
cell_freq = Frequencies.get_relative_freq(data="cells")
weight = results.value.copy()

for pred, out in permutations(cell_freq.index.to_list(), 2):
nopred = cell_freq[cell_freq.index != pred]

# Probability of predicting from the predictor cell
p_pred = cell_freq.loc[pred, 'result']
# Probability of predicting the target cell among the cells other than the predictor
p_out = (nopred.loc[nopred.index == out, 'result']/nopred.result.sum()).iloc[0]
# Setting the result
weight.loc[(pred, out)] = p_out * p_pred

mean = (weight * results.value).sum()
else:
if weighting:
log.warning("Couldn't find cell frequencies. Falling back on weighting by the number of pairs.")
mean = results.loc[:, "value"].mean()

return mean

def export_file(self, filename):
""" Export the data DataFrame to file
Expand Down

0 comments on commit c8c15d0

Please sign in to comment.