From 570362820681f93f3fab3ff5bad23dfb4f5b1c30 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 27 Oct 2024 19:00:40 +0100 Subject: [PATCH] Fix examples --- cebra/integrations/sklearn/metrics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index 41dc67f..4b3c08e 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -141,7 +141,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA, >>> cebra_model = cebra.CEBRA(max_iterations=10) >>> cebra_model.fit(neural_data) CEBRA(max_iterations=10) - >>> gof = cebra.goodness_of_fit_score(cebra_model, neural_data) + >>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data) """ loss = infonce_loss(cebra_model, X, @@ -172,7 +172,7 @@ def goodness_of_fit_history(model): >>> cebra_model = cebra.CEBRA(max_iterations=10) >>> cebra_model.fit(neural_data) CEBRA(max_iterations=10) - >>> gof_history = cebra.goodness_of_fit_history(cebra_model) + >>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model) """ infonce = np.array(model.state_dict_["log"]["total"]) return infonce_to_goodness_of_fit(infonce, model) @@ -215,7 +215,7 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]], num_sessions = model.num_sessions_ if num_sessions is None: num_sessions = 1 - chance_level = np.log(model.batch_size * (model.num_sessions_ or 1)) + chance_level = np.log(model.batch_size * num_sessions) return (chance_level - infonce) * nats_to_bits