From 6c8702dfc6377afbfe5057e428fce0822fe01070 Mon Sep 17 00:00:00 2001 From: jack89roberts Date: Fri, 6 Aug 2021 20:55:59 +0100 Subject: [PATCH 1/4] Clip negative values to small positive value to avoid NaNs --- bpl/_util.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/bpl/_util.py b/bpl/_util.py index 548ab21..10c8bd3 100644 --- a/bpl/_util.py +++ b/bpl/_util.py @@ -10,6 +10,7 @@ def dixon_coles_correlation_term( home_rate: jnp.array, away_rate: jnp.array, corr_coef: jnp.array, + tol: float = 1e-9, # FIXME workaround to clip negative values to tol to avoid NaNs ) -> jnp.array: # correlation term from dixon and coles paper corr_term = jnp.zeros_like(home_rate) @@ -19,8 +20,13 @@ def dixon_coles_correlation_term( corr_term, (..., nil_nil), jnp.log( - 1.0 - - corr_coef[..., None] * home_rate[..., nil_nil] * away_rate[..., nil_nil] + jnp.clip( + 1.0 + - corr_coef[..., None] + * home_rate[..., nil_nil] + * away_rate[..., nil_nil], + a_min=tol, + ) ), ) @@ -28,19 +34,25 @@ def dixon_coles_correlation_term( corr_term = jax.ops.index_update( corr_term, (..., one_nil), - jnp.log(1.0 + corr_coef[..., None] * away_rate[..., one_nil]), + jnp.log( + jnp.clip(1.0 + corr_coef[..., None] * away_rate[..., one_nil], a_min=tol) + ), ) nil_one = (home_goals == 0) & (away_goals == 1) corr_term = jax.ops.index_update( corr_term, (..., nil_one), - jnp.log(1.0 + corr_coef[..., None] * home_rate[..., nil_one]), + jnp.log( + jnp.clip(1.0 + corr_coef[..., None] * home_rate[..., nil_one], a_min=TOL) + ), ) one_one = (home_goals == 1) & (away_goals == 1) corr_term = jax.ops.index_update( - corr_term, (..., one_one), jnp.log(1.0 - corr_coef[..., None]) + corr_term, + (..., one_one), + jnp.log(jnp.clip(1.0 - corr_coef[..., None], a_min=TOL)), ) return corr_term From 89d4d448490f3170f0dbd3da3ad7bacfc4e0fab2 Mon Sep 17 00:00:00 2001 From: jack89roberts Date: Fri, 6 Aug 2021 20:59:52 +0100 Subject: [PATCH 2/4] fix typos --- bpl/_util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bpl/_util.py b/bpl/_util.py index 10c8bd3..d400ab3 100644 --- a/bpl/_util.py +++ b/bpl/_util.py @@ -10,7 +10,7 @@ def dixon_coles_correlation_term( home_rate: jnp.array, away_rate: jnp.array, corr_coef: jnp.array, - tol: float = 1e-9, # FIXME workaround to clip negative values to tol to avoid NaNs + tol: float = 0, # FIXME workaround to clip negative values to tol to avoid NaNs ) -> jnp.array: # correlation term from dixon and coles paper corr_term = jnp.zeros_like(home_rate) @@ -44,7 +44,7 @@ def dixon_coles_correlation_term( corr_term, (..., nil_one), jnp.log( - jnp.clip(1.0 + corr_coef[..., None] * home_rate[..., nil_one], a_min=TOL) + jnp.clip(1.0 + corr_coef[..., None] * home_rate[..., nil_one], a_min=tol) ), ) @@ -52,7 +52,7 @@ def dixon_coles_correlation_term( corr_term = jax.ops.index_update( corr_term, (..., one_one), - jnp.log(jnp.clip(1.0 - corr_coef[..., None], a_min=TOL)), + jnp.log(jnp.clip(1.0 - corr_coef[..., None], a_min=tol)), ) return corr_term From f3a1572ac408e5b38b6c4d1e5037443df8c91769 Mon Sep 17 00:00:00 2001 From: jack89roberts Date: Fri, 6 Aug 2021 21:06:08 +0100 Subject: [PATCH 3/4] black --- bpl/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bpl/base.py b/bpl/base.py index f19b03f..515b7b6 100644 --- a/bpl/base.py +++ b/bpl/base.py @@ -155,4 +155,3 @@ def predict_concede_n_proba( # sum probability all scorelines where team conceded n goals return probs.sum(axis=0) - From 433c4c6e0f28240d8ca8639ede7441108eaf60d3 Mon Sep 17 00:00:00 2001 From: jack89roberts Date: Fri, 6 Aug 2021 21:09:08 +0100 Subject: [PATCH 4/4] isort --- bpl/dixon_coles.py | 2 +- bpl/extended_dixon_coles.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bpl/dixon_coles.py b/bpl/dixon_coles.py index e42718c..64d7c9e 100644 --- a/bpl/dixon_coles.py +++ b/bpl/dixon_coles.py @@ -12,8 +12,8 @@ from numpyro.infer import MCMC, NUTS from numpyro.infer.reparam import LocScaleReparam -from bpl.base import BaseMatchPredictor from bpl._util import dixon_coles_correlation_term +from bpl.base import BaseMatchPredictor __all__ = ["DixonColesMatchPredictor"] diff --git a/bpl/extended_dixon_coles.py b/bpl/extended_dixon_coles.py index c9ce025..e0b5cd7 100644 --- a/bpl/extended_dixon_coles.py +++ b/bpl/extended_dixon_coles.py @@ -12,8 +12,8 @@ from numpyro.infer import MCMC, NUTS from numpyro.infer.reparam import LocScaleReparam -from bpl.base import BaseMatchPredictor from bpl._util import dixon_coles_correlation_term +from bpl.base import BaseMatchPredictor __all__ = ["ExtendedDixonColesMatchPredictor"]