-
Notifications
You must be signed in to change notification settings - Fork 1
/
posterior_sampling.py
150 lines (117 loc) · 4.46 KB
/
posterior_sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
'''
Helper code for MS&E 338 Reinforcement Learning implementation assignment.
The functions in the script below are designed to help you update posteriors
for the reward and transition functions. We will use a simple (and standard)
set of conjugate families.
You may find it helpful to use these functions to update and sample from
posterior distributions for PSRL. Note however that none of this code is
particularly optimized/vectorized. There may be better pre-packaged solutions
if you want to do this in practice, but this is a nice way to look "beneath the
hood" in a simple example.
Rewards
Distribution approximation: Normal unknown mean, unknown variance
Conjugate prior: Normal Gamma
Wikipedia: http://en.wikipedia.org/wiki/Normal-gamma_distribution
Transitions
Distribution approximation: Multinomial distribution
Conjugate prior: Dirichlet
Wikipedia: http://en.wikipedia.org/wiki/Dirichlet_distribution
author: [email protected]
'''
import numpy as np
#---------------------------------------------------------------------------
# Rewards functions
def convert_prior(mu, n_mu, tau, n_tau):
'''
Convert the natural way to speak about priors to our paramterization
Args:
mu - 1x1 - prior mean
n_mu - 1x1 - number of observations of mean
tau - 1x1 - prior precision (1 / variance)
n_tau - 1x1 - number of observations of tau
Returns:
prior - 4x1 - (mu, lambda, alpha, beta)
'''
prior = (mu, n_mu, n_tau * 0.5, (0.5 * n_tau) / tau)
return prior
def update_normal_ig(prior, data):
'''
Update the parameters of a normal gamma.
T | a,b ~ Gamma(a, b)
X | T ~ Normal(mu, 1 / (lambda T))
Args:
prior - 4 x 1 - tuple containing (in this order)
mu0 - prior mean
lambda0 - pseudo observations for prior mean
alpha0 - inverse gamma shape
beta0 - inverse gamma scale
data - n x 1 - numpy array of {y_i} observations
Returns:
posterior - 4 x 1 - tuple containing updated posterior params.
NB this is in the same format as the prior input.
'''
# Unpack the prior
(mu0, lambda0, alpha0, beta0) = prior
n = len(data)
y_bar = np.mean(data)
# Updating normal component
lambda1 = lambda0 + n
mu1 = (lambda0 * mu0 + n * y_bar) / lambda1
# Updating Inverse-Gamma component
alpha1 = alpha0 + (n * 0.5)
ssq = n * np.var(data)
prior_disc = lambda0 * n * ((y_bar - mu0) ** 2) / lambda1
beta1 = beta0 + 0.5 * (ssq + prior_disc)
posterior = (mu1, lambda1, alpha1, beta1)
return posterior
def sample_normal_ig(prior):
'''
Sample a single normal distribution from a normal inverse gamma prior.
Args:
prior - 4 x 1 - tuple containing (in this order)
mu - prior mean
lambda0 - pseudo observations for prior mean
alpha - inverse gamma shape
beta - inverse gamma scale
Returns:
params - 2 x 1 - tuple, sampled mean and precision.
'''
# Unpack the prior
(mu, lambda0, alpha, beta) = prior
# Sample scaling tau from a gamma distribution
tau = np.random.gamma(shape=alpha, scale=1. / beta)
var = 1. / (lambda0 * tau)
# Sample mean from normal mean mu, var
mean = np.random.normal(loc=mu, scale=np.sqrt(var))
return (mean, tau)
#---------------------------------------------------------------------------
# Transition functions
def update_dirichlet(prior, data):
'''
Update the parameters of a dirichlet distribution.
We assume that the data is drawn from multinomial over n discrete states.
Args:
prior - n x 1 - numpy array, pseudocounts of discrete observations.
data - n x 1 - numpy array, counts of observations of each draw
Returns:
posterior - n x 1 - numpy array, overall pseudocounts.
'''
# Updating dirichlet is trivial
posterior = prior + data
return posterior
def sample_dirichlet(prior):
'''
Sample a multinomial distribution from a Dirichlet prior.
Args:
prior - n x 1 - numpy array, pseudocounts of discrete observations.
Returns:
dist - n x 1 - numpy array, probability distribution over n discrete.
'''
n = len(prior)
dist = np.zeros(n)
for i in range(n):
# Sample a gamma distribution for each entry
dist[i] = np.random.gamma(prior[i])
# Normalize the probability distribution
dist = dist / sum(dist)
return dist