diff --git a/bpl/_util.py b/bpl/_util.py index 548ab21..d400ab3 100644 --- a/bpl/_util.py +++ b/bpl/_util.py @@ -10,6 +10,7 @@ def dixon_coles_correlation_term( home_rate: jnp.array, away_rate: jnp.array, corr_coef: jnp.array, + tol: float = 0, # FIXME workaround to clip negative values to tol to avoid NaNs ) -> jnp.array: # correlation term from dixon and coles paper corr_term = jnp.zeros_like(home_rate) @@ -19,8 +20,13 @@ def dixon_coles_correlation_term( corr_term, (..., nil_nil), jnp.log( - 1.0 - - corr_coef[..., None] * home_rate[..., nil_nil] * away_rate[..., nil_nil] + jnp.clip( + 1.0 + - corr_coef[..., None] + * home_rate[..., nil_nil] + * away_rate[..., nil_nil], + a_min=tol, + ) ), ) @@ -28,19 +34,25 @@ def dixon_coles_correlation_term( corr_term = jax.ops.index_update( corr_term, (..., one_nil), - jnp.log(1.0 + corr_coef[..., None] * away_rate[..., one_nil]), + jnp.log( + jnp.clip(1.0 + corr_coef[..., None] * away_rate[..., one_nil], a_min=tol) + ), ) nil_one = (home_goals == 0) & (away_goals == 1) corr_term = jax.ops.index_update( corr_term, (..., nil_one), - jnp.log(1.0 + corr_coef[..., None] * home_rate[..., nil_one]), + jnp.log( + jnp.clip(1.0 + corr_coef[..., None] * home_rate[..., nil_one], a_min=tol) + ), ) one_one = (home_goals == 1) & (away_goals == 1) corr_term = jax.ops.index_update( - corr_term, (..., one_one), jnp.log(1.0 - corr_coef[..., None]) + corr_term, + (..., one_one), + jnp.log(jnp.clip(1.0 - corr_coef[..., None], a_min=tol)), ) return corr_term diff --git a/bpl/base.py b/bpl/base.py index f19b03f..515b7b6 100644 --- a/bpl/base.py +++ b/bpl/base.py @@ -155,4 +155,3 @@ def predict_concede_n_proba( # sum probability all scorelines where team conceded n goals return probs.sum(axis=0) - diff --git a/bpl/dixon_coles.py b/bpl/dixon_coles.py index e42718c..64d7c9e 100644 --- a/bpl/dixon_coles.py +++ b/bpl/dixon_coles.py @@ -12,8 +12,8 @@ from numpyro.infer import MCMC, NUTS from numpyro.infer.reparam import LocScaleReparam -from bpl.base import BaseMatchPredictor from bpl._util import dixon_coles_correlation_term +from bpl.base import BaseMatchPredictor __all__ = ["DixonColesMatchPredictor"] diff --git a/bpl/extended_dixon_coles.py b/bpl/extended_dixon_coles.py index c9ce025..e0b5cd7 100644 --- a/bpl/extended_dixon_coles.py +++ b/bpl/extended_dixon_coles.py @@ -12,8 +12,8 @@ from numpyro.infer import MCMC, NUTS from numpyro.infer.reparam import LocScaleReparam -from bpl.base import BaseMatchPredictor from bpl._util import dixon_coles_correlation_term +from bpl.base import BaseMatchPredictor __all__ = ["ExtendedDixonColesMatchPredictor"]