diff --git a/bpl/_util.py b/bpl/_util.py index 548ab21..d400ab3 100644 --- a/bpl/_util.py +++ b/bpl/_util.py @@ -10,6 +10,7 @@ def dixon_coles_correlation_term( home_rate: jnp.array, away_rate: jnp.array, corr_coef: jnp.array, + tol: float = 0, # FIXME workaround to clip negative values to tol to avoid NaNs ) -> jnp.array: # correlation term from dixon and coles paper corr_term = jnp.zeros_like(home_rate) @@ -19,8 +20,13 @@ def dixon_coles_correlation_term( corr_term, (..., nil_nil), jnp.log( - 1.0 - - corr_coef[..., None] * home_rate[..., nil_nil] * away_rate[..., nil_nil] + jnp.clip( + 1.0 + - corr_coef[..., None] + * home_rate[..., nil_nil] + * away_rate[..., nil_nil], + a_min=tol, + ) ), ) @@ -28,19 +34,25 @@ def dixon_coles_correlation_term( corr_term = jax.ops.index_update( corr_term, (..., one_nil), - jnp.log(1.0 + corr_coef[..., None] * away_rate[..., one_nil]), + jnp.log( + jnp.clip(1.0 + corr_coef[..., None] * away_rate[..., one_nil], a_min=tol) + ), ) nil_one = (home_goals == 0) & (away_goals == 1) corr_term = jax.ops.index_update( corr_term, (..., nil_one), - jnp.log(1.0 + corr_coef[..., None] * home_rate[..., nil_one]), + jnp.log( + jnp.clip(1.0 + corr_coef[..., None] * home_rate[..., nil_one], a_min=tol) + ), ) one_one = (home_goals == 1) & (away_goals == 1) corr_term = jax.ops.index_update( - corr_term, (..., one_one), jnp.log(1.0 - corr_coef[..., None]) + corr_term, + (..., one_one), + jnp.log(jnp.clip(1.0 - corr_coef[..., None], a_min=tol)), ) return corr_term