Skip to content

Commit

Permalink
model from current bpl now runs
Browse files Browse the repository at this point in the history
  • Loading branch information
anguswilliams91 committed May 16, 2021
1 parent 34dd6d6 commit c5e9a8f
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions bpl/extended_dixon_coles.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def __init__(self):
self.attack = None
self.defence = None
self.home_advantage = None
self.corr_coef = Nonen
self.corr_coef = None
self.rho = None
self.attack_coefficients = None
self.defence_coefficients = None

# pylint: disable=too-many-locals
@staticmethod
Expand Down Expand Up @@ -55,9 +58,9 @@ def _model(

if team_covariates is not None:
standardised_covariates = (
team_covariates - team_covariates.mean(axis=-1)
) / team_covariates.std(axis=-1)
num_covariates = standardised_covariates.shape[0]
team_covariates - team_covariates.mean(axis=0)
) / team_covariates.std(axis=0)
num_covariates = standardised_covariates.shape[1]

with numpyro.plate("covariates", num_covariates):
attack_coefficients = numpyro.sample(
Expand All @@ -67,21 +70,22 @@ def _model(
"defence_coefficients", dist.Normal(loc=0.0, scale=1.0)
)

attack_prior_mean = attack_coefficients @ standardised_covariates
defence_prior_mean = (
mean_defence + defence_coefficients @ standardised_covariates
)

attack_prior_mean = jnp.matmul(
standardised_covariates, attack_coefficients[..., None]
).squeeze(-1)
defence_prior_mean = mean_defence + jnp.matmul(
standardised_covariates, defence_coefficients[..., None]
).squeeze(-1)
else:
attack_prior_mean = 0.0
defence_prior_mean = mean_defence

with numpyro.plate("teams", num_teams):
standardised_attack = numpyro.sample(
"attack", dist.Normal(loc=0.0, scale=1.0)
"standardised_attack", dist.Normal(loc=0.0, scale=1.0)
)
standardised_defence = numpyro.sample(
"defence",
"standardised_defence",
dist.Normal(
loc=rho * standardised_attack, scale=jnp.sqrt(1.0 - rho ** 2.0)
),
Expand All @@ -93,14 +97,22 @@ def _model(
dist.Normal(mean_home_advantage, std_home_advantage),
)

attack = attack_prior_mean + standardised_attack * std_attack
defence = defence_prior_mean + standardised_defence * std_defence
attack = numpyro.deterministic(
"attack", attack_prior_mean + standardised_attack * std_attack
)
defence = numpyro.deterministic(
"defence", defence_prior_mean + standardised_defence * std_defence
)

expected_home_goals = jnp.exp(
attack[home_team] - defence[away_team] + home_advantage[home_team]
)
expected_away_goals = jnp.exp(attack[away_team] - defence[home_team])

# FIXME: this is because the priors allow crazy simulated data before inference
expected_home_goals = jnp.clip(expected_home_goals, a_max=15.0)
expected_away_goals = jnp.clip(expected_away_goals, a_max=15.0)

numpyro.sample(
"home_goals", dist.Poisson(expected_home_goals).to_event(1), obs=home_goals
)
Expand Down Expand Up @@ -132,6 +144,14 @@ def fit(
home_ind = jnp.array([self.teams.index(t) for t in home_team])
away_ind = jnp.array([self.teams.index(t) for t in away_team])

if team_covariates:
if set(team_covariates.keys()) == set(self.teams):
team_covariates = jnp.array([team_covariates[t] for t in self.teams])
else:
raise ValueError(
"team_covariates must contain all the teams in the data."
)

nuts_kernel = NUTS(self._model)
mcmc = MCMC(nuts_kernel, num_warmup, num_samples, **(mcmc_kwargs or {}))
rng_key = jax.random.PRNGKey(random_state)
Expand All @@ -151,6 +171,9 @@ def fit(
self.defence = samples["defence"]
self.home_advantage = samples["home_advantage"]
self.corr_coef = samples["corr_coef"]
self.rho = samples["rho"]
self.attack_coefficients = samples.get("attack_coefficients", None)
self.defence_coefficients = samples.get("defence_coefficients", None)

return self

Expand Down

0 comments on commit c5e9a8f

Please sign in to comment.