Skip to content

Commit

Permalink
some shape fixing stuff for lang -> info plane measurement
Browse files Browse the repository at this point in the history
  • Loading branch information
shanest committed Sep 20, 2024
1 parent 3ad8c8c commit d3ac8ed
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 20 deletions.
7 changes: 6 additions & 1 deletion src/examples/colors/ib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ def language_to_encoder(language: Language) -> np.ndarray:
where element (r, w) is the probability of word w given referent r
"""
universe = language.universe
encoder = np.array(
# (|referents|, |words|): p(r | w)
prob_chip_given_expression = np.array(
[
[expression.meaning[referent] for expression in language.expressions]
for referent in universe.referents
]
)
# (|referents|, 1): p(r)
# prior = np.array(universe.prior)[:, None]
# TODO: get encoder from prob_chip_given_expression (i.e. prob_expression_given_chip)
encoder = prob_chip_given_expression
print(encoder)
return encoder
5 changes: 4 additions & 1 deletion src/examples/colors/scripts/measure_natural_languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

if __name__ == "__main__":

"""
with open("colors/outputs/natural_languages.yaml", "r") as f:
languages = load(f, Loader=Loader)
Expand All @@ -21,14 +22,15 @@
encoder = language_to_encoder(language)
print(encoder)
print(encoder.sum(axis=0))
print(encoder.sum(axis=1))

print(
ib_encoder_to_point(
np.array(color_universe.prior), meaning_distributions, encoder
)
)
"""

"""
prior = np.array(color_universe.prior)
information_plane = pd.DataFrame.from_records(
[
Expand All @@ -43,3 +45,4 @@
information_plane.to_csv(
"colors/outputs/natural_language_information_plane.csv", index=False
)
"""
10 changes: 1 addition & 9 deletions src/examples/colors/scripts/read_color_universe.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
import pandas as pd
import pickle

from ultk.language.semantics import Universe


if __name__ == "__main__":
referents = pd.read_csv("colors/data/cnum-vhcm-lab-new.txt", delimiter="\t")
referents.sort_values(by="#cnum", inplace=True)
# add a name column, as required by ULTK
referents["name"] = referents["#cnum"]
# rename columns for access as python properties
referents.rename(columns={"L*": "L", "a*": "a", "b*": "b"}, inplace=True)
referents.to_csv("colors/outputs/color_universe.csv", index=False)
"""
color_universe = Universe.from_dataframe(referents)
with open("colors/outputs/color_universe.pkl", "wb") as f:
pickle.dump(color_universe, f)
"""
14 changes: 5 additions & 9 deletions src/ultk/effcomm/rate_distortion.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,39 +122,35 @@ def ib_encoder_to_point(
# IB complexity = info rate of encoder = I(meanings; words)
complexity = information_cond(prior, encoder)

# (|meanings|, 1)
prior = prior[:, None]

# IB accuracy/informativity = I(words; world states)
pMW = encoder * prior[:, None]
pMW = encoder * prior
pWU = pMW.T @ meaning_dists
accuracy = mutual_info(pWU)

# expected distortion
I_mu = information_cond(prior, meaning_dists)
distortion = I_mu - accuracy

# TODO: debug the below!

# TODO: the above is only IB optimal; should we look at the emergent listener accuracy? To do that we'll need to compute kl divergence
# and then do I(M;U) - distortion to get the accuracy.

# pu_w = decoder @ meaning_dists
# dist_mat = ib_kl(meaning_dists, pu_w,) # getting infs; I confirmed that this because there exists an x s.t. p(x) > 0 but q(x) = 0. Ask Noga what to do here. Add a little epsilon?
# distortion = np.sum( prior * ( encoder @ decoder ) * dist_mat )

"""
decoder_smoothed = decoder + 1e-20
decoder_smoothed /= decoder_smoothed.sum(axis=1, keepdims=True)
print(decoder_smoothed.shape)
pu_w = decoder_smoothed @ meaning_dists
print(pu_w.shape)
dist_mat = ib_kl(
meaning_dists,
pu_w,
)
print(dist_mat.shape) # NOTE: this is the wrong shape
distortion = np.sum(prior * (encoder @ decoder) * dist_mat)
distortion = np.sum(prior * encoder * dist_mat)
# but this measure of distortion is almost an order magnitude higher than bayesian decoder
# breakpoint()
"""

return (complexity, accuracy, distortion)

Expand Down

0 comments on commit d3ac8ed

Please sign in to comment.