diff --git a/src/crested/tl/_crested.py b/src/crested/tl/_crested.py index ea4b93d..80ab82a 100644 --- a/src/crested/tl/_crested.py +++ b/src/crested/tl/_crested.py @@ -578,9 +578,9 @@ def test(self, return_metrics: bool = False) -> dict | None: # Log the evaluation results for metric_name, metric_value in evaluation_metrics.items(): logger.info(f"Test {metric_name}: {metric_value:.4f}") - return None if return_metrics: return evaluation_metrics + return None def get_embeddings( self, diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a455e9a..986f01e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -79,7 +79,9 @@ def test_peak_regression(): "tests/data/test_pipeline/test_peak_regression/checkpoints/01.keras", compile=True, ) - trainer.test() + + test_metrics = trainer.test(return_metrics=True) + assert isinstance(test_metrics, dict) trainer.predict(adata, model_name="01") trainer.predict_regions(region_idx=["chr1:1000-1600", "chr2:2000-2600"])