Skip to content

Commit

Permalink
internalize the mean computation
Browse files Browse the repository at this point in the history
  • Loading branch information
JPapir committed Nov 11, 2024
1 parent f887291 commit 0cea1d1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
16 changes: 10 additions & 6 deletions src/qumin/calc_paradigm_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from .entropy.distribution import PatternDistribution, SplitPatternDistribution
from .representations import segments, patterns, create_paradigms, create_features
from .representations.frequencies import Frequencies
from itertools import permutations

log = logging.getLogger()

Expand All @@ -24,6 +26,8 @@ def H_command(cfg, md):
cfg.patterns) == 2, "You must pass either a single dataset and patterns file, or a list of two of each (coindexed)."
md.bipartite = True

Frequencies.initialize(md.datasets[0], real=True)

patterns_file_path = cfg.patterns if md.bipartite else [cfg.patterns]
sounds_file_name = md.get_table_path("sounds")

Expand Down Expand Up @@ -86,10 +90,10 @@ def H_command(cfg, md):
features=features)

distrib.mutual_information()
mean1 = distrib.distribs[0].get_results().loc[:, "value"].mean()
mean2 = distrib.distribs[1].get_results().loc[:, "value"].mean()
mean3 = distrib.get_results(measure="mutual_information").loc[:, "value"].mean()
mean4 = distrib.get_results(measure="normalized_mutual_information").loc[:, "value"].mean()
mean1 = distrib.distribs[0].get_mean()
mean2 = distrib.distribs[1].get_mean()
mean3 = distrib.get_mean(measure="mutual_information")
mean4 = distrib.get_mean(measure="normalized_mutual_information")
log.debug("Mean remaining H(c1 -> c2) for %s = %s", names[0], mean1)
log.debug("Mean remaining H(c1 -> c2) for %s = %s", names[1], mean2)
log.debug("Mean I(%s,%s) = %s", *names, mean3)
Expand All @@ -109,7 +113,7 @@ def H_command(cfg, md):
if onePred:
if not md.bipartite: # Already computed in bipartite systems :)
distrib.one_pred_entropy()
mean = distrib.get_results().loc[:, "value"].mean()
mean = distrib.get_mean()
log.info("Mean H(c1 -> c2) = %s ", mean)
if verbose:
distrib.one_pred_distrib_log()
Expand All @@ -119,7 +123,7 @@ def H_command(cfg, md):

for n in preds:
distrib.n_preds_entropy_matrix(n)
mean = distrib.get_results(n=n).loc[:, "value"].mean()
mean = distrib.get_mean(n=n)
log.info(f"Mean H(c1, ..., c{n} -> c) = {mean}")

if verbose:
Expand Down
3 changes: 2 additions & 1 deletion src/qumin/config/qumin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ entropy:
# with any file, use to compute entropy heatmap
# with n-1 predictors, allows for acceleration on nPreds entropy computation.
merged: False # Whether identical columns are merged in the input.
stacked: False # whether to stack results in long form
stacked: False # whether to stack results in long form.
weighting: True # whether to use cell frequencies for weighting.

eval:
iter: 10 # How many 90/10 train/test folds to do.
Expand Down
5 changes: 3 additions & 2 deletions src/qumin/entropy/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ def __init__(self, paradigms, patterns, classes, name, features=None):
"dataset"
])

def get_results(self, measure="cond_entropy", n=1):
def get_mean(self, measure="cond_entropy", n=1):
is_cond_ent = self.data.loc[:, "measure"] == measure
is_one_pred = self.data.loc[:, "n_preds"] == n
return self.data.loc[is_cond_ent & is_one_pred, :]

return self.data.loc[is_cond_ent & is_one_pred, "value"].mean()

def export_file(self, filename):
""" Export the data DataFrame to file
Expand Down
2 changes: 1 addition & 1 deletion src/qumin/representations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_unknown_segments(forms, unknowns, name):
paradigms = paradigms[~paradigms.loc[:, lexemes].isin(defective_lexemes)]

if most_freq:
inflected = paradigms.loc[:,lexemes].unique()
inflected = paradigms.loc[:, lexemes].unique()
lexemes_file_name = Path(dataset.basepath) / dataset.get_resource("lexemes").path
lexemes_df = pd.read_csv(lexemes_file_name, usecols=["lexeme_id", "frequency"])
# Restrict to lexemes we have kept, if we dropped defectives
Expand Down

0 comments on commit 0cea1d1

Please sign in to comment.