Skip to content

Commit

Permalink
Merge pull request #7 from anguswilliams91/fix-nan-corr-term
Browse files Browse the repository at this point in the history
Clip negative values to to avoid NaN correlation term
  • Loading branch information
jack89roberts authored Feb 4, 2022
2 parents 760bb2e + 433c4c6 commit 9caa266
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions bpl/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def dixon_coles_correlation_term(
home_rate: jnp.array,
away_rate: jnp.array,
corr_coef: jnp.array,
tol: float = 0, # FIXME workaround to clip negative values to tol to avoid NaNs
) -> jnp.array:
# correlation term from dixon and coles paper
corr_term = jnp.zeros_like(home_rate)
Expand All @@ -19,28 +20,39 @@ def dixon_coles_correlation_term(
corr_term,
(..., nil_nil),
jnp.log(
1.0
- corr_coef[..., None] * home_rate[..., nil_nil] * away_rate[..., nil_nil]
jnp.clip(
1.0
- corr_coef[..., None]
* home_rate[..., nil_nil]
* away_rate[..., nil_nil],
a_min=tol,
)
),
)

one_nil = (home_goals == 1) & (away_goals == 0)
corr_term = jax.ops.index_update(
corr_term,
(..., one_nil),
jnp.log(1.0 + corr_coef[..., None] * away_rate[..., one_nil]),
jnp.log(
jnp.clip(1.0 + corr_coef[..., None] * away_rate[..., one_nil], a_min=tol)
),
)

nil_one = (home_goals == 0) & (away_goals == 1)
corr_term = jax.ops.index_update(
corr_term,
(..., nil_one),
jnp.log(1.0 + corr_coef[..., None] * home_rate[..., nil_one]),
jnp.log(
jnp.clip(1.0 + corr_coef[..., None] * home_rate[..., nil_one], a_min=tol)
),
)

one_one = (home_goals == 1) & (away_goals == 1)
corr_term = jax.ops.index_update(
corr_term, (..., one_one), jnp.log(1.0 - corr_coef[..., None])
corr_term,
(..., one_one),
jnp.log(jnp.clip(1.0 - corr_coef[..., None], a_min=tol)),
)

return corr_term

0 comments on commit 9caa266

Please sign in to comment.