diff --git a/bpl/_util.py b/bpl/_util.py index 4d2695c..c0189f6 100644 --- a/bpl/_util.py +++ b/bpl/_util.py @@ -1,4 +1,5 @@ """Private utility functions.""" + from typing import Iterable, Optional, Tuple, Union import jax diff --git a/bpl/base.py b/bpl/base.py index f481ab3..ff3d357 100644 --- a/bpl/base.py +++ b/bpl/base.py @@ -1,4 +1,5 @@ """Implementation of the probabilistic model for soccer matches.""" + from __future__ import annotations from abc import abstractmethod diff --git a/bpl/dixon_coles.py b/bpl/dixon_coles.py index c0d2709..2e90f1c 100644 --- a/bpl/dixon_coles.py +++ b/bpl/dixon_coles.py @@ -1,4 +1,5 @@ """Implementation of a simple team level model.""" + from __future__ import annotations from typing import Any, Dict, Iterable, Optional, Tuple, Union diff --git a/bpl/dynamic_dixon_coles.py b/bpl/dynamic_dixon_coles.py index 848c08c..7d5e5ab 100644 --- a/bpl/dynamic_dixon_coles.py +++ b/bpl/dynamic_dixon_coles.py @@ -1,4 +1,5 @@ """Implementation of the neutral model with dynamic parameters in the current version of bpl.""" + from __future__ import annotations import warnings diff --git a/bpl/extended_dixon_coles.py b/bpl/extended_dixon_coles.py index a614455..0540c39 100644 --- a/bpl/extended_dixon_coles.py +++ b/bpl/extended_dixon_coles.py @@ -1,4 +1,5 @@ """Implementation of the model in the current version of bpl.""" + from __future__ import annotations import warnings diff --git a/bpl/neutral_dixon_coles.py b/bpl/neutral_dixon_coles.py index 2dba0d7..98d7340 100644 --- a/bpl/neutral_dixon_coles.py +++ b/bpl/neutral_dixon_coles.py @@ -1,4 +1,5 @@ """Implementation of the neutral model for predicting the World Cup.""" + from __future__ import annotations import warnings diff --git a/bpl/neutral_dixon_coles_WC.py b/bpl/neutral_dixon_coles_WC.py index 272213e..151cea9 100644 --- a/bpl/neutral_dixon_coles_WC.py +++ b/bpl/neutral_dixon_coles_WC.py @@ -1,4 +1,5 @@ """Implementation of the neutral model for predicting the World Cup.""" + from __future__ import annotations import warnings diff --git a/tests/test_all_models.py b/tests/test_all_models.py index 21ba19e..2b0917b 100644 --- a/tests/test_all_models.py +++ b/tests/test_all_models.py @@ -1,4 +1,5 @@ """Shared tests across all models, e.g. checking probabilities are valid.""" + import jax.numpy as jnp import pytest diff --git a/tests/test_neutral_dixon_coles.py b/tests/test_neutral_dixon_coles.py index 00a9340..bd4ff08 100644 --- a/tests/test_neutral_dixon_coles.py +++ b/tests/test_neutral_dixon_coles.py @@ -6,6 +6,7 @@ TOL = 1e-02 + @pytest.fixture def model(neutral_dummy_data): return NeutralDixonColesMatchPredictor().fit(neutral_dummy_data) diff --git a/tests/test_neutral_dixon_coles_WC.py b/tests/test_neutral_dixon_coles_WC.py index fecd0ef..8c56911 100644 --- a/tests/test_neutral_dixon_coles_WC.py +++ b/tests/test_neutral_dixon_coles_WC.py @@ -6,6 +6,7 @@ TOL = 5e-02 + @pytest.fixture def model(neutral_dummy_data): return NeutralDixonColesMatchPredictorWC().fit(neutral_dummy_data) @@ -97,7 +98,9 @@ def test_predict_concede_n_proba(model): assert len(proba_team_concede) == 1 assert (proba_team_concede[0] >= 0) and (proba_team_concede[0] <= 1) - proba_opponent_score = model.predict_score_n_proba(1, "1", "0", "0", "1", home=False) + proba_opponent_score = model.predict_score_n_proba( + 1, "1", "0", "0", "1", home=False + ) assert proba_team_concede.tolist() == pytest.approx( proba_opponent_score.tolist(), abs=TOL )