Skip to content

Commit

Permalink
Fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
stes committed Oct 27, 2024
1 parent db9df82 commit 5703628
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
>>> cebra_model = cebra.CEBRA(max_iterations=10)
>>> cebra_model.fit(neural_data)
CEBRA(max_iterations=10)
>>> gof = cebra.goodness_of_fit_score(cebra_model, neural_data)
>>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
"""
loss = infonce_loss(cebra_model,
X,
Expand Down Expand Up @@ -172,7 +172,7 @@ def goodness_of_fit_history(model):
>>> cebra_model = cebra.CEBRA(max_iterations=10)
>>> cebra_model.fit(neural_data)
CEBRA(max_iterations=10)
>>> gof_history = cebra.goodness_of_fit_history(cebra_model)
>>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
"""
infonce = np.array(model.state_dict_["log"]["total"])
return infonce_to_goodness_of_fit(infonce, model)
Expand Down Expand Up @@ -215,7 +215,7 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
num_sessions = model.num_sessions_
if num_sessions is None:
num_sessions = 1
chance_level = np.log(model.batch_size * (model.num_sessions_ or 1))
chance_level = np.log(model.batch_size * num_sessions)
return (chance_level - infonce) * nats_to_bits


Expand Down

0 comments on commit 5703628

Please sign in to comment.